Mercurial > hg > pycsalgos
comparison pyCSalgos/OMP/omp_QR.py @ 2:735a0e24575c
Organized folders: added tests, apps, matlab, docs folders. Added __init__.py files
author | nikcleju |
---|---|
date | Fri, 21 Oct 2011 13:53:49 +0000 |
parents | |
children | e1da5140c9a5 |
comparison
equal
deleted
inserted
replaced
1:2a2abf5092f8 | 2:735a0e24575c |
---|---|
1 import numpy as np | |
2 import scipy.linalg | |
3 import time | |
4 import math | |
5 | |
6 | |
7 #function [s, err_mse, iter_time]=greed_omp_qr(x,A,m,varargin) | |
8 def greed_omp_qr(x,A,m,opts=[]): | |
9 # greed_omp_qr: Orthogonal Matching Pursuit algorithm based on QR | |
10 # factorisation | |
11 # Nic: translated to Python on 19.10.2011. Original Matlab Code by Thomas Blumensath | |
12 ########################################################################### | |
13 # Usage | |
14 # [s, err_mse, iter_time]=greed_omp_qr(x,P,m,'option_name','option_value') | |
15 ########################################################################### | |
16 ########################################################################### | |
17 # Input | |
18 # Mandatory: | |
19 # x Observation vector to be decomposed | |
20 # P Either: | |
21 # 1) An nxm matrix (n must be dimension of x) | |
22 # 2) A function handle (type "help function_format" | |
23 # for more information) | |
24 # Also requires specification of P_trans option. | |
25 # 3) An object handle (type "help object_format" for | |
26 # more information) | |
27 # m length of s | |
28 # | |
29 # Possible additional options: | |
30 # (specify as many as you want using 'option_name','option_value' pairs) | |
31 # See below for explanation of options: | |
32 #__________________________________________________________________________ | |
33 # option_name | available option_values | default | |
34 #-------------------------------------------------------------------------- | |
35 # stopCrit | M, corr, mse, mse_change | M | |
36 # stopTol | number (see below) | n/4 | |
37 # P_trans | function_handle (see below) | | |
38 # maxIter | positive integer (see below) | n | |
39 # verbose | true, false | false | |
40 # start_val | vector of length m | zeros | |
41 # | |
42 # Available stopping criteria : | |
43 # M - Extracts exactly M = stopTol elements. | |
44 # corr - Stops when maximum correlation between | |
45 # residual and atoms is below stopTol value. | |
46 # mse - Stops when mean squared error of residual | |
47 # is below stopTol value. | |
48 # mse_change - Stops when the change in the mean squared | |
49 # error falls below stopTol value. | |
50 # | |
51 # stopTol: Value for stopping criterion. | |
52 # | |
53 # P_trans: If P is a function handle, then P_trans has to be specified and | |
54 # must be a function handle. | |
55 # | |
56 # maxIter: Maximum number of allowed iterations. | |
57 # | |
58 # verbose: Logical value to allow algorithm progress to be displayed. | |
59 # | |
60 # start_val: Allows algorithms to start from partial solution. | |
61 # | |
62 ########################################################################### | |
63 # Outputs | |
64 # s Solution vector | |
65 # err_mse Vector containing mse of approximation error for each | |
66 # iteration | |
67 # iter_time Vector containing computation times for each iteration | |
68 # | |
69 ########################################################################### | |
70 # Description | |
71 # greed_omp_qr performs a greedy signal decomposition. | |
72 # In each iteration a new element is selected depending on the inner | |
73 # product between the current residual and columns in P. | |
74 # The non-zero elements of s are approximated by orthogonally projecting | |
75 # x onto the selected elements in each iteration. | |
76 # The algorithm uses QR decomposition. | |
77 # | |
78 # See Also | |
79 # greed_omp_chol, greed_omp_cg, greed_omp_cgp, greed_omp_pinv, | |
80 # greed_omp_linsolve, greed_gp, greed_nomp | |
81 # | |
82 # Copyright (c) 2007 Thomas Blumensath | |
83 # | |
84 # The University of Edinburgh | |
85 # Email: thomas.blumensath@ed.ac.uk | |
86 # Comments and bug reports welcome | |
87 # | |
88 # This file is part of sparsity Version 0.1 | |
89 # Created: April 2007 | |
90 # | |
91 # Part of this toolbox was developed with the support of EPSRC Grant | |
92 # D000246/1 | |
93 # | |
94 # Please read COPYRIGHT.m for terms and conditions. | |
95 | |
96 ########################################################################### | |
97 # Default values and initialisation | |
98 ########################################################################### | |
99 #[n1 n2]=size(x); | |
100 #n1,n2 = x.shape | |
101 #if n2 == 1 | |
102 # n=n1; | |
103 #elseif n1 == 1 | |
104 # x=x'; | |
105 # n=n2; | |
106 #else | |
107 # display('x must be a vector.'); | |
108 # return | |
109 #end | |
110 if x.ndim != 1: | |
111 print 'x must be a vector.' | |
112 return | |
113 n = x.size | |
114 | |
115 #sigsize = x'*x/n; | |
116 sigsize = np.vdot(x,x)/n; | |
117 initial_given = 0; | |
118 err_mse = np.array([]); | |
119 iter_time = np.array([]); | |
120 STOPCRIT = 'M'; | |
121 STOPTOL = math.ceil(n/4.0); | |
122 MAXITER = n; | |
123 verbose = False; | |
124 s_initial = np.zeros(m); | |
125 | |
126 if verbose: | |
127 print 'Initialising...' | |
128 #end | |
129 | |
130 ########################################################################### | |
131 # Output variables | |
132 ########################################################################### | |
133 #switch nargout | |
134 # case 3 | |
135 # comp_err=true; | |
136 # comp_time=true; | |
137 # case 2 | |
138 # comp_err=true; | |
139 # comp_time=false; | |
140 # case 1 | |
141 # comp_err=false; | |
142 # comp_time=false; | |
143 # case 0 | |
144 # error('Please assign output variable.') | |
145 # otherwise | |
146 # error('Too many output arguments specified') | |
147 #end | |
148 if 'nargout' in opts: | |
149 if opts['nargout'] == 3: | |
150 comp_err = True | |
151 comp_time = True | |
152 elif opts['nargout'] == 2: | |
153 comp_err = True | |
154 comp_time = False | |
155 elif opts['nargout'] == 1: | |
156 comp_err = False | |
157 comp_time = False | |
158 elif opts['nargout'] == 0: | |
159 print 'Please assign output variable.' | |
160 return | |
161 else: | |
162 print 'Too many output arguments specified' | |
163 return | |
164 else: | |
165 # If not given, make default nargout = 3 | |
166 # and add nargout to options | |
167 opts['nargout'] = 3 | |
168 comp_err = True | |
169 comp_time = True | |
170 | |
171 ########################################################################### | |
172 # Look through options | |
173 ########################################################################### | |
174 # Put option into nice format | |
175 #Options={}; | |
176 #OS=nargin-3; | |
177 #c=1; | |
178 #for i=1:OS | |
179 # if isa(varargin{i},'cell') | |
180 # CellSize=length(varargin{i}); | |
181 # ThisCell=varargin{i}; | |
182 # for j=1:CellSize | |
183 # Options{c}=ThisCell{j}; | |
184 # c=c+1; | |
185 # end | |
186 # else | |
187 # Options{c}=varargin{i}; | |
188 # c=c+1; | |
189 # end | |
190 #end | |
191 #OS=length(Options); | |
192 #if rem(OS,2) | |
193 # error('Something is wrong with argument name and argument value pairs.') | |
194 #end | |
195 # | |
196 #for i=1:2:OS | |
197 # switch Options{i} | |
198 # case {'stopCrit'} | |
199 # if (strmatch(Options{i+1},{'M'; 'corr'; 'mse'; 'mse_change'},'exact')); | |
200 # STOPCRIT = Options{i+1}; | |
201 # else error('stopCrit must be char string [M, corr, mse, mse_change]. Exiting.'); end | |
202 # case {'stopTol'} | |
203 # if isa(Options{i+1},'numeric') ; STOPTOL = Options{i+1}; | |
204 # else error('stopTol must be number. Exiting.'); end | |
205 # case {'P_trans'} | |
206 # if isa(Options{i+1},'function_handle'); Pt = Options{i+1}; | |
207 # else error('P_trans must be function _handle. Exiting.'); end | |
208 # case {'maxIter'} | |
209 # if isa(Options{i+1},'numeric'); MAXITER = Options{i+1}; | |
210 # else error('maxIter must be a number. Exiting.'); end | |
211 # case {'verbose'} | |
212 # if isa(Options{i+1},'logical'); verbose = Options{i+1}; | |
213 # else error('verbose must be a logical. Exiting.'); end | |
214 # case {'start_val'} | |
215 # if isa(Options{i+1},'numeric') & length(Options{i+1}) == m ; | |
216 # s_initial = Options{i+1}; | |
217 # initial_given=1; | |
218 # else error('start_val must be a vector of length m. Exiting.'); end | |
219 # otherwise | |
220 # error('Unrecognised option. Exiting.') | |
221 # end | |
222 #end | |
223 if 'stopCrit' in opts: | |
224 STOPCRIT = opts['stopCrit'] | |
225 if 'stopTol' in opts: | |
226 if hasattr(opts['stopTol'], '__int__'): # check if numeric | |
227 STOPTOL = opts['stopTol'] | |
228 else: | |
229 raise TypeError('stopTol must be number. Exiting.') | |
230 if 'P_trans' in opts: | |
231 if hasattr(opts['P_trans'], '__call__'): # check if function handle | |
232 Pt = opts['P_trans'] | |
233 else: | |
234 raise TypeError('P_trans must be function _handle. Exiting.') | |
235 if 'maxIter' in opts: | |
236 if hasattr(opts['maxIter'], '__int__'): # check if numeric | |
237 MAXITER = opts['maxIter'] | |
238 else: | |
239 raise TypeError('maxIter must be a number. Exiting.') | |
240 if 'verbose' in opts: | |
241 # TODO: Should check here if is logical | |
242 verbose = opts['verbose'] | |
243 if 'start_val' in opts: | |
244 # TODO: Should check here if is numeric | |
245 if opts['start_val'].size == m: | |
246 s_initial = opts['start_val'] | |
247 initial_given = 1 | |
248 else: | |
249 raise ValueError('start_val must be a vector of length m. Exiting.') | |
250 # Don't exit if unknown option is given, simply ignore it | |
251 | |
252 #if strcmp(STOPCRIT,'M') | |
253 # maxM=STOPTOL; | |
254 #else | |
255 # maxM=MAXITER; | |
256 #end | |
257 if STOPCRIT == 'M': | |
258 maxM = STOPTOL | |
259 else: | |
260 maxM = MAXITER | |
261 | |
262 # if nargout >=2 | |
263 # err_mse = zeros(maxM,1); | |
264 # end | |
265 # if nargout ==3 | |
266 # iter_time = zeros(maxM,1); | |
267 # end | |
268 if opts['nargout'] >= 2: | |
269 err_mse = np.zeros(maxM) | |
270 if opts['nargout'] == 3: | |
271 iter_time = np.zeros(maxM) | |
272 | |
273 ########################################################################### | |
274 # Make P and Pt functions | |
275 ########################################################################### | |
276 #if isa(A,'float') P =@(z) A*z; Pt =@(z) A'*z; | |
277 #elseif isobject(A) P =@(z) A*z; Pt =@(z) A'*z; | |
278 #elseif isa(A,'function_handle') | |
279 # try | |
280 # if isa(Pt,'function_handle'); P=A; | |
281 # else error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end | |
282 # catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end | |
283 #else error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end | |
284 if hasattr(A, '__call__'): | |
285 if hasattr(Pt, '__call__'): | |
286 P = A | |
287 else: | |
288 raise TypeError('If P is a function handle, Pt also needs to be a function handle.') | |
289 else: | |
290 # TODO: should check here if A is matrix | |
291 P = lambda z: np.dot(A,z) | |
292 Pt = lambda z: np.dot(A.T,z) | |
293 | |
294 ########################################################################### | |
295 # Random Check to see if dictionary is normalised | |
296 ########################################################################### | |
297 # mask=zeros(m,1); | |
298 # mask(ceil(rand*m))=1; | |
299 # nP=norm(P(mask)); | |
300 # if abs(1-nP)>1e-3; | |
301 # display('Dictionary appears not to have unit norm columns.') | |
302 # end | |
303 mask = np.zeros(m) | |
304 mask[math.floor(np.random.rand() * m)] = 1 | |
305 nP = np.linalg.norm(P(mask)) | |
306 if abs(1-nP) > 1e-3: | |
307 print 'Dictionary appears not to have unit norm columns.' | |
308 #end | |
309 | |
310 ########################################################################### | |
311 # Check if we have enough memory and initialise | |
312 ########################################################################### | |
313 # try Q=zeros(n,maxM); | |
314 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.') | |
315 # end | |
316 # try R=zeros(maxM); | |
317 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.') | |
318 # end | |
319 try: | |
320 Q = np.zeros((n,maxM)) | |
321 except: | |
322 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.' | |
323 raise | |
324 try: | |
325 R = np.zeros((maxM, maxM)) | |
326 except: | |
327 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.' | |
328 raise | |
329 | |
330 ########################################################################### | |
331 # Do we start from zero or not? | |
332 ########################################################################### | |
333 #if initial_given ==1; | |
334 # IN = find(s_initial); | |
335 # if ~isempty(IN) | |
336 # Residual = x-P(s_initial); | |
337 # lengthIN=length(IN); | |
338 # z=[]; | |
339 # for k=1:length(IN) | |
340 # # Extract new element | |
341 # mask=zeros(m,1); | |
342 # mask(IN(k))=1; | |
343 # new_element=P(mask); | |
344 # | |
345 # # Orthogonalise new element | |
346 # qP=Q(:,1:k-1)'*new_element; | |
347 # q=new_element-Q(:,1:k-1)*(qP); | |
348 # | |
349 # nq=norm(q); | |
350 # q=q/nq; | |
351 # # Update QR factorisation | |
352 # R(1:k-1,k)=qP; | |
353 # R(k,k)=nq; | |
354 # Q(:,k)=q; | |
355 # | |
356 # z(k)=q'*x; | |
357 # end | |
358 # s = s_initial; | |
359 # Residual=x-Q(:,k)*z; | |
360 # oldERR = Residual'*Residual/n; | |
361 # else | |
362 # IN = []; | |
363 # Residual = x; | |
364 # s = s_initial; | |
365 # sigsize = x'*x/n; | |
366 # oldERR = sigsize; | |
367 # k=0; | |
368 # z=[]; | |
369 # end | |
370 # | |
371 #else | |
372 # IN = []; | |
373 # Residual = x; | |
374 # s = s_initial; | |
375 # sigsize = x'*x/n; | |
376 # oldERR = sigsize; | |
377 # k=0; | |
378 # z=[]; | |
379 #end | |
380 if initial_given == 1: | |
381 #IN = find(s_initial); | |
382 IN = np.nonzero(s_initial)[0].tolist() | |
383 #if ~isempty(IN) | |
384 if IN.size > 0: | |
385 Residual = x - P(s_initial) | |
386 lengthIN = IN.size | |
387 z = np.array([]) | |
388 #for k=1:length(IN) | |
389 for k in np.arange(IN.size): | |
390 # Extract new element | |
391 mask = np.zeros(m) | |
392 mask[IN[k]] = 1 | |
393 new_element = P(mask) | |
394 | |
395 # Orthogonalise new element | |
396 #qP=Q(:,1:k-1)'*new_element; | |
397 if k-1 >= 0: | |
398 qP = np.dot(Q[:,0:k].T , new_element) | |
399 #q=new_element-Q(:,1:k-1)*(qP); | |
400 q = new_element - np.dot(Q[:,0:k] , qP) | |
401 | |
402 nq = np.linalg.norm(q) | |
403 q = q / nq | |
404 # Update QR factorisation | |
405 R[0:k,k] = qP | |
406 R[k,k] = nq | |
407 Q[:,k] = q | |
408 else: | |
409 q = new_element | |
410 | |
411 nq = np.linalg.norm(q) | |
412 q = q / nq | |
413 # Update QR factorisation | |
414 R[k,k] = nq | |
415 Q[:,k] = q | |
416 | |
417 z[k] = np.dot(q.T , x) | |
418 #end | |
419 s = s_initial.copy() | |
420 Residual = x - np.dot(Q[:,k] , z) | |
421 oldERR = np.vdot(Residual , Residual) / n; | |
422 else: | |
423 #IN = np.array([], dtype = int) | |
424 IN = np.array([], dtype = int).tolist() | |
425 Residual = x.copy() | |
426 s = s_initial.copy() | |
427 sigsize = np.vdot(x , x) / n | |
428 oldERR = sigsize | |
429 k = 0 | |
430 #z = np.array([]) | |
431 z = [] | |
432 #end | |
433 | |
434 else: | |
435 #IN = np.array([], dtype = int) | |
436 IN = np.array([], dtype = int).tolist() | |
437 Residual = x.copy() | |
438 s = s_initial.copy() | |
439 sigsize = np.vdot(x , x) / n | |
440 oldERR = sigsize | |
441 k = 0 | |
442 #z = np.array([]) | |
443 z = [] | |
444 #end | |
445 | |
446 ########################################################################### | |
447 # Main algorithm | |
448 ########################################################################### | |
449 # if verbose | |
450 # display('Main iterations...') | |
451 # end | |
452 # tic | |
453 # t=0; | |
454 # DR=Pt(Residual); | |
455 # done = 0; | |
456 # iter=1; | |
457 if verbose: | |
458 print 'Main iterations...' | |
459 tic = time.time() | |
460 t = 0 | |
461 DR = Pt(Residual) | |
462 done = 0 | |
463 iter = 1 | |
464 | |
465 #while ~done | |
466 # | |
467 # # Select new element | |
468 # DR(IN)=0; | |
469 # # Nic: replace selection with random variable | |
470 # # i.e. Randomized OMP!! | |
471 # # DON'T FORGET ABOUT THIS!! | |
472 # [v I]=max(abs(DR)); | |
473 # #I = randp(exp(abs(DR).^2 ./ (norms.^2)'), [1 1]); | |
474 # IN=[IN I]; | |
475 # | |
476 # | |
477 # k=k+1; | |
478 # # Extract new element | |
479 # mask=zeros(m,1); | |
480 # mask(IN(k))=1; | |
481 # new_element=P(mask); | |
482 # | |
483 # # Orthogonalise new element | |
484 # qP=Q(:,1:k-1)'*new_element; | |
485 # q=new_element-Q(:,1:k-1)*(qP); | |
486 # | |
487 # nq=norm(q); | |
488 # q=q/nq; | |
489 # # Update QR factorisation | |
490 # R(1:k-1,k)=qP; | |
491 # R(k,k)=nq; | |
492 # Q(:,k)=q; | |
493 # | |
494 # z(k)=q'*x; | |
495 # | |
496 # # New residual | |
497 # Residual=Residual-q*(z(k)); | |
498 # DR=Pt(Residual); | |
499 # | |
500 # ERR=Residual'*Residual/n; | |
501 # if comp_err | |
502 # err_mse(iter)=ERR; | |
503 # end | |
504 # | |
505 # if comp_time | |
506 # iter_time(iter)=toc; | |
507 # end | |
508 # | |
509 ############################################################################ | |
510 ## Are we done yet? | |
511 ############################################################################ | |
512 # | |
513 # if strcmp(STOPCRIT,'M') | |
514 # if iter >= STOPTOL | |
515 # done =1; | |
516 # elseif verbose && toc-t>10 | |
517 # display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter)) | |
518 # t=toc; | |
519 # end | |
520 # elseif strcmp(STOPCRIT,'mse') | |
521 # if comp_err | |
522 # if err_mse(iter)<STOPTOL; | |
523 # done = 1; | |
524 # elseif verbose && toc-t>10 | |
525 # display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter))) | |
526 # t=toc; | |
527 # end | |
528 # else | |
529 # if ERR<STOPTOL; | |
530 # done = 1; | |
531 # elseif verbose && toc-t>10 | |
532 # display(sprintf('Iteration #i. --- #i mse',iter ,ERR)) | |
533 # t=toc; | |
534 # end | |
535 # end | |
536 # elseif strcmp(STOPCRIT,'mse_change') && iter >=2 | |
537 # if comp_err && iter >=2 | |
538 # if ((err_mse(iter-1)-err_mse(iter))/sigsize <STOPTOL); | |
539 # done = 1; | |
540 # elseif verbose && toc-t>10 | |
541 # display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize )) | |
542 # t=toc; | |
543 # end | |
544 # else | |
545 # if ((oldERR - ERR)/sigsize < STOPTOL); | |
546 # done = 1; | |
547 # elseif verbose && toc-t>10 | |
548 # display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize)) | |
549 # t=toc; | |
550 # end | |
551 # end | |
552 # elseif strcmp(STOPCRIT,'corr') | |
553 # if max(abs(DR)) < STOPTOL; | |
554 # done = 1; | |
555 # elseif verbose && toc-t>10 | |
556 # display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR)))) | |
557 # t=toc; | |
558 # end | |
559 # end | |
560 # | |
561 # # Also stop if residual gets too small or maxIter reached | |
562 # if comp_err | |
563 # if err_mse(iter)<1e-16 | |
564 # display('Stopping. Exact signal representation found!') | |
565 # done=1; | |
566 # end | |
567 # else | |
568 # | |
569 # | |
570 # if iter>1 | |
571 # if ERR<1e-16 | |
572 # display('Stopping. Exact signal representation found!') | |
573 # done=1; | |
574 # end | |
575 # end | |
576 # end | |
577 # | |
578 # if iter >= MAXITER | |
579 # display('Stopping. Maximum number of iterations reached!') | |
580 # done = 1; | |
581 # end | |
582 # | |
583 ############################################################################ | |
584 ## If not done, take another round | |
585 ############################################################################ | |
586 # | |
587 # if ~done | |
588 # iter=iter+1; | |
589 # oldERR=ERR; | |
590 # end | |
591 #end | |
592 while not done: | |
593 | |
594 # Select new element | |
595 DR[IN]=0 | |
596 #[v I]=max(abs(DR)); | |
597 #v = np.abs(DR).max() | |
598 I = np.abs(DR).argmax() | |
599 #IN = np.concatenate((IN,I)) | |
600 IN.append(I) | |
601 | |
602 | |
603 #k = k + 1 Move to end, since is zero based | |
604 | |
605 # Extract new element | |
606 mask = np.zeros(m) | |
607 mask[IN[k]] = 1 | |
608 new_element = P(mask) | |
609 | |
610 # Orthogonalise new element | |
611 if k-1 >= 0: | |
612 qP = np.dot(Q[:,0:k].T , new_element) | |
613 q = new_element - np.dot(Q[:,0:k] , qP) | |
614 | |
615 nq = np.linalg.norm(q) | |
616 q = q/nq | |
617 # Update QR factorisation | |
618 R[0:k,k] = qP | |
619 R[k,k] = nq | |
620 Q[:,k] = q | |
621 else: | |
622 q = new_element | |
623 | |
624 nq = np.linalg.norm(q) | |
625 q = q/nq | |
626 # Update QR factorisation | |
627 R[k,k] = nq | |
628 Q[:,k] = q | |
629 | |
630 #z[k]=np.vdot(q , x) | |
631 z.append(np.vdot(q , x)) | |
632 | |
633 # New residual | |
634 Residual = Residual - q * (z[k]) | |
635 DR = Pt(Residual) | |
636 | |
637 ERR = np.vdot(Residual , Residual) / n | |
638 if comp_err: | |
639 err_mse[iter-1] = ERR | |
640 #end | |
641 | |
642 if comp_time: | |
643 iter_time[iter-1] = time.time() - tic | |
644 #end | |
645 | |
646 ########################################################################### | |
647 # Are we done yet? | |
648 ########################################################################### | |
649 if STOPCRIT == 'M': | |
650 if iter >= STOPTOL: | |
651 done = 1 | |
652 elif verbose and time.time() - t > 10.0/1000: # time() returns sec | |
653 #display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter)) | |
654 print 'Iteration '+iter+'. --- '+(STOPTOL-iter)+' iterations to go' | |
655 t = time.time() | |
656 #end | |
657 elif STOPCRIT =='mse': | |
658 if comp_err: | |
659 if err_mse[iter-1] < STOPTOL: | |
660 done = 1 | |
661 elif verbose and time.time() - t > 10.0/1000: # time() returns sec | |
662 #display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter))) | |
663 print 'Iteration '+iter+'. --- '+err_mse[iter-1]+' mse' | |
664 t = time.time() | |
665 #end | |
666 else: | |
667 if ERR < STOPTOL: | |
668 done = 1 | |
669 elif verbose and time.time() - t > 10.0/1000: # time() returns sec | |
670 #display(sprintf('Iteration #i. --- #i mse',iter ,ERR)) | |
671 print 'Iteration '+iter+'. --- '+ERR+' mse' | |
672 t = time.time() | |
673 #end | |
674 #end | |
675 elif STOPCRIT == 'mse_change' and iter >=2: | |
676 if comp_err and iter >=2: | |
677 if ((err_mse[iter-2] - err_mse[iter-1])/sigsize < STOPTOL): | |
678 done = 1 | |
679 elif verbose and time.time() - t > 10.0/1000: # time() returns sec | |
680 #display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize )) | |
681 print 'Iteration '+iter+'. --- '+((err_mse[iter-2]-err_mse[iter-1])/sigsize)+' mse change' | |
682 t = time.time() | |
683 #end | |
684 else: | |
685 if ((oldERR - ERR)/sigsize < STOPTOL): | |
686 done = 1 | |
687 elif verbose and time.time() - t > 10.0/1000: # time() returns sec | |
688 #display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize)) | |
689 print 'Iteration '+iter+'. --- '+((oldERR - ERR)/sigsize)+' mse change' | |
690 t = time.time() | |
691 #end | |
692 #end | |
693 elif STOPCRIT == 'corr': | |
694 if np.abs(DR).max() < STOPTOL: | |
695 done = 1 | |
696 elif verbose and time.time() - t > 10.0/1000: # time() returns sec | |
697 #display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR)))) | |
698 print 'Iteration '+iter+'. --- '+(np.abs(DR).max())+' corr' | |
699 t = time.time() | |
700 #end | |
701 #end | |
702 | |
703 # Also stop if residual gets too small or maxIter reached | |
704 if comp_err: | |
705 if err_mse[iter-1] < 1e-14: | |
706 done = 1 | |
707 # Nic: added verbose check | |
708 if verbose: | |
709 print 'Stopping. Exact signal representation found!' | |
710 #end | |
711 else: | |
712 if iter > 1: | |
713 if ERR < 1e-14: | |
714 done = 1 | |
715 # Nic: added verbose check | |
716 if verbose: | |
717 print 'Stopping. Exact signal representation found!' | |
718 #end | |
719 #end | |
720 #end | |
721 | |
722 | |
723 if iter >= MAXITER: | |
724 done = 1 | |
725 # Nic: added verbose check | |
726 if verbose: | |
727 print 'Stopping. Maximum number of iterations reached!' | |
728 #end | |
729 | |
730 ########################################################################### | |
731 # If not done, take another round | |
732 ########################################################################### | |
733 if not done: | |
734 iter = iter + 1 | |
735 oldERR = ERR | |
736 #end | |
737 | |
738 # Moved here from front, since we are 0-based | |
739 k = k + 1 | |
740 #end | |
741 | |
742 ########################################################################### | |
743 # Now we can solve for s by back-substitution | |
744 ########################################################################### | |
745 #s(IN)=R(1:k,1:k)\z(1:k)'; | |
746 s[IN] = scipy.linalg.solve(R[0:k,0:k] , np.array(z[0:k])) | |
747 | |
748 ########################################################################### | |
749 # Only return as many elements as iterations | |
750 ########################################################################### | |
751 if opts['nargout'] >= 2: | |
752 err_mse = err_mse[0:iter-1] | |
753 #end | |
754 if opts['nargout'] == 3: | |
755 iter_time = iter_time[0:iter-1] | |
756 #end | |
757 if verbose: | |
758 print 'Done' | |
759 #end | |
760 | |
761 # Return | |
762 if opts['nargout'] == 1: | |
763 return s | |
764 elif opts['nargout'] == 2: | |
765 return s, err_mse | |
766 elif opts['nargout'] == 3: | |
767 return s, err_mse, iter_time | |
768 | |
769 # Change history | |
770 # | |
771 # 8 of Februray: Algo does no longer stop if dictionary is not normaliesd. | |
772 | |
773 # End of greed_omp_qr() function | |
774 #-------------------------------- | |
775 | |
776 | |
777 def omp_qr(x, dict, D, natom, tolerance): | |
778 """ Recover x using QR implementation of OMP | |
779 | |
780 Parameter | |
781 --------- | |
782 x: measurements | |
783 dict: dictionary | |
784 D: Gramian of dictionary | |
785 natom: iterations | |
786 tolerance: error tolerance | |
787 | |
788 Return | |
789 ------ | |
790 x_hat : estimate of x | |
791 gamma : indices where non-zero | |
792 | |
793 For more information, see http://media.aau.dk/null_space_pursuits/2011/10/efficient-omp.html | |
794 """ | |
795 msize, dictsize = dict.shape | |
796 normr2 = np.vdot(x,x) | |
797 normtol2 = tolerance*normr2 | |
798 R = np.zeros((natom,natom)) | |
799 Q = np.zeros((msize,natom)) | |
800 gamma = [] | |
801 | |
802 # find initial projections | |
803 origprojections = np.dot(x.T,dict) | |
804 origprojectionsT = origprojections.T | |
805 projections = origprojections.copy(); | |
806 | |
807 k = 0 | |
808 while (normr2 > normtol2) and (k < natom): | |
809 # find index of maximum magnitude projection | |
810 newgam = np.argmax(np.abs(projections ** 2)) | |
811 gamma.append(newgam) | |
812 # update QR factorization, projections, and residual energy | |
813 if k == 0: | |
814 R[0,0] = 1 | |
815 Q[:,0] = dict[:,newgam].copy() | |
816 # update projections | |
817 QtempQtempT = np.outer(Q[:,0],Q[:,0]) | |
818 projections -= np.dot(x.T, np.dot(QtempQtempT,dict)) | |
819 # update residual energy | |
820 normr2 -= np.vdot(x, np.dot(QtempQtempT,x)) | |
821 else: | |
822 w = scipy.linalg.solve_triangular(R[0:k,0:k],D[gamma[0:k],newgam],trans=1) | |
823 R[k,k] = np.sqrt(1-np.vdot(w,w)) | |
824 R[0:k,k] = w.copy() | |
825 Q[:,k] = (dict[:,newgam] - np.dot(QtempQtempT,dict[:,newgam]))/R[k,k] | |
826 QkQkT = np.outer(Q[:,k],Q[:,k]) | |
827 xTQkQkT = np.dot(x.T,QkQkT) | |
828 QtempQtempT += QkQkT | |
829 # update projections | |
830 projections -= np.dot(xTQkQkT,dict) | |
831 # update residual energy | |
832 normr2 -= np.dot(xTQkQkT,x) | |
833 | |
834 k += 1 | |
835 | |
836 # build solution | |
837 tempR = R[0:k,0:k] | |
838 w = scipy.linalg.solve_triangular(tempR,origprojectionsT[gamma[0:k]],trans=1) | |
839 x_hat = np.zeros((dictsize,1)) | |
840 x_hat[gamma[0:k]] = scipy.linalg.solve_triangular(tempR,w) | |
841 | |
842 return x_hat, gamma |