nikcleju@2
|
1 import numpy as np
|
nikcleju@2
|
2 import scipy.linalg
|
nikcleju@2
|
3 import time
|
nikcleju@2
|
4 import math
|
nikcleju@2
|
5
|
nikcleju@2
|
6
|
nikcleju@2
|
7 #function [s, err_mse, iter_time]=greed_omp_qr(x,A,m,varargin)
|
nikcleju@2
|
8 def greed_omp_qr(x,A,m,opts=[]):
|
nikcleju@2
|
9 # greed_omp_qr: Orthogonal Matching Pursuit algorithm based on QR
|
nikcleju@2
|
10 # factorisation
|
nikcleju@2
|
11 # Nic: translated to Python on 19.10.2011. Original Matlab Code by Thomas Blumensath
|
nikcleju@2
|
12 ###########################################################################
|
nikcleju@2
|
13 # Usage
|
nikcleju@2
|
14 # [s, err_mse, iter_time]=greed_omp_qr(x,P,m,'option_name','option_value')
|
nikcleju@2
|
15 ###########################################################################
|
nikcleju@2
|
16 ###########################################################################
|
nikcleju@2
|
17 # Input
|
nikcleju@2
|
18 # Mandatory:
|
nikcleju@2
|
19 # x Observation vector to be decomposed
|
nikcleju@2
|
20 # P Either:
|
nikcleju@2
|
21 # 1) An nxm matrix (n must be dimension of x)
|
nikcleju@2
|
22 # 2) A function handle (type "help function_format"
|
nikcleju@2
|
23 # for more information)
|
nikcleju@2
|
24 # Also requires specification of P_trans option.
|
nikcleju@2
|
25 # 3) An object handle (type "help object_format" for
|
nikcleju@2
|
26 # more information)
|
nikcleju@2
|
27 # m length of s
|
nikcleju@2
|
28 #
|
nikcleju@2
|
29 # Possible additional options:
|
nikcleju@2
|
30 # (specify as many as you want using 'option_name','option_value' pairs)
|
nikcleju@2
|
31 # See below for explanation of options:
|
nikcleju@2
|
32 #__________________________________________________________________________
|
nikcleju@2
|
33 # option_name | available option_values | default
|
nikcleju@2
|
34 #--------------------------------------------------------------------------
|
nikcleju@2
|
35 # stopCrit | M, corr, mse, mse_change | M
|
nikcleju@2
|
36 # stopTol | number (see below) | n/4
|
nikcleju@2
|
37 # P_trans | function_handle (see below) |
|
nikcleju@2
|
38 # maxIter | positive integer (see below) | n
|
nikcleju@2
|
39 # verbose | true, false | false
|
nikcleju@2
|
40 # start_val | vector of length m | zeros
|
nikcleju@2
|
41 #
|
nikcleju@2
|
42 # Available stopping criteria :
|
nikcleju@2
|
43 # M - Extracts exactly M = stopTol elements.
|
nikcleju@2
|
44 # corr - Stops when maximum correlation between
|
nikcleju@2
|
45 # residual and atoms is below stopTol value.
|
nikcleju@2
|
46 # mse - Stops when mean squared error of residual
|
nikcleju@2
|
47 # is below stopTol value.
|
nikcleju@2
|
48 # mse_change - Stops when the change in the mean squared
|
nikcleju@2
|
49 # error falls below stopTol value.
|
nikcleju@2
|
50 #
|
nikcleju@2
|
51 # stopTol: Value for stopping criterion.
|
nikcleju@2
|
52 #
|
nikcleju@2
|
53 # P_trans: If P is a function handle, then P_trans has to be specified and
|
nikcleju@2
|
54 # must be a function handle.
|
nikcleju@2
|
55 #
|
nikcleju@2
|
56 # maxIter: Maximum number of allowed iterations.
|
nikcleju@2
|
57 #
|
nikcleju@2
|
58 # verbose: Logical value to allow algorithm progress to be displayed.
|
nikcleju@2
|
59 #
|
nikcleju@2
|
60 # start_val: Allows algorithms to start from partial solution.
|
nikcleju@2
|
61 #
|
nikcleju@2
|
62 ###########################################################################
|
nikcleju@2
|
63 # Outputs
|
nikcleju@2
|
64 # s Solution vector
|
nikcleju@2
|
65 # err_mse Vector containing mse of approximation error for each
|
nikcleju@2
|
66 # iteration
|
nikcleju@2
|
67 # iter_time Vector containing computation times for each iteration
|
nikcleju@2
|
68 #
|
nikcleju@2
|
69 ###########################################################################
|
nikcleju@2
|
70 # Description
|
nikcleju@2
|
71 # greed_omp_qr performs a greedy signal decomposition.
|
nikcleju@2
|
72 # In each iteration a new element is selected depending on the inner
|
nikcleju@2
|
73 # product between the current residual and columns in P.
|
nikcleju@2
|
74 # The non-zero elements of s are approximated by orthogonally projecting
|
nikcleju@2
|
75 # x onto the selected elements in each iteration.
|
nikcleju@2
|
76 # The algorithm uses QR decomposition.
|
nikcleju@2
|
77 #
|
nikcleju@2
|
78 # See Also
|
nikcleju@2
|
79 # greed_omp_chol, greed_omp_cg, greed_omp_cgp, greed_omp_pinv,
|
nikcleju@2
|
80 # greed_omp_linsolve, greed_gp, greed_nomp
|
nikcleju@2
|
81 #
|
nikcleju@2
|
82 # Copyright (c) 2007 Thomas Blumensath
|
nikcleju@2
|
83 #
|
nikcleju@2
|
84 # The University of Edinburgh
|
nikcleju@2
|
85 # Email: thomas.blumensath@ed.ac.uk
|
nikcleju@2
|
86 # Comments and bug reports welcome
|
nikcleju@2
|
87 #
|
nikcleju@2
|
88 # This file is part of sparsity Version 0.1
|
nikcleju@2
|
89 # Created: April 2007
|
nikcleju@2
|
90 #
|
nikcleju@2
|
91 # Part of this toolbox was developed with the support of EPSRC Grant
|
nikcleju@2
|
92 # D000246/1
|
nikcleju@2
|
93 #
|
nikcleju@2
|
94 # Please read COPYRIGHT.m for terms and conditions.
|
nikcleju@2
|
95
|
nikcleju@2
|
96 ###########################################################################
|
nikcleju@2
|
97 # Default values and initialisation
|
nikcleju@2
|
98 ###########################################################################
|
nikcleju@2
|
99 #[n1 n2]=size(x);
|
nikcleju@2
|
100 #n1,n2 = x.shape
|
nikcleju@2
|
101 #if n2 == 1
|
nikcleju@2
|
102 # n=n1;
|
nikcleju@2
|
103 #elseif n1 == 1
|
nikcleju@2
|
104 # x=x';
|
nikcleju@2
|
105 # n=n2;
|
nikcleju@2
|
106 #else
|
nikcleju@2
|
107 # display('x must be a vector.');
|
nikcleju@2
|
108 # return
|
nikcleju@2
|
109 #end
|
nikcleju@2
|
110 if x.ndim != 1:
|
nikcleju@2
|
111 print 'x must be a vector.'
|
nikcleju@2
|
112 return
|
nikcleju@2
|
113 n = x.size
|
nikcleju@2
|
114
|
nikcleju@2
|
115 #sigsize = x'*x/n;
|
nikcleju@2
|
116 sigsize = np.vdot(x,x)/n;
|
nikcleju@2
|
117 initial_given = 0;
|
nikcleju@2
|
118 err_mse = np.array([]);
|
nikcleju@2
|
119 iter_time = np.array([]);
|
nikcleju@2
|
120 STOPCRIT = 'M';
|
nikcleju@2
|
121 STOPTOL = math.ceil(n/4.0);
|
nikcleju@2
|
122 MAXITER = n;
|
nikcleju@2
|
123 verbose = False;
|
nikcleju@2
|
124 s_initial = np.zeros(m);
|
nikcleju@2
|
125
|
nikcleju@2
|
126 if verbose:
|
nikcleju@2
|
127 print 'Initialising...'
|
nikcleju@2
|
128 #end
|
nikcleju@2
|
129
|
nikcleju@2
|
130 ###########################################################################
|
nikcleju@2
|
131 # Output variables
|
nikcleju@2
|
132 ###########################################################################
|
nikcleju@2
|
133 #switch nargout
|
nikcleju@2
|
134 # case 3
|
nikcleju@2
|
135 # comp_err=true;
|
nikcleju@2
|
136 # comp_time=true;
|
nikcleju@2
|
137 # case 2
|
nikcleju@2
|
138 # comp_err=true;
|
nikcleju@2
|
139 # comp_time=false;
|
nikcleju@2
|
140 # case 1
|
nikcleju@2
|
141 # comp_err=false;
|
nikcleju@2
|
142 # comp_time=false;
|
nikcleju@2
|
143 # case 0
|
nikcleju@2
|
144 # error('Please assign output variable.')
|
nikcleju@2
|
145 # otherwise
|
nikcleju@2
|
146 # error('Too many output arguments specified')
|
nikcleju@2
|
147 #end
|
nikcleju@2
|
148 if 'nargout' in opts:
|
nikcleju@2
|
149 if opts['nargout'] == 3:
|
nikcleju@2
|
150 comp_err = True
|
nikcleju@2
|
151 comp_time = True
|
nikcleju@2
|
152 elif opts['nargout'] == 2:
|
nikcleju@2
|
153 comp_err = True
|
nikcleju@2
|
154 comp_time = False
|
nikcleju@2
|
155 elif opts['nargout'] == 1:
|
nikcleju@2
|
156 comp_err = False
|
nikcleju@2
|
157 comp_time = False
|
nikcleju@2
|
158 elif opts['nargout'] == 0:
|
nikcleju@2
|
159 print 'Please assign output variable.'
|
nikcleju@2
|
160 return
|
nikcleju@2
|
161 else:
|
nikcleju@2
|
162 print 'Too many output arguments specified'
|
nikcleju@2
|
163 return
|
nikcleju@2
|
164 else:
|
nikcleju@2
|
165 # If not given, make default nargout = 3
|
nikcleju@2
|
166 # and add nargout to options
|
nikcleju@2
|
167 opts['nargout'] = 3
|
nikcleju@2
|
168 comp_err = True
|
nikcleju@2
|
169 comp_time = True
|
nikcleju@2
|
170
|
nikcleju@2
|
171 ###########################################################################
|
nikcleju@2
|
172 # Look through options
|
nikcleju@2
|
173 ###########################################################################
|
nikcleju@2
|
174 # Put option into nice format
|
nikcleju@2
|
175 #Options={};
|
nikcleju@2
|
176 #OS=nargin-3;
|
nikcleju@2
|
177 #c=1;
|
nikcleju@2
|
178 #for i=1:OS
|
nikcleju@2
|
179 # if isa(varargin{i},'cell')
|
nikcleju@2
|
180 # CellSize=length(varargin{i});
|
nikcleju@2
|
181 # ThisCell=varargin{i};
|
nikcleju@2
|
182 # for j=1:CellSize
|
nikcleju@2
|
183 # Options{c}=ThisCell{j};
|
nikcleju@2
|
184 # c=c+1;
|
nikcleju@2
|
185 # end
|
nikcleju@2
|
186 # else
|
nikcleju@2
|
187 # Options{c}=varargin{i};
|
nikcleju@2
|
188 # c=c+1;
|
nikcleju@2
|
189 # end
|
nikcleju@2
|
190 #end
|
nikcleju@2
|
191 #OS=length(Options);
|
nikcleju@2
|
192 #if rem(OS,2)
|
nikcleju@2
|
193 # error('Something is wrong with argument name and argument value pairs.')
|
nikcleju@2
|
194 #end
|
nikcleju@2
|
195 #
|
nikcleju@2
|
196 #for i=1:2:OS
|
nikcleju@2
|
197 # switch Options{i}
|
nikcleju@2
|
198 # case {'stopCrit'}
|
nikcleju@2
|
199 # if (strmatch(Options{i+1},{'M'; 'corr'; 'mse'; 'mse_change'},'exact'));
|
nikcleju@2
|
200 # STOPCRIT = Options{i+1};
|
nikcleju@2
|
201 # else error('stopCrit must be char string [M, corr, mse, mse_change]. Exiting.'); end
|
nikcleju@2
|
202 # case {'stopTol'}
|
nikcleju@2
|
203 # if isa(Options{i+1},'numeric') ; STOPTOL = Options{i+1};
|
nikcleju@2
|
204 # else error('stopTol must be number. Exiting.'); end
|
nikcleju@2
|
205 # case {'P_trans'}
|
nikcleju@2
|
206 # if isa(Options{i+1},'function_handle'); Pt = Options{i+1};
|
nikcleju@2
|
207 # else error('P_trans must be function _handle. Exiting.'); end
|
nikcleju@2
|
208 # case {'maxIter'}
|
nikcleju@2
|
209 # if isa(Options{i+1},'numeric'); MAXITER = Options{i+1};
|
nikcleju@2
|
210 # else error('maxIter must be a number. Exiting.'); end
|
nikcleju@2
|
211 # case {'verbose'}
|
nikcleju@2
|
212 # if isa(Options{i+1},'logical'); verbose = Options{i+1};
|
nikcleju@2
|
213 # else error('verbose must be a logical. Exiting.'); end
|
nikcleju@2
|
214 # case {'start_val'}
|
nikcleju@2
|
215 # if isa(Options{i+1},'numeric') & length(Options{i+1}) == m ;
|
nikcleju@2
|
216 # s_initial = Options{i+1};
|
nikcleju@2
|
217 # initial_given=1;
|
nikcleju@2
|
218 # else error('start_val must be a vector of length m. Exiting.'); end
|
nikcleju@2
|
219 # otherwise
|
nikcleju@2
|
220 # error('Unrecognised option. Exiting.')
|
nikcleju@2
|
221 # end
|
nikcleju@2
|
222 #end
|
nikcleju@2
|
223 if 'stopCrit' in opts:
|
nikcleju@2
|
224 STOPCRIT = opts['stopCrit']
|
nikcleju@2
|
225 if 'stopTol' in opts:
|
nikcleju@2
|
226 if hasattr(opts['stopTol'], '__int__'): # check if numeric
|
nikcleju@2
|
227 STOPTOL = opts['stopTol']
|
nikcleju@2
|
228 else:
|
nikcleju@2
|
229 raise TypeError('stopTol must be number. Exiting.')
|
nikcleju@2
|
230 if 'P_trans' in opts:
|
nikcleju@2
|
231 if hasattr(opts['P_trans'], '__call__'): # check if function handle
|
nikcleju@2
|
232 Pt = opts['P_trans']
|
nikcleju@2
|
233 else:
|
nikcleju@2
|
234 raise TypeError('P_trans must be function _handle. Exiting.')
|
nikcleju@2
|
235 if 'maxIter' in opts:
|
nikcleju@2
|
236 if hasattr(opts['maxIter'], '__int__'): # check if numeric
|
nikcleju@2
|
237 MAXITER = opts['maxIter']
|
nikcleju@2
|
238 else:
|
nikcleju@2
|
239 raise TypeError('maxIter must be a number. Exiting.')
|
nikcleju@2
|
240 if 'verbose' in opts:
|
nikcleju@2
|
241 # TODO: Should check here if is logical
|
nikcleju@2
|
242 verbose = opts['verbose']
|
nikcleju@2
|
243 if 'start_val' in opts:
|
nikcleju@2
|
244 # TODO: Should check here if is numeric
|
nikcleju@2
|
245 if opts['start_val'].size == m:
|
nikcleju@2
|
246 s_initial = opts['start_val']
|
nikcleju@2
|
247 initial_given = 1
|
nikcleju@2
|
248 else:
|
nikcleju@2
|
249 raise ValueError('start_val must be a vector of length m. Exiting.')
|
nikcleju@2
|
250 # Don't exit if unknown option is given, simply ignore it
|
nikcleju@2
|
251
|
nikcleju@2
|
252 #if strcmp(STOPCRIT,'M')
|
nikcleju@2
|
253 # maxM=STOPTOL;
|
nikcleju@2
|
254 #else
|
nikcleju@2
|
255 # maxM=MAXITER;
|
nikcleju@2
|
256 #end
|
nikcleju@2
|
257 if STOPCRIT == 'M':
|
nikcleju@2
|
258 maxM = STOPTOL
|
nikcleju@2
|
259 else:
|
nikcleju@2
|
260 maxM = MAXITER
|
nikcleju@2
|
261
|
nikcleju@2
|
262 # if nargout >=2
|
nikcleju@2
|
263 # err_mse = zeros(maxM,1);
|
nikcleju@2
|
264 # end
|
nikcleju@2
|
265 # if nargout ==3
|
nikcleju@2
|
266 # iter_time = zeros(maxM,1);
|
nikcleju@2
|
267 # end
|
nikcleju@2
|
268 if opts['nargout'] >= 2:
|
nikcleju@2
|
269 err_mse = np.zeros(maxM)
|
nikcleju@2
|
270 if opts['nargout'] == 3:
|
nikcleju@2
|
271 iter_time = np.zeros(maxM)
|
nikcleju@2
|
272
|
nikcleju@2
|
273 ###########################################################################
|
nikcleju@2
|
274 # Make P and Pt functions
|
nikcleju@2
|
275 ###########################################################################
|
nikcleju@2
|
276 #if isa(A,'float') P =@(z) A*z; Pt =@(z) A'*z;
|
nikcleju@2
|
277 #elseif isobject(A) P =@(z) A*z; Pt =@(z) A'*z;
|
nikcleju@2
|
278 #elseif isa(A,'function_handle')
|
nikcleju@2
|
279 # try
|
nikcleju@2
|
280 # if isa(Pt,'function_handle'); P=A;
|
nikcleju@2
|
281 # else error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end
|
nikcleju@2
|
282 # catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end
|
nikcleju@2
|
283 #else error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end
|
nikcleju@2
|
284 if hasattr(A, '__call__'):
|
nikcleju@2
|
285 if hasattr(Pt, '__call__'):
|
nikcleju@2
|
286 P = A
|
nikcleju@2
|
287 else:
|
nikcleju@2
|
288 raise TypeError('If P is a function handle, Pt also needs to be a function handle.')
|
nikcleju@2
|
289 else:
|
nikcleju@2
|
290 # TODO: should check here if A is matrix
|
nikcleju@2
|
291 P = lambda z: np.dot(A,z)
|
nikcleju@2
|
292 Pt = lambda z: np.dot(A.T,z)
|
nikcleju@2
|
293
|
nikcleju@2
|
294 ###########################################################################
|
nikcleju@2
|
295 # Random Check to see if dictionary is normalised
|
nikcleju@2
|
296 ###########################################################################
|
nikcleju@2
|
297 # mask=zeros(m,1);
|
nikcleju@2
|
298 # mask(ceil(rand*m))=1;
|
nikcleju@2
|
299 # nP=norm(P(mask));
|
nikcleju@2
|
300 # if abs(1-nP)>1e-3;
|
nikcleju@2
|
301 # display('Dictionary appears not to have unit norm columns.')
|
nikcleju@2
|
302 # end
|
nikcleju@2
|
303 mask = np.zeros(m)
|
nikcleju@2
|
304 mask[math.floor(np.random.rand() * m)] = 1
|
nikcleju@2
|
305 nP = np.linalg.norm(P(mask))
|
nikcleju@2
|
306 if abs(1-nP) > 1e-3:
|
nikcleju@2
|
307 print 'Dictionary appears not to have unit norm columns.'
|
nikcleju@2
|
308 #end
|
nikcleju@2
|
309
|
nikcleju@2
|
310 ###########################################################################
|
nikcleju@2
|
311 # Check if we have enough memory and initialise
|
nikcleju@2
|
312 ###########################################################################
|
nikcleju@2
|
313 # try Q=zeros(n,maxM);
|
nikcleju@2
|
314 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
|
nikcleju@2
|
315 # end
|
nikcleju@2
|
316 # try R=zeros(maxM);
|
nikcleju@2
|
317 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
|
nikcleju@2
|
318 # end
|
nikcleju@2
|
319 try:
|
nikcleju@2
|
320 Q = np.zeros((n,maxM))
|
nikcleju@2
|
321 except:
|
nikcleju@2
|
322 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
|
nikcleju@2
|
323 raise
|
nikcleju@2
|
324 try:
|
nikcleju@2
|
325 R = np.zeros((maxM, maxM))
|
nikcleju@2
|
326 except:
|
nikcleju@2
|
327 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
|
nikcleju@2
|
328 raise
|
nikcleju@2
|
329
|
nikcleju@2
|
330 ###########################################################################
|
nikcleju@2
|
331 # Do we start from zero or not?
|
nikcleju@2
|
332 ###########################################################################
|
nikcleju@2
|
333 #if initial_given ==1;
|
nikcleju@2
|
334 # IN = find(s_initial);
|
nikcleju@2
|
335 # if ~isempty(IN)
|
nikcleju@2
|
336 # Residual = x-P(s_initial);
|
nikcleju@2
|
337 # lengthIN=length(IN);
|
nikcleju@2
|
338 # z=[];
|
nikcleju@2
|
339 # for k=1:length(IN)
|
nikcleju@2
|
340 # # Extract new element
|
nikcleju@2
|
341 # mask=zeros(m,1);
|
nikcleju@2
|
342 # mask(IN(k))=1;
|
nikcleju@2
|
343 # new_element=P(mask);
|
nikcleju@2
|
344 #
|
nikcleju@2
|
345 # # Orthogonalise new element
|
nikcleju@2
|
346 # qP=Q(:,1:k-1)'*new_element;
|
nikcleju@2
|
347 # q=new_element-Q(:,1:k-1)*(qP);
|
nikcleju@2
|
348 #
|
nikcleju@2
|
349 # nq=norm(q);
|
nikcleju@2
|
350 # q=q/nq;
|
nikcleju@2
|
351 # # Update QR factorisation
|
nikcleju@2
|
352 # R(1:k-1,k)=qP;
|
nikcleju@2
|
353 # R(k,k)=nq;
|
nikcleju@2
|
354 # Q(:,k)=q;
|
nikcleju@2
|
355 #
|
nikcleju@2
|
356 # z(k)=q'*x;
|
nikcleju@2
|
357 # end
|
nikcleju@2
|
358 # s = s_initial;
|
nikcleju@2
|
359 # Residual=x-Q(:,k)*z;
|
nikcleju@2
|
360 # oldERR = Residual'*Residual/n;
|
nikcleju@2
|
361 # else
|
nikcleju@2
|
362 # IN = [];
|
nikcleju@2
|
363 # Residual = x;
|
nikcleju@2
|
364 # s = s_initial;
|
nikcleju@2
|
365 # sigsize = x'*x/n;
|
nikcleju@2
|
366 # oldERR = sigsize;
|
nikcleju@2
|
367 # k=0;
|
nikcleju@2
|
368 # z=[];
|
nikcleju@2
|
369 # end
|
nikcleju@2
|
370 #
|
nikcleju@2
|
371 #else
|
nikcleju@2
|
372 # IN = [];
|
nikcleju@2
|
373 # Residual = x;
|
nikcleju@2
|
374 # s = s_initial;
|
nikcleju@2
|
375 # sigsize = x'*x/n;
|
nikcleju@2
|
376 # oldERR = sigsize;
|
nikcleju@2
|
377 # k=0;
|
nikcleju@2
|
378 # z=[];
|
nikcleju@2
|
379 #end
|
nikcleju@2
|
380 if initial_given == 1:
|
nikcleju@2
|
381 #IN = find(s_initial);
|
nikcleju@2
|
382 IN = np.nonzero(s_initial)[0].tolist()
|
nikcleju@2
|
383 #if ~isempty(IN)
|
nikcleju@2
|
384 if IN.size > 0:
|
nikcleju@2
|
385 Residual = x - P(s_initial)
|
nikcleju@2
|
386 lengthIN = IN.size
|
nikcleju@2
|
387 z = np.array([])
|
nikcleju@2
|
388 #for k=1:length(IN)
|
nikcleju@2
|
389 for k in np.arange(IN.size):
|
nikcleju@2
|
390 # Extract new element
|
nikcleju@2
|
391 mask = np.zeros(m)
|
nikcleju@2
|
392 mask[IN[k]] = 1
|
nikcleju@2
|
393 new_element = P(mask)
|
nikcleju@2
|
394
|
nikcleju@2
|
395 # Orthogonalise new element
|
nikcleju@2
|
396 #qP=Q(:,1:k-1)'*new_element;
|
nikcleju@2
|
397 if k-1 >= 0:
|
nikcleju@2
|
398 qP = np.dot(Q[:,0:k].T , new_element)
|
nikcleju@2
|
399 #q=new_element-Q(:,1:k-1)*(qP);
|
nikcleju@2
|
400 q = new_element - np.dot(Q[:,0:k] , qP)
|
nikcleju@2
|
401
|
nikcleju@2
|
402 nq = np.linalg.norm(q)
|
nikcleju@2
|
403 q = q / nq
|
nikcleju@2
|
404 # Update QR factorisation
|
nikcleju@2
|
405 R[0:k,k] = qP
|
nikcleju@2
|
406 R[k,k] = nq
|
nikcleju@2
|
407 Q[:,k] = q
|
nikcleju@2
|
408 else:
|
nikcleju@2
|
409 q = new_element
|
nikcleju@2
|
410
|
nikcleju@2
|
411 nq = np.linalg.norm(q)
|
nikcleju@2
|
412 q = q / nq
|
nikcleju@2
|
413 # Update QR factorisation
|
nikcleju@2
|
414 R[k,k] = nq
|
nikcleju@2
|
415 Q[:,k] = q
|
nikcleju@2
|
416
|
nikcleju@2
|
417 z[k] = np.dot(q.T , x)
|
nikcleju@2
|
418 #end
|
nikcleju@2
|
419 s = s_initial.copy()
|
nikcleju@2
|
420 Residual = x - np.dot(Q[:,k] , z)
|
nikcleju@2
|
421 oldERR = np.vdot(Residual , Residual) / n;
|
nikcleju@2
|
422 else:
|
nikcleju@2
|
423 #IN = np.array([], dtype = int)
|
nikcleju@2
|
424 IN = np.array([], dtype = int).tolist()
|
nikcleju@2
|
425 Residual = x.copy()
|
nikcleju@2
|
426 s = s_initial.copy()
|
nikcleju@2
|
427 sigsize = np.vdot(x , x) / n
|
nikcleju@2
|
428 oldERR = sigsize
|
nikcleju@2
|
429 k = 0
|
nikcleju@2
|
430 #z = np.array([])
|
nikcleju@2
|
431 z = []
|
nikcleju@2
|
432 #end
|
nikcleju@2
|
433
|
nikcleju@2
|
434 else:
|
nikcleju@2
|
435 #IN = np.array([], dtype = int)
|
nikcleju@2
|
436 IN = np.array([], dtype = int).tolist()
|
nikcleju@2
|
437 Residual = x.copy()
|
nikcleju@2
|
438 s = s_initial.copy()
|
nikcleju@2
|
439 sigsize = np.vdot(x , x) / n
|
nikcleju@2
|
440 oldERR = sigsize
|
nikcleju@2
|
441 k = 0
|
nikcleju@2
|
442 #z = np.array([])
|
nikcleju@2
|
443 z = []
|
nikcleju@2
|
444 #end
|
nikcleju@2
|
445
|
nikcleju@2
|
446 ###########################################################################
|
nikcleju@2
|
447 # Main algorithm
|
nikcleju@2
|
448 ###########################################################################
|
nikcleju@2
|
449 # if verbose
|
nikcleju@2
|
450 # display('Main iterations...')
|
nikcleju@2
|
451 # end
|
nikcleju@2
|
452 # tic
|
nikcleju@2
|
453 # t=0;
|
nikcleju@2
|
454 # DR=Pt(Residual);
|
nikcleju@2
|
455 # done = 0;
|
nikcleju@2
|
456 # iter=1;
|
nikcleju@2
|
457 if verbose:
|
nikcleju@2
|
458 print 'Main iterations...'
|
nikcleju@2
|
459 tic = time.time()
|
nikcleju@2
|
460 t = 0
|
nikcleju@2
|
461 DR = Pt(Residual)
|
nikcleju@2
|
462 done = 0
|
nikcleju@2
|
463 iter = 1
|
nikcleju@2
|
464
|
nikcleju@2
|
465 #while ~done
|
nikcleju@2
|
466 #
|
nikcleju@2
|
467 # # Select new element
|
nikcleju@2
|
468 # DR(IN)=0;
|
nikcleju@2
|
469 # # Nic: replace selection with random variable
|
nikcleju@2
|
470 # # i.e. Randomized OMP!!
|
nikcleju@2
|
471 # # DON'T FORGET ABOUT THIS!!
|
nikcleju@2
|
472 # [v I]=max(abs(DR));
|
nikcleju@2
|
473 # #I = randp(exp(abs(DR).^2 ./ (norms.^2)'), [1 1]);
|
nikcleju@2
|
474 # IN=[IN I];
|
nikcleju@2
|
475 #
|
nikcleju@2
|
476 #
|
nikcleju@2
|
477 # k=k+1;
|
nikcleju@2
|
478 # # Extract new element
|
nikcleju@2
|
479 # mask=zeros(m,1);
|
nikcleju@2
|
480 # mask(IN(k))=1;
|
nikcleju@2
|
481 # new_element=P(mask);
|
nikcleju@2
|
482 #
|
nikcleju@2
|
483 # # Orthogonalise new element
|
nikcleju@2
|
484 # qP=Q(:,1:k-1)'*new_element;
|
nikcleju@2
|
485 # q=new_element-Q(:,1:k-1)*(qP);
|
nikcleju@2
|
486 #
|
nikcleju@2
|
487 # nq=norm(q);
|
nikcleju@2
|
488 # q=q/nq;
|
nikcleju@2
|
489 # # Update QR factorisation
|
nikcleju@2
|
490 # R(1:k-1,k)=qP;
|
nikcleju@2
|
491 # R(k,k)=nq;
|
nikcleju@2
|
492 # Q(:,k)=q;
|
nikcleju@2
|
493 #
|
nikcleju@2
|
494 # z(k)=q'*x;
|
nikcleju@2
|
495 #
|
nikcleju@2
|
496 # # New residual
|
nikcleju@2
|
497 # Residual=Residual-q*(z(k));
|
nikcleju@2
|
498 # DR=Pt(Residual);
|
nikcleju@2
|
499 #
|
nikcleju@2
|
500 # ERR=Residual'*Residual/n;
|
nikcleju@2
|
501 # if comp_err
|
nikcleju@2
|
502 # err_mse(iter)=ERR;
|
nikcleju@2
|
503 # end
|
nikcleju@2
|
504 #
|
nikcleju@2
|
505 # if comp_time
|
nikcleju@2
|
506 # iter_time(iter)=toc;
|
nikcleju@2
|
507 # end
|
nikcleju@2
|
508 #
|
nikcleju@2
|
509 ############################################################################
|
nikcleju@2
|
510 ## Are we done yet?
|
nikcleju@2
|
511 ############################################################################
|
nikcleju@2
|
512 #
|
nikcleju@2
|
513 # if strcmp(STOPCRIT,'M')
|
nikcleju@2
|
514 # if iter >= STOPTOL
|
nikcleju@2
|
515 # done =1;
|
nikcleju@2
|
516 # elseif verbose && toc-t>10
|
nikcleju@2
|
517 # display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
|
nikcleju@2
|
518 # t=toc;
|
nikcleju@2
|
519 # end
|
nikcleju@2
|
520 # elseif strcmp(STOPCRIT,'mse')
|
nikcleju@2
|
521 # if comp_err
|
nikcleju@2
|
522 # if err_mse(iter)<STOPTOL;
|
nikcleju@2
|
523 # done = 1;
|
nikcleju@2
|
524 # elseif verbose && toc-t>10
|
nikcleju@2
|
525 # display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
|
nikcleju@2
|
526 # t=toc;
|
nikcleju@2
|
527 # end
|
nikcleju@2
|
528 # else
|
nikcleju@2
|
529 # if ERR<STOPTOL;
|
nikcleju@2
|
530 # done = 1;
|
nikcleju@2
|
531 # elseif verbose && toc-t>10
|
nikcleju@2
|
532 # display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
|
nikcleju@2
|
533 # t=toc;
|
nikcleju@2
|
534 # end
|
nikcleju@2
|
535 # end
|
nikcleju@2
|
536 # elseif strcmp(STOPCRIT,'mse_change') && iter >=2
|
nikcleju@2
|
537 # if comp_err && iter >=2
|
nikcleju@2
|
538 # if ((err_mse(iter-1)-err_mse(iter))/sigsize <STOPTOL);
|
nikcleju@2
|
539 # done = 1;
|
nikcleju@2
|
540 # elseif verbose && toc-t>10
|
nikcleju@2
|
541 # display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
|
nikcleju@2
|
542 # t=toc;
|
nikcleju@2
|
543 # end
|
nikcleju@2
|
544 # else
|
nikcleju@2
|
545 # if ((oldERR - ERR)/sigsize < STOPTOL);
|
nikcleju@2
|
546 # done = 1;
|
nikcleju@2
|
547 # elseif verbose && toc-t>10
|
nikcleju@2
|
548 # display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
|
nikcleju@2
|
549 # t=toc;
|
nikcleju@2
|
550 # end
|
nikcleju@2
|
551 # end
|
nikcleju@2
|
552 # elseif strcmp(STOPCRIT,'corr')
|
nikcleju@2
|
553 # if max(abs(DR)) < STOPTOL;
|
nikcleju@2
|
554 # done = 1;
|
nikcleju@2
|
555 # elseif verbose && toc-t>10
|
nikcleju@2
|
556 # display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
|
nikcleju@2
|
557 # t=toc;
|
nikcleju@2
|
558 # end
|
nikcleju@2
|
559 # end
|
nikcleju@2
|
560 #
|
nikcleju@2
|
561 # # Also stop if residual gets too small or maxIter reached
|
nikcleju@2
|
562 # if comp_err
|
nikcleju@2
|
563 # if err_mse(iter)<1e-16
|
nikcleju@2
|
564 # display('Stopping. Exact signal representation found!')
|
nikcleju@2
|
565 # done=1;
|
nikcleju@2
|
566 # end
|
nikcleju@2
|
567 # else
|
nikcleju@2
|
568 #
|
nikcleju@2
|
569 #
|
nikcleju@2
|
570 # if iter>1
|
nikcleju@2
|
571 # if ERR<1e-16
|
nikcleju@2
|
572 # display('Stopping. Exact signal representation found!')
|
nikcleju@2
|
573 # done=1;
|
nikcleju@2
|
574 # end
|
nikcleju@2
|
575 # end
|
nikcleju@2
|
576 # end
|
nikcleju@2
|
577 #
|
nikcleju@2
|
578 # if iter >= MAXITER
|
nikcleju@2
|
579 # display('Stopping. Maximum number of iterations reached!')
|
nikcleju@2
|
580 # done = 1;
|
nikcleju@2
|
581 # end
|
nikcleju@2
|
582 #
|
nikcleju@2
|
583 ############################################################################
|
nikcleju@2
|
584 ## If not done, take another round
|
nikcleju@2
|
585 ############################################################################
|
nikcleju@2
|
586 #
|
nikcleju@2
|
587 # if ~done
|
nikcleju@2
|
588 # iter=iter+1;
|
nikcleju@2
|
589 # oldERR=ERR;
|
nikcleju@2
|
590 # end
|
nikcleju@2
|
591 #end
|
nikcleju@2
|
592 while not done:
|
nikcleju@2
|
593
|
nikcleju@2
|
594 # Select new element
|
nikcleju@2
|
595 DR[IN]=0
|
nikcleju@2
|
596 #[v I]=max(abs(DR));
|
nikcleju@2
|
597 #v = np.abs(DR).max()
|
nikcleju@2
|
598 I = np.abs(DR).argmax()
|
nikcleju@2
|
599 #IN = np.concatenate((IN,I))
|
nikcleju@2
|
600 IN.append(I)
|
nikcleju@2
|
601
|
nikcleju@2
|
602
|
nikcleju@2
|
603 #k = k + 1 Move to end, since is zero based
|
nikcleju@2
|
604
|
nikcleju@2
|
605 # Extract new element
|
nikcleju@2
|
606 mask = np.zeros(m)
|
nikcleju@2
|
607 mask[IN[k]] = 1
|
nikcleju@2
|
608 new_element = P(mask)
|
nikcleju@2
|
609
|
nikcleju@2
|
610 # Orthogonalise new element
|
nikcleju@2
|
611 if k-1 >= 0:
|
nikcleju@2
|
612 qP = np.dot(Q[:,0:k].T , new_element)
|
nikcleju@2
|
613 q = new_element - np.dot(Q[:,0:k] , qP)
|
nikcleju@2
|
614
|
nikcleju@2
|
615 nq = np.linalg.norm(q)
|
nikcleju@2
|
616 q = q/nq
|
nikcleju@2
|
617 # Update QR factorisation
|
nikcleju@2
|
618 R[0:k,k] = qP
|
nikcleju@2
|
619 R[k,k] = nq
|
nikcleju@2
|
620 Q[:,k] = q
|
nikcleju@2
|
621 else:
|
nikcleju@2
|
622 q = new_element
|
nikcleju@2
|
623
|
nikcleju@2
|
624 nq = np.linalg.norm(q)
|
nikcleju@2
|
625 q = q/nq
|
nikcleju@2
|
626 # Update QR factorisation
|
nikcleju@2
|
627 R[k,k] = nq
|
nikcleju@2
|
628 Q[:,k] = q
|
nikcleju@2
|
629
|
nikcleju@2
|
630 #z[k]=np.vdot(q , x)
|
nikcleju@2
|
631 z.append(np.vdot(q , x))
|
nikcleju@2
|
632
|
nikcleju@2
|
633 # New residual
|
nikcleju@2
|
634 Residual = Residual - q * (z[k])
|
nikcleju@2
|
635 DR = Pt(Residual)
|
nikcleju@2
|
636
|
nikcleju@2
|
637 ERR = np.vdot(Residual , Residual) / n
|
nikcleju@2
|
638 if comp_err:
|
nikcleju@2
|
639 err_mse[iter-1] = ERR
|
nikcleju@2
|
640 #end
|
nikcleju@2
|
641
|
nikcleju@2
|
642 if comp_time:
|
nikcleju@2
|
643 iter_time[iter-1] = time.time() - tic
|
nikcleju@2
|
644 #end
|
nikcleju@2
|
645
|
nikcleju@2
|
646 ###########################################################################
|
nikcleju@2
|
647 # Are we done yet?
|
nikcleju@2
|
648 ###########################################################################
|
nikcleju@2
|
649 if STOPCRIT == 'M':
|
nikcleju@2
|
650 if iter >= STOPTOL:
|
nikcleju@2
|
651 done = 1
|
nikcleju@2
|
652 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
653 #display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
|
nikcleju@2
|
654 print 'Iteration '+iter+'. --- '+(STOPTOL-iter)+' iterations to go'
|
nikcleju@2
|
655 t = time.time()
|
nikcleju@2
|
656 #end
|
nikcleju@2
|
657 elif STOPCRIT =='mse':
|
nikcleju@2
|
658 if comp_err:
|
nikcleju@2
|
659 if err_mse[iter-1] < STOPTOL:
|
nikcleju@2
|
660 done = 1
|
nikcleju@2
|
661 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
662 #display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
|
nikcleju@2
|
663 print 'Iteration '+iter+'. --- '+err_mse[iter-1]+' mse'
|
nikcleju@2
|
664 t = time.time()
|
nikcleju@2
|
665 #end
|
nikcleju@2
|
666 else:
|
nikcleju@2
|
667 if ERR < STOPTOL:
|
nikcleju@2
|
668 done = 1
|
nikcleju@2
|
669 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
670 #display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
|
nikcleju@2
|
671 print 'Iteration '+iter+'. --- '+ERR+' mse'
|
nikcleju@2
|
672 t = time.time()
|
nikcleju@2
|
673 #end
|
nikcleju@2
|
674 #end
|
nikcleju@2
|
675 elif STOPCRIT == 'mse_change' and iter >=2:
|
nikcleju@2
|
676 if comp_err and iter >=2:
|
nikcleju@2
|
677 if ((err_mse[iter-2] - err_mse[iter-1])/sigsize < STOPTOL):
|
nikcleju@2
|
678 done = 1
|
nikcleju@2
|
679 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
680 #display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
|
nikcleju@2
|
681 print 'Iteration '+iter+'. --- '+((err_mse[iter-2]-err_mse[iter-1])/sigsize)+' mse change'
|
nikcleju@2
|
682 t = time.time()
|
nikcleju@2
|
683 #end
|
nikcleju@2
|
684 else:
|
nikcleju@2
|
685 if ((oldERR - ERR)/sigsize < STOPTOL):
|
nikcleju@2
|
686 done = 1
|
nikcleju@2
|
687 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
688 #display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
|
nikcleju@2
|
689 print 'Iteration '+iter+'. --- '+((oldERR - ERR)/sigsize)+' mse change'
|
nikcleju@2
|
690 t = time.time()
|
nikcleju@2
|
691 #end
|
nikcleju@2
|
692 #end
|
nikcleju@2
|
693 elif STOPCRIT == 'corr':
|
nikcleju@2
|
694 if np.abs(DR).max() < STOPTOL:
|
nikcleju@2
|
695 done = 1
|
nikcleju@2
|
696 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
697 #display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
|
nikcleju@2
|
698 print 'Iteration '+iter+'. --- '+(np.abs(DR).max())+' corr'
|
nikcleju@2
|
699 t = time.time()
|
nikcleju@2
|
700 #end
|
nikcleju@2
|
701 #end
|
nikcleju@2
|
702
|
nikcleju@2
|
703 # Also stop if residual gets too small or maxIter reached
|
nikcleju@2
|
704 if comp_err:
|
nikcleju@2
|
705 if err_mse[iter-1] < 1e-14:
|
nikcleju@2
|
706 done = 1
|
nikcleju@2
|
707 # Nic: added verbose check
|
nikcleju@2
|
708 if verbose:
|
nikcleju@2
|
709 print 'Stopping. Exact signal representation found!'
|
nikcleju@2
|
710 #end
|
nikcleju@2
|
711 else:
|
nikcleju@2
|
712 if iter > 1:
|
nikcleju@2
|
713 if ERR < 1e-14:
|
nikcleju@2
|
714 done = 1
|
nikcleju@2
|
715 # Nic: added verbose check
|
nikcleju@2
|
716 if verbose:
|
nikcleju@2
|
717 print 'Stopping. Exact signal representation found!'
|
nikcleju@2
|
718 #end
|
nikcleju@2
|
719 #end
|
nikcleju@2
|
720 #end
|
nikcleju@2
|
721
|
nikcleju@2
|
722
|
nikcleju@2
|
723 if iter >= MAXITER:
|
nikcleju@2
|
724 done = 1
|
nikcleju@2
|
725 # Nic: added verbose check
|
nikcleju@2
|
726 if verbose:
|
nikcleju@2
|
727 print 'Stopping. Maximum number of iterations reached!'
|
nikcleju@2
|
728 #end
|
nikcleju@2
|
729
|
nikcleju@2
|
730 ###########################################################################
|
nikcleju@2
|
731 # If not done, take another round
|
nikcleju@2
|
732 ###########################################################################
|
nikcleju@2
|
733 if not done:
|
nikcleju@2
|
734 iter = iter + 1
|
nikcleju@2
|
735 oldERR = ERR
|
nikcleju@2
|
736 #end
|
nikcleju@2
|
737
|
nikcleju@2
|
738 # Moved here from front, since we are 0-based
|
nikcleju@2
|
739 k = k + 1
|
nikcleju@2
|
740 #end
|
nikcleju@2
|
741
|
nikcleju@2
|
742 ###########################################################################
|
nikcleju@2
|
743 # Now we can solve for s by back-substitution
|
nikcleju@2
|
744 ###########################################################################
|
nikcleju@2
|
745 #s(IN)=R(1:k,1:k)\z(1:k)';
|
nikcleju@2
|
746 s[IN] = scipy.linalg.solve(R[0:k,0:k] , np.array(z[0:k]))
|
nikcleju@2
|
747
|
nikcleju@2
|
748 ###########################################################################
|
nikcleju@2
|
749 # Only return as many elements as iterations
|
nikcleju@2
|
750 ###########################################################################
|
nikcleju@2
|
751 if opts['nargout'] >= 2:
|
nikcleju@2
|
752 err_mse = err_mse[0:iter-1]
|
nikcleju@2
|
753 #end
|
nikcleju@2
|
754 if opts['nargout'] == 3:
|
nikcleju@2
|
755 iter_time = iter_time[0:iter-1]
|
nikcleju@2
|
756 #end
|
nikcleju@2
|
757 if verbose:
|
nikcleju@2
|
758 print 'Done'
|
nikcleju@2
|
759 #end
|
nikcleju@2
|
760
|
nikcleju@2
|
761 # Return
|
nikcleju@2
|
762 if opts['nargout'] == 1:
|
nikcleju@2
|
763 return s
|
nikcleju@2
|
764 elif opts['nargout'] == 2:
|
nikcleju@2
|
765 return s, err_mse
|
nikcleju@2
|
766 elif opts['nargout'] == 3:
|
nikcleju@2
|
767 return s, err_mse, iter_time
|
nikcleju@2
|
768
|
nikcleju@2
|
769 # Change history
|
nikcleju@2
|
770 #
|
nikcleju@2
|
771 # 8 of Februray: Algo does no longer stop if dictionary is not normaliesd.
|
nikcleju@2
|
772
|
nikcleju@2
|
773 # End of greed_omp_qr() function
|
nikcleju@2
|
774 #--------------------------------
|
nikcleju@2
|
775
|
nikcleju@2
|
776
|
nikcleju@2
|
777 def omp_qr(x, dict, D, natom, tolerance):
|
nikcleju@2
|
778 """ Recover x using QR implementation of OMP
|
nikcleju@2
|
779
|
nikcleju@2
|
780 Parameter
|
nikcleju@2
|
781 ---------
|
nikcleju@2
|
782 x: measurements
|
nikcleju@2
|
783 dict: dictionary
|
nikcleju@2
|
784 D: Gramian of dictionary
|
nikcleju@2
|
785 natom: iterations
|
nikcleju@2
|
786 tolerance: error tolerance
|
nikcleju@2
|
787
|
nikcleju@2
|
788 Return
|
nikcleju@2
|
789 ------
|
nikcleju@2
|
790 x_hat : estimate of x
|
nikcleju@2
|
791 gamma : indices where non-zero
|
nikcleju@2
|
792
|
nikcleju@2
|
793 For more information, see http://media.aau.dk/null_space_pursuits/2011/10/efficient-omp.html
|
nikcleju@2
|
794 """
|
nikcleju@2
|
795 msize, dictsize = dict.shape
|
nikcleju@2
|
796 normr2 = np.vdot(x,x)
|
nikcleju@2
|
797 normtol2 = tolerance*normr2
|
nikcleju@2
|
798 R = np.zeros((natom,natom))
|
nikcleju@2
|
799 Q = np.zeros((msize,natom))
|
nikcleju@2
|
800 gamma = []
|
nikcleju@2
|
801
|
nikcleju@2
|
802 # find initial projections
|
nikcleju@2
|
803 origprojections = np.dot(x.T,dict)
|
nikcleju@2
|
804 origprojectionsT = origprojections.T
|
nikcleju@2
|
805 projections = origprojections.copy();
|
nikcleju@2
|
806
|
nikcleju@2
|
807 k = 0
|
nikcleju@2
|
808 while (normr2 > normtol2) and (k < natom):
|
nikcleju@2
|
809 # find index of maximum magnitude projection
|
nikcleju@2
|
810 newgam = np.argmax(np.abs(projections ** 2))
|
nikcleju@2
|
811 gamma.append(newgam)
|
nikcleju@2
|
812 # update QR factorization, projections, and residual energy
|
nikcleju@2
|
813 if k == 0:
|
nikcleju@2
|
814 R[0,0] = 1
|
nikcleju@2
|
815 Q[:,0] = dict[:,newgam].copy()
|
nikcleju@2
|
816 # update projections
|
nikcleju@2
|
817 QtempQtempT = np.outer(Q[:,0],Q[:,0])
|
nikcleju@2
|
818 projections -= np.dot(x.T, np.dot(QtempQtempT,dict))
|
nikcleju@2
|
819 # update residual energy
|
nikcleju@2
|
820 normr2 -= np.vdot(x, np.dot(QtempQtempT,x))
|
nikcleju@2
|
821 else:
|
nikcleju@2
|
822 w = scipy.linalg.solve_triangular(R[0:k,0:k],D[gamma[0:k],newgam],trans=1)
|
nikcleju@2
|
823 R[k,k] = np.sqrt(1-np.vdot(w,w))
|
nikcleju@2
|
824 R[0:k,k] = w.copy()
|
nikcleju@2
|
825 Q[:,k] = (dict[:,newgam] - np.dot(QtempQtempT,dict[:,newgam]))/R[k,k]
|
nikcleju@2
|
826 QkQkT = np.outer(Q[:,k],Q[:,k])
|
nikcleju@2
|
827 xTQkQkT = np.dot(x.T,QkQkT)
|
nikcleju@2
|
828 QtempQtempT += QkQkT
|
nikcleju@2
|
829 # update projections
|
nikcleju@2
|
830 projections -= np.dot(xTQkQkT,dict)
|
nikcleju@2
|
831 # update residual energy
|
nikcleju@2
|
832 normr2 -= np.dot(xTQkQkT,x)
|
nikcleju@2
|
833
|
nikcleju@2
|
834 k += 1
|
nikcleju@2
|
835
|
nikcleju@2
|
836 # build solution
|
nikcleju@2
|
837 tempR = R[0:k,0:k]
|
nikcleju@2
|
838 w = scipy.linalg.solve_triangular(tempR,origprojectionsT[gamma[0:k]],trans=1)
|
nikcleju@2
|
839 x_hat = np.zeros((dictsize,1))
|
nikcleju@2
|
840 x_hat[gamma[0:k]] = scipy.linalg.solve_triangular(tempR,w)
|
nikcleju@2
|
841
|
nikcleju@2
|
842 return x_hat, gamma |