comparison pyCSalgos/OMP/omp_QR.py @ 2:735a0e24575c

Organized folders: added tests, apps, matlab, docs folders. Added __init__.py files
author nikcleju
date Fri, 21 Oct 2011 13:53:49 +0000
parents
children e1da5140c9a5
comparison
equal deleted inserted replaced
1:2a2abf5092f8 2:735a0e24575c
1 import numpy as np
2 import scipy.linalg
3 import time
4 import math
5
6
7 #function [s, err_mse, iter_time]=greed_omp_qr(x,A,m,varargin)
8 def greed_omp_qr(x,A,m,opts=[]):
9 # greed_omp_qr: Orthogonal Matching Pursuit algorithm based on QR
10 # factorisation
11 # Nic: translated to Python on 19.10.2011. Original Matlab Code by Thomas Blumensath
12 ###########################################################################
13 # Usage
14 # [s, err_mse, iter_time]=greed_omp_qr(x,P,m,'option_name','option_value')
15 ###########################################################################
16 ###########################################################################
17 # Input
18 # Mandatory:
19 # x Observation vector to be decomposed
20 # P Either:
21 # 1) An nxm matrix (n must be dimension of x)
22 # 2) A function handle (type "help function_format"
23 # for more information)
24 # Also requires specification of P_trans option.
25 # 3) An object handle (type "help object_format" for
26 # more information)
27 # m length of s
28 #
29 # Possible additional options:
30 # (specify as many as you want using 'option_name','option_value' pairs)
31 # See below for explanation of options:
32 #__________________________________________________________________________
33 # option_name | available option_values | default
34 #--------------------------------------------------------------------------
35 # stopCrit | M, corr, mse, mse_change | M
36 # stopTol | number (see below) | n/4
37 # P_trans | function_handle (see below) |
38 # maxIter | positive integer (see below) | n
39 # verbose | true, false | false
40 # start_val | vector of length m | zeros
41 #
42 # Available stopping criteria :
43 # M - Extracts exactly M = stopTol elements.
44 # corr - Stops when maximum correlation between
45 # residual and atoms is below stopTol value.
46 # mse - Stops when mean squared error of residual
47 # is below stopTol value.
48 # mse_change - Stops when the change in the mean squared
49 # error falls below stopTol value.
50 #
51 # stopTol: Value for stopping criterion.
52 #
53 # P_trans: If P is a function handle, then P_trans has to be specified and
54 # must be a function handle.
55 #
56 # maxIter: Maximum number of allowed iterations.
57 #
58 # verbose: Logical value to allow algorithm progress to be displayed.
59 #
60 # start_val: Allows algorithms to start from partial solution.
61 #
62 ###########################################################################
63 # Outputs
64 # s Solution vector
65 # err_mse Vector containing mse of approximation error for each
66 # iteration
67 # iter_time Vector containing computation times for each iteration
68 #
69 ###########################################################################
70 # Description
71 # greed_omp_qr performs a greedy signal decomposition.
72 # In each iteration a new element is selected depending on the inner
73 # product between the current residual and columns in P.
74 # The non-zero elements of s are approximated by orthogonally projecting
75 # x onto the selected elements in each iteration.
76 # The algorithm uses QR decomposition.
77 #
78 # See Also
79 # greed_omp_chol, greed_omp_cg, greed_omp_cgp, greed_omp_pinv,
80 # greed_omp_linsolve, greed_gp, greed_nomp
81 #
82 # Copyright (c) 2007 Thomas Blumensath
83 #
84 # The University of Edinburgh
85 # Email: thomas.blumensath@ed.ac.uk
86 # Comments and bug reports welcome
87 #
88 # This file is part of sparsity Version 0.1
89 # Created: April 2007
90 #
91 # Part of this toolbox was developed with the support of EPSRC Grant
92 # D000246/1
93 #
94 # Please read COPYRIGHT.m for terms and conditions.
95
96 ###########################################################################
97 # Default values and initialisation
98 ###########################################################################
99 #[n1 n2]=size(x);
100 #n1,n2 = x.shape
101 #if n2 == 1
102 # n=n1;
103 #elseif n1 == 1
104 # x=x';
105 # n=n2;
106 #else
107 # display('x must be a vector.');
108 # return
109 #end
110 if x.ndim != 1:
111 print 'x must be a vector.'
112 return
113 n = x.size
114
115 #sigsize = x'*x/n;
116 sigsize = np.vdot(x,x)/n;
117 initial_given = 0;
118 err_mse = np.array([]);
119 iter_time = np.array([]);
120 STOPCRIT = 'M';
121 STOPTOL = math.ceil(n/4.0);
122 MAXITER = n;
123 verbose = False;
124 s_initial = np.zeros(m);
125
126 if verbose:
127 print 'Initialising...'
128 #end
129
130 ###########################################################################
131 # Output variables
132 ###########################################################################
133 #switch nargout
134 # case 3
135 # comp_err=true;
136 # comp_time=true;
137 # case 2
138 # comp_err=true;
139 # comp_time=false;
140 # case 1
141 # comp_err=false;
142 # comp_time=false;
143 # case 0
144 # error('Please assign output variable.')
145 # otherwise
146 # error('Too many output arguments specified')
147 #end
148 if 'nargout' in opts:
149 if opts['nargout'] == 3:
150 comp_err = True
151 comp_time = True
152 elif opts['nargout'] == 2:
153 comp_err = True
154 comp_time = False
155 elif opts['nargout'] == 1:
156 comp_err = False
157 comp_time = False
158 elif opts['nargout'] == 0:
159 print 'Please assign output variable.'
160 return
161 else:
162 print 'Too many output arguments specified'
163 return
164 else:
165 # If not given, make default nargout = 3
166 # and add nargout to options
167 opts['nargout'] = 3
168 comp_err = True
169 comp_time = True
170
171 ###########################################################################
172 # Look through options
173 ###########################################################################
174 # Put option into nice format
175 #Options={};
176 #OS=nargin-3;
177 #c=1;
178 #for i=1:OS
179 # if isa(varargin{i},'cell')
180 # CellSize=length(varargin{i});
181 # ThisCell=varargin{i};
182 # for j=1:CellSize
183 # Options{c}=ThisCell{j};
184 # c=c+1;
185 # end
186 # else
187 # Options{c}=varargin{i};
188 # c=c+1;
189 # end
190 #end
191 #OS=length(Options);
192 #if rem(OS,2)
193 # error('Something is wrong with argument name and argument value pairs.')
194 #end
195 #
196 #for i=1:2:OS
197 # switch Options{i}
198 # case {'stopCrit'}
199 # if (strmatch(Options{i+1},{'M'; 'corr'; 'mse'; 'mse_change'},'exact'));
200 # STOPCRIT = Options{i+1};
201 # else error('stopCrit must be char string [M, corr, mse, mse_change]. Exiting.'); end
202 # case {'stopTol'}
203 # if isa(Options{i+1},'numeric') ; STOPTOL = Options{i+1};
204 # else error('stopTol must be number. Exiting.'); end
205 # case {'P_trans'}
206 # if isa(Options{i+1},'function_handle'); Pt = Options{i+1};
207 # else error('P_trans must be function _handle. Exiting.'); end
208 # case {'maxIter'}
209 # if isa(Options{i+1},'numeric'); MAXITER = Options{i+1};
210 # else error('maxIter must be a number. Exiting.'); end
211 # case {'verbose'}
212 # if isa(Options{i+1},'logical'); verbose = Options{i+1};
213 # else error('verbose must be a logical. Exiting.'); end
214 # case {'start_val'}
215 # if isa(Options{i+1},'numeric') & length(Options{i+1}) == m ;
216 # s_initial = Options{i+1};
217 # initial_given=1;
218 # else error('start_val must be a vector of length m. Exiting.'); end
219 # otherwise
220 # error('Unrecognised option. Exiting.')
221 # end
222 #end
223 if 'stopCrit' in opts:
224 STOPCRIT = opts['stopCrit']
225 if 'stopTol' in opts:
226 if hasattr(opts['stopTol'], '__int__'): # check if numeric
227 STOPTOL = opts['stopTol']
228 else:
229 raise TypeError('stopTol must be number. Exiting.')
230 if 'P_trans' in opts:
231 if hasattr(opts['P_trans'], '__call__'): # check if function handle
232 Pt = opts['P_trans']
233 else:
234 raise TypeError('P_trans must be function _handle. Exiting.')
235 if 'maxIter' in opts:
236 if hasattr(opts['maxIter'], '__int__'): # check if numeric
237 MAXITER = opts['maxIter']
238 else:
239 raise TypeError('maxIter must be a number. Exiting.')
240 if 'verbose' in opts:
241 # TODO: Should check here if is logical
242 verbose = opts['verbose']
243 if 'start_val' in opts:
244 # TODO: Should check here if is numeric
245 if opts['start_val'].size == m:
246 s_initial = opts['start_val']
247 initial_given = 1
248 else:
249 raise ValueError('start_val must be a vector of length m. Exiting.')
250 # Don't exit if unknown option is given, simply ignore it
251
252 #if strcmp(STOPCRIT,'M')
253 # maxM=STOPTOL;
254 #else
255 # maxM=MAXITER;
256 #end
257 if STOPCRIT == 'M':
258 maxM = STOPTOL
259 else:
260 maxM = MAXITER
261
262 # if nargout >=2
263 # err_mse = zeros(maxM,1);
264 # end
265 # if nargout ==3
266 # iter_time = zeros(maxM,1);
267 # end
268 if opts['nargout'] >= 2:
269 err_mse = np.zeros(maxM)
270 if opts['nargout'] == 3:
271 iter_time = np.zeros(maxM)
272
273 ###########################################################################
274 # Make P and Pt functions
275 ###########################################################################
276 #if isa(A,'float') P =@(z) A*z; Pt =@(z) A'*z;
277 #elseif isobject(A) P =@(z) A*z; Pt =@(z) A'*z;
278 #elseif isa(A,'function_handle')
279 # try
280 # if isa(Pt,'function_handle'); P=A;
281 # else error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end
282 # catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end
283 #else error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end
284 if hasattr(A, '__call__'):
285 if hasattr(Pt, '__call__'):
286 P = A
287 else:
288 raise TypeError('If P is a function handle, Pt also needs to be a function handle.')
289 else:
290 # TODO: should check here if A is matrix
291 P = lambda z: np.dot(A,z)
292 Pt = lambda z: np.dot(A.T,z)
293
294 ###########################################################################
295 # Random Check to see if dictionary is normalised
296 ###########################################################################
297 # mask=zeros(m,1);
298 # mask(ceil(rand*m))=1;
299 # nP=norm(P(mask));
300 # if abs(1-nP)>1e-3;
301 # display('Dictionary appears not to have unit norm columns.')
302 # end
303 mask = np.zeros(m)
304 mask[math.floor(np.random.rand() * m)] = 1
305 nP = np.linalg.norm(P(mask))
306 if abs(1-nP) > 1e-3:
307 print 'Dictionary appears not to have unit norm columns.'
308 #end
309
310 ###########################################################################
311 # Check if we have enough memory and initialise
312 ###########################################################################
313 # try Q=zeros(n,maxM);
314 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
315 # end
316 # try R=zeros(maxM);
317 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
318 # end
319 try:
320 Q = np.zeros((n,maxM))
321 except:
322 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
323 raise
324 try:
325 R = np.zeros((maxM, maxM))
326 except:
327 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
328 raise
329
330 ###########################################################################
331 # Do we start from zero or not?
332 ###########################################################################
333 #if initial_given ==1;
334 # IN = find(s_initial);
335 # if ~isempty(IN)
336 # Residual = x-P(s_initial);
337 # lengthIN=length(IN);
338 # z=[];
339 # for k=1:length(IN)
340 # # Extract new element
341 # mask=zeros(m,1);
342 # mask(IN(k))=1;
343 # new_element=P(mask);
344 #
345 # # Orthogonalise new element
346 # qP=Q(:,1:k-1)'*new_element;
347 # q=new_element-Q(:,1:k-1)*(qP);
348 #
349 # nq=norm(q);
350 # q=q/nq;
351 # # Update QR factorisation
352 # R(1:k-1,k)=qP;
353 # R(k,k)=nq;
354 # Q(:,k)=q;
355 #
356 # z(k)=q'*x;
357 # end
358 # s = s_initial;
359 # Residual=x-Q(:,k)*z;
360 # oldERR = Residual'*Residual/n;
361 # else
362 # IN = [];
363 # Residual = x;
364 # s = s_initial;
365 # sigsize = x'*x/n;
366 # oldERR = sigsize;
367 # k=0;
368 # z=[];
369 # end
370 #
371 #else
372 # IN = [];
373 # Residual = x;
374 # s = s_initial;
375 # sigsize = x'*x/n;
376 # oldERR = sigsize;
377 # k=0;
378 # z=[];
379 #end
380 if initial_given == 1:
381 #IN = find(s_initial);
382 IN = np.nonzero(s_initial)[0].tolist()
383 #if ~isempty(IN)
384 if IN.size > 0:
385 Residual = x - P(s_initial)
386 lengthIN = IN.size
387 z = np.array([])
388 #for k=1:length(IN)
389 for k in np.arange(IN.size):
390 # Extract new element
391 mask = np.zeros(m)
392 mask[IN[k]] = 1
393 new_element = P(mask)
394
395 # Orthogonalise new element
396 #qP=Q(:,1:k-1)'*new_element;
397 if k-1 >= 0:
398 qP = np.dot(Q[:,0:k].T , new_element)
399 #q=new_element-Q(:,1:k-1)*(qP);
400 q = new_element - np.dot(Q[:,0:k] , qP)
401
402 nq = np.linalg.norm(q)
403 q = q / nq
404 # Update QR factorisation
405 R[0:k,k] = qP
406 R[k,k] = nq
407 Q[:,k] = q
408 else:
409 q = new_element
410
411 nq = np.linalg.norm(q)
412 q = q / nq
413 # Update QR factorisation
414 R[k,k] = nq
415 Q[:,k] = q
416
417 z[k] = np.dot(q.T , x)
418 #end
419 s = s_initial.copy()
420 Residual = x - np.dot(Q[:,k] , z)
421 oldERR = np.vdot(Residual , Residual) / n;
422 else:
423 #IN = np.array([], dtype = int)
424 IN = np.array([], dtype = int).tolist()
425 Residual = x.copy()
426 s = s_initial.copy()
427 sigsize = np.vdot(x , x) / n
428 oldERR = sigsize
429 k = 0
430 #z = np.array([])
431 z = []
432 #end
433
434 else:
435 #IN = np.array([], dtype = int)
436 IN = np.array([], dtype = int).tolist()
437 Residual = x.copy()
438 s = s_initial.copy()
439 sigsize = np.vdot(x , x) / n
440 oldERR = sigsize
441 k = 0
442 #z = np.array([])
443 z = []
444 #end
445
446 ###########################################################################
447 # Main algorithm
448 ###########################################################################
449 # if verbose
450 # display('Main iterations...')
451 # end
452 # tic
453 # t=0;
454 # DR=Pt(Residual);
455 # done = 0;
456 # iter=1;
457 if verbose:
458 print 'Main iterations...'
459 tic = time.time()
460 t = 0
461 DR = Pt(Residual)
462 done = 0
463 iter = 1
464
465 #while ~done
466 #
467 # # Select new element
468 # DR(IN)=0;
469 # # Nic: replace selection with random variable
470 # # i.e. Randomized OMP!!
471 # # DON'T FORGET ABOUT THIS!!
472 # [v I]=max(abs(DR));
473 # #I = randp(exp(abs(DR).^2 ./ (norms.^2)'), [1 1]);
474 # IN=[IN I];
475 #
476 #
477 # k=k+1;
478 # # Extract new element
479 # mask=zeros(m,1);
480 # mask(IN(k))=1;
481 # new_element=P(mask);
482 #
483 # # Orthogonalise new element
484 # qP=Q(:,1:k-1)'*new_element;
485 # q=new_element-Q(:,1:k-1)*(qP);
486 #
487 # nq=norm(q);
488 # q=q/nq;
489 # # Update QR factorisation
490 # R(1:k-1,k)=qP;
491 # R(k,k)=nq;
492 # Q(:,k)=q;
493 #
494 # z(k)=q'*x;
495 #
496 # # New residual
497 # Residual=Residual-q*(z(k));
498 # DR=Pt(Residual);
499 #
500 # ERR=Residual'*Residual/n;
501 # if comp_err
502 # err_mse(iter)=ERR;
503 # end
504 #
505 # if comp_time
506 # iter_time(iter)=toc;
507 # end
508 #
509 ############################################################################
510 ## Are we done yet?
511 ############################################################################
512 #
513 # if strcmp(STOPCRIT,'M')
514 # if iter >= STOPTOL
515 # done =1;
516 # elseif verbose && toc-t>10
517 # display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
518 # t=toc;
519 # end
520 # elseif strcmp(STOPCRIT,'mse')
521 # if comp_err
522 # if err_mse(iter)<STOPTOL;
523 # done = 1;
524 # elseif verbose && toc-t>10
525 # display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
526 # t=toc;
527 # end
528 # else
529 # if ERR<STOPTOL;
530 # done = 1;
531 # elseif verbose && toc-t>10
532 # display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
533 # t=toc;
534 # end
535 # end
536 # elseif strcmp(STOPCRIT,'mse_change') && iter >=2
537 # if comp_err && iter >=2
538 # if ((err_mse(iter-1)-err_mse(iter))/sigsize <STOPTOL);
539 # done = 1;
540 # elseif verbose && toc-t>10
541 # display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
542 # t=toc;
543 # end
544 # else
545 # if ((oldERR - ERR)/sigsize < STOPTOL);
546 # done = 1;
547 # elseif verbose && toc-t>10
548 # display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
549 # t=toc;
550 # end
551 # end
552 # elseif strcmp(STOPCRIT,'corr')
553 # if max(abs(DR)) < STOPTOL;
554 # done = 1;
555 # elseif verbose && toc-t>10
556 # display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
557 # t=toc;
558 # end
559 # end
560 #
561 # # Also stop if residual gets too small or maxIter reached
562 # if comp_err
563 # if err_mse(iter)<1e-16
564 # display('Stopping. Exact signal representation found!')
565 # done=1;
566 # end
567 # else
568 #
569 #
570 # if iter>1
571 # if ERR<1e-16
572 # display('Stopping. Exact signal representation found!')
573 # done=1;
574 # end
575 # end
576 # end
577 #
578 # if iter >= MAXITER
579 # display('Stopping. Maximum number of iterations reached!')
580 # done = 1;
581 # end
582 #
583 ############################################################################
584 ## If not done, take another round
585 ############################################################################
586 #
587 # if ~done
588 # iter=iter+1;
589 # oldERR=ERR;
590 # end
591 #end
592 while not done:
593
594 # Select new element
595 DR[IN]=0
596 #[v I]=max(abs(DR));
597 #v = np.abs(DR).max()
598 I = np.abs(DR).argmax()
599 #IN = np.concatenate((IN,I))
600 IN.append(I)
601
602
603 #k = k + 1 Move to end, since is zero based
604
605 # Extract new element
606 mask = np.zeros(m)
607 mask[IN[k]] = 1
608 new_element = P(mask)
609
610 # Orthogonalise new element
611 if k-1 >= 0:
612 qP = np.dot(Q[:,0:k].T , new_element)
613 q = new_element - np.dot(Q[:,0:k] , qP)
614
615 nq = np.linalg.norm(q)
616 q = q/nq
617 # Update QR factorisation
618 R[0:k,k] = qP
619 R[k,k] = nq
620 Q[:,k] = q
621 else:
622 q = new_element
623
624 nq = np.linalg.norm(q)
625 q = q/nq
626 # Update QR factorisation
627 R[k,k] = nq
628 Q[:,k] = q
629
630 #z[k]=np.vdot(q , x)
631 z.append(np.vdot(q , x))
632
633 # New residual
634 Residual = Residual - q * (z[k])
635 DR = Pt(Residual)
636
637 ERR = np.vdot(Residual , Residual) / n
638 if comp_err:
639 err_mse[iter-1] = ERR
640 #end
641
642 if comp_time:
643 iter_time[iter-1] = time.time() - tic
644 #end
645
646 ###########################################################################
647 # Are we done yet?
648 ###########################################################################
649 if STOPCRIT == 'M':
650 if iter >= STOPTOL:
651 done = 1
652 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
653 #display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
654 print 'Iteration '+iter+'. --- '+(STOPTOL-iter)+' iterations to go'
655 t = time.time()
656 #end
657 elif STOPCRIT =='mse':
658 if comp_err:
659 if err_mse[iter-1] < STOPTOL:
660 done = 1
661 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
662 #display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
663 print 'Iteration '+iter+'. --- '+err_mse[iter-1]+' mse'
664 t = time.time()
665 #end
666 else:
667 if ERR < STOPTOL:
668 done = 1
669 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
670 #display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
671 print 'Iteration '+iter+'. --- '+ERR+' mse'
672 t = time.time()
673 #end
674 #end
675 elif STOPCRIT == 'mse_change' and iter >=2:
676 if comp_err and iter >=2:
677 if ((err_mse[iter-2] - err_mse[iter-1])/sigsize < STOPTOL):
678 done = 1
679 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
680 #display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
681 print 'Iteration '+iter+'. --- '+((err_mse[iter-2]-err_mse[iter-1])/sigsize)+' mse change'
682 t = time.time()
683 #end
684 else:
685 if ((oldERR - ERR)/sigsize < STOPTOL):
686 done = 1
687 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
688 #display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
689 print 'Iteration '+iter+'. --- '+((oldERR - ERR)/sigsize)+' mse change'
690 t = time.time()
691 #end
692 #end
693 elif STOPCRIT == 'corr':
694 if np.abs(DR).max() < STOPTOL:
695 done = 1
696 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
697 #display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
698 print 'Iteration '+iter+'. --- '+(np.abs(DR).max())+' corr'
699 t = time.time()
700 #end
701 #end
702
703 # Also stop if residual gets too small or maxIter reached
704 if comp_err:
705 if err_mse[iter-1] < 1e-14:
706 done = 1
707 # Nic: added verbose check
708 if verbose:
709 print 'Stopping. Exact signal representation found!'
710 #end
711 else:
712 if iter > 1:
713 if ERR < 1e-14:
714 done = 1
715 # Nic: added verbose check
716 if verbose:
717 print 'Stopping. Exact signal representation found!'
718 #end
719 #end
720 #end
721
722
723 if iter >= MAXITER:
724 done = 1
725 # Nic: added verbose check
726 if verbose:
727 print 'Stopping. Maximum number of iterations reached!'
728 #end
729
730 ###########################################################################
731 # If not done, take another round
732 ###########################################################################
733 if not done:
734 iter = iter + 1
735 oldERR = ERR
736 #end
737
738 # Moved here from front, since we are 0-based
739 k = k + 1
740 #end
741
742 ###########################################################################
743 # Now we can solve for s by back-substitution
744 ###########################################################################
745 #s(IN)=R(1:k,1:k)\z(1:k)';
746 s[IN] = scipy.linalg.solve(R[0:k,0:k] , np.array(z[0:k]))
747
748 ###########################################################################
749 # Only return as many elements as iterations
750 ###########################################################################
751 if opts['nargout'] >= 2:
752 err_mse = err_mse[0:iter-1]
753 #end
754 if opts['nargout'] == 3:
755 iter_time = iter_time[0:iter-1]
756 #end
757 if verbose:
758 print 'Done'
759 #end
760
761 # Return
762 if opts['nargout'] == 1:
763 return s
764 elif opts['nargout'] == 2:
765 return s, err_mse
766 elif opts['nargout'] == 3:
767 return s, err_mse, iter_time
768
769 # Change history
770 #
771 # 8 of Februray: Algo does no longer stop if dictionary is not normaliesd.
772
773 # End of greed_omp_qr() function
774 #--------------------------------
775
776
777 def omp_qr(x, dict, D, natom, tolerance):
778 """ Recover x using QR implementation of OMP
779
780 Parameter
781 ---------
782 x: measurements
783 dict: dictionary
784 D: Gramian of dictionary
785 natom: iterations
786 tolerance: error tolerance
787
788 Return
789 ------
790 x_hat : estimate of x
791 gamma : indices where non-zero
792
793 For more information, see http://media.aau.dk/null_space_pursuits/2011/10/efficient-omp.html
794 """
795 msize, dictsize = dict.shape
796 normr2 = np.vdot(x,x)
797 normtol2 = tolerance*normr2
798 R = np.zeros((natom,natom))
799 Q = np.zeros((msize,natom))
800 gamma = []
801
802 # find initial projections
803 origprojections = np.dot(x.T,dict)
804 origprojectionsT = origprojections.T
805 projections = origprojections.copy();
806
807 k = 0
808 while (normr2 > normtol2) and (k < natom):
809 # find index of maximum magnitude projection
810 newgam = np.argmax(np.abs(projections ** 2))
811 gamma.append(newgam)
812 # update QR factorization, projections, and residual energy
813 if k == 0:
814 R[0,0] = 1
815 Q[:,0] = dict[:,newgam].copy()
816 # update projections
817 QtempQtempT = np.outer(Q[:,0],Q[:,0])
818 projections -= np.dot(x.T, np.dot(QtempQtempT,dict))
819 # update residual energy
820 normr2 -= np.vdot(x, np.dot(QtempQtempT,x))
821 else:
822 w = scipy.linalg.solve_triangular(R[0:k,0:k],D[gamma[0:k],newgam],trans=1)
823 R[k,k] = np.sqrt(1-np.vdot(w,w))
824 R[0:k,k] = w.copy()
825 Q[:,k] = (dict[:,newgam] - np.dot(QtempQtempT,dict[:,newgam]))/R[k,k]
826 QkQkT = np.outer(Q[:,k],Q[:,k])
827 xTQkQkT = np.dot(x.T,QkQkT)
828 QtempQtempT += QkQkT
829 # update projections
830 projections -= np.dot(xTQkQkT,dict)
831 # update residual energy
832 normr2 -= np.dot(xTQkQkT,x)
833
834 k += 1
835
836 # build solution
837 tempR = R[0:k,0:k]
838 w = scipy.linalg.solve_triangular(tempR,origprojectionsT[gamma[0:k]],trans=1)
839 x_hat = np.zeros((dictsize,1))
840 x_hat[gamma[0:k]] = scipy.linalg.solve_triangular(tempR,w)
841
842 return x_hat, gamma