nikcleju@2
|
1 import numpy as np
|
nikcleju@2
|
2 import scipy.linalg
|
nikcleju@2
|
3 import time
|
nikcleju@2
|
4 import math
|
nikcleju@2
|
5
|
nikcleju@2
|
6
|
nikcleju@2
|
7 #function [s, err_mse, iter_time]=greed_omp_qr(x,A,m,varargin)
|
nikcleju@2
|
8 def greed_omp_qr(x,A,m,opts=[]):
|
nikcleju@2
|
9 # greed_omp_qr: Orthogonal Matching Pursuit algorithm based on QR
|
nikcleju@2
|
10 # factorisation
|
nikcleju@2
|
11 # Nic: translated to Python on 19.10.2011. Original Matlab Code by Thomas Blumensath
|
nikcleju@2
|
12 ###########################################################################
|
nikcleju@2
|
13 # Usage
|
nikcleju@2
|
14 # [s, err_mse, iter_time]=greed_omp_qr(x,P,m,'option_name','option_value')
|
nikcleju@2
|
15 ###########################################################################
|
nikcleju@2
|
16 ###########################################################################
|
nikcleju@2
|
17 # Input
|
nikcleju@2
|
18 # Mandatory:
|
nikcleju@2
|
19 # x Observation vector to be decomposed
|
nikcleju@2
|
20 # P Either:
|
nikcleju@2
|
21 # 1) An nxm matrix (n must be dimension of x)
|
nikcleju@2
|
22 # 2) A function handle (type "help function_format"
|
nikcleju@2
|
23 # for more information)
|
nikcleju@2
|
24 # Also requires specification of P_trans option.
|
nikcleju@2
|
25 # 3) An object handle (type "help object_format" for
|
nikcleju@2
|
26 # more information)
|
nikcleju@2
|
27 # m length of s
|
nikcleju@2
|
28 #
|
nikcleju@2
|
29 # Possible additional options:
|
nikcleju@2
|
30 # (specify as many as you want using 'option_name','option_value' pairs)
|
nikcleju@2
|
31 # See below for explanation of options:
|
nikcleju@2
|
32 #__________________________________________________________________________
|
nikcleju@2
|
33 # option_name | available option_values | default
|
nikcleju@2
|
34 #--------------------------------------------------------------------------
|
nikcleju@2
|
35 # stopCrit | M, corr, mse, mse_change | M
|
nikcleju@2
|
36 # stopTol | number (see below) | n/4
|
nikcleju@2
|
37 # P_trans | function_handle (see below) |
|
nikcleju@2
|
38 # maxIter | positive integer (see below) | n
|
nikcleju@2
|
39 # verbose | true, false | false
|
nikcleju@2
|
40 # start_val | vector of length m | zeros
|
nikcleju@2
|
41 #
|
nikcleju@2
|
42 # Available stopping criteria :
|
nikcleju@2
|
43 # M - Extracts exactly M = stopTol elements.
|
nikcleju@2
|
44 # corr - Stops when maximum correlation between
|
nikcleju@2
|
45 # residual and atoms is below stopTol value.
|
nikcleju@2
|
46 # mse - Stops when mean squared error of residual
|
nikcleju@2
|
47 # is below stopTol value.
|
nikcleju@2
|
48 # mse_change - Stops when the change in the mean squared
|
nikcleju@2
|
49 # error falls below stopTol value.
|
nikcleju@2
|
50 #
|
nikcleju@2
|
51 # stopTol: Value for stopping criterion.
|
nikcleju@2
|
52 #
|
nikcleju@2
|
53 # P_trans: If P is a function handle, then P_trans has to be specified and
|
nikcleju@2
|
54 # must be a function handle.
|
nikcleju@2
|
55 #
|
nikcleju@2
|
56 # maxIter: Maximum number of allowed iterations.
|
nikcleju@2
|
57 #
|
nikcleju@2
|
58 # verbose: Logical value to allow algorithm progress to be displayed.
|
nikcleju@2
|
59 #
|
nikcleju@2
|
60 # start_val: Allows algorithms to start from partial solution.
|
nikcleju@2
|
61 #
|
nikcleju@2
|
62 ###########################################################################
|
nikcleju@2
|
63 # Outputs
|
nikcleju@2
|
64 # s Solution vector
|
nikcleju@2
|
65 # err_mse Vector containing mse of approximation error for each
|
nikcleju@2
|
66 # iteration
|
nikcleju@2
|
67 # iter_time Vector containing computation times for each iteration
|
nikcleju@2
|
68 #
|
nikcleju@2
|
69 ###########################################################################
|
nikcleju@2
|
70 # Description
|
nikcleju@2
|
71 # greed_omp_qr performs a greedy signal decomposition.
|
nikcleju@2
|
72 # In each iteration a new element is selected depending on the inner
|
nikcleju@2
|
73 # product between the current residual and columns in P.
|
nikcleju@2
|
74 # The non-zero elements of s are approximated by orthogonally projecting
|
nikcleju@2
|
75 # x onto the selected elements in each iteration.
|
nikcleju@2
|
76 # The algorithm uses QR decomposition.
|
nikcleju@2
|
77 #
|
nikcleju@2
|
78 # See Also
|
nikcleju@2
|
79 # greed_omp_chol, greed_omp_cg, greed_omp_cgp, greed_omp_pinv,
|
nikcleju@2
|
80 # greed_omp_linsolve, greed_gp, greed_nomp
|
nikcleju@2
|
81 #
|
nikcleju@2
|
82 # Copyright (c) 2007 Thomas Blumensath
|
nikcleju@2
|
83 #
|
nikcleju@2
|
84 # The University of Edinburgh
|
nikcleju@2
|
85 # Email: thomas.blumensath@ed.ac.uk
|
nikcleju@2
|
86 # Comments and bug reports welcome
|
nikcleju@2
|
87 #
|
nikcleju@2
|
88 # This file is part of sparsity Version 0.1
|
nikcleju@2
|
89 # Created: April 2007
|
nikcleju@2
|
90 #
|
nikcleju@2
|
91 # Part of this toolbox was developed with the support of EPSRC Grant
|
nikcleju@2
|
92 # D000246/1
|
nikcleju@2
|
93 #
|
nikcleju@2
|
94 # Please read COPYRIGHT.m for terms and conditions.
|
nikcleju@2
|
95
|
nikcleju@2
|
96 ###########################################################################
|
nikcleju@2
|
97 # Default values and initialisation
|
nikcleju@2
|
98 ###########################################################################
|
nikcleju@2
|
99 #[n1 n2]=size(x);
|
nikcleju@2
|
100 #n1,n2 = x.shape
|
nikcleju@2
|
101 #if n2 == 1
|
nikcleju@2
|
102 # n=n1;
|
nikcleju@2
|
103 #elseif n1 == 1
|
nikcleju@2
|
104 # x=x';
|
nikcleju@2
|
105 # n=n2;
|
nikcleju@2
|
106 #else
|
nikcleju@2
|
107 # display('x must be a vector.');
|
nikcleju@2
|
108 # return
|
nikcleju@2
|
109 #end
|
nikcleju@2
|
110 if x.ndim != 1:
|
nikcleju@2
|
111 print 'x must be a vector.'
|
nikcleju@2
|
112 return
|
nikcleju@2
|
113 n = x.size
|
nikcleju@2
|
114
|
nikcleju@2
|
115 #sigsize = x'*x/n;
|
nikcleju@2
|
116 sigsize = np.vdot(x,x)/n;
|
nikcleju@2
|
117 initial_given = 0;
|
nikcleju@2
|
118 err_mse = np.array([]);
|
nikcleju@2
|
119 iter_time = np.array([]);
|
nikcleju@2
|
120 STOPCRIT = 'M';
|
nikcleju@2
|
121 STOPTOL = math.ceil(n/4.0);
|
nikcleju@2
|
122 MAXITER = n;
|
nikcleju@2
|
123 verbose = False;
|
nikcleju@2
|
124 s_initial = np.zeros(m);
|
nikcleju@2
|
125
|
nikcleju@2
|
126 if verbose:
|
nikcleju@2
|
127 print 'Initialising...'
|
nikcleju@2
|
128 #end
|
nikcleju@2
|
129
|
nikcleju@2
|
130 ###########################################################################
|
nikcleju@2
|
131 # Output variables
|
nikcleju@2
|
132 ###########################################################################
|
nikcleju@2
|
133 #switch nargout
|
nikcleju@2
|
134 # case 3
|
nikcleju@2
|
135 # comp_err=true;
|
nikcleju@2
|
136 # comp_time=true;
|
nikcleju@2
|
137 # case 2
|
nikcleju@2
|
138 # comp_err=true;
|
nikcleju@2
|
139 # comp_time=false;
|
nikcleju@2
|
140 # case 1
|
nikcleju@2
|
141 # comp_err=false;
|
nikcleju@2
|
142 # comp_time=false;
|
nikcleju@2
|
143 # case 0
|
nikcleju@2
|
144 # error('Please assign output variable.')
|
nikcleju@2
|
145 # otherwise
|
nikcleju@2
|
146 # error('Too many output arguments specified')
|
nikcleju@2
|
147 #end
|
nikcleju@2
|
148 if 'nargout' in opts:
|
nikcleju@2
|
149 if opts['nargout'] == 3:
|
nikcleju@2
|
150 comp_err = True
|
nikcleju@2
|
151 comp_time = True
|
nikcleju@2
|
152 elif opts['nargout'] == 2:
|
nikcleju@2
|
153 comp_err = True
|
nikcleju@2
|
154 comp_time = False
|
nikcleju@2
|
155 elif opts['nargout'] == 1:
|
nikcleju@2
|
156 comp_err = False
|
nikcleju@2
|
157 comp_time = False
|
nikcleju@2
|
158 elif opts['nargout'] == 0:
|
nikcleju@2
|
159 print 'Please assign output variable.'
|
nikcleju@2
|
160 return
|
nikcleju@2
|
161 else:
|
nikcleju@2
|
162 print 'Too many output arguments specified'
|
nikcleju@2
|
163 return
|
nikcleju@2
|
164 else:
|
nikcleju@2
|
165 # If not given, make default nargout = 3
|
nikcleju@2
|
166 # and add nargout to options
|
nikcleju@2
|
167 opts['nargout'] = 3
|
nikcleju@2
|
168 comp_err = True
|
nikcleju@2
|
169 comp_time = True
|
nikcleju@2
|
170
|
nikcleju@2
|
171 ###########################################################################
|
nikcleju@2
|
172 # Look through options
|
nikcleju@2
|
173 ###########################################################################
|
nikcleju@2
|
174 # Put option into nice format
|
nikcleju@2
|
175 #Options={};
|
nikcleju@2
|
176 #OS=nargin-3;
|
nikcleju@2
|
177 #c=1;
|
nikcleju@2
|
178 #for i=1:OS
|
nikcleju@2
|
179 # if isa(varargin{i},'cell')
|
nikcleju@2
|
180 # CellSize=length(varargin{i});
|
nikcleju@2
|
181 # ThisCell=varargin{i};
|
nikcleju@2
|
182 # for j=1:CellSize
|
nikcleju@2
|
183 # Options{c}=ThisCell{j};
|
nikcleju@2
|
184 # c=c+1;
|
nikcleju@2
|
185 # end
|
nikcleju@2
|
186 # else
|
nikcleju@2
|
187 # Options{c}=varargin{i};
|
nikcleju@2
|
188 # c=c+1;
|
nikcleju@2
|
189 # end
|
nikcleju@2
|
190 #end
|
nikcleju@2
|
191 #OS=length(Options);
|
nikcleju@2
|
192 #if rem(OS,2)
|
nikcleju@2
|
193 # error('Something is wrong with argument name and argument value pairs.')
|
nikcleju@2
|
194 #end
|
nikcleju@2
|
195 #
|
nikcleju@2
|
196 #for i=1:2:OS
|
nikcleju@2
|
197 # switch Options{i}
|
nikcleju@2
|
198 # case {'stopCrit'}
|
nikcleju@2
|
199 # if (strmatch(Options{i+1},{'M'; 'corr'; 'mse'; 'mse_change'},'exact'));
|
nikcleju@2
|
200 # STOPCRIT = Options{i+1};
|
nikcleju@2
|
201 # else error('stopCrit must be char string [M, corr, mse, mse_change]. Exiting.'); end
|
nikcleju@2
|
202 # case {'stopTol'}
|
nikcleju@2
|
203 # if isa(Options{i+1},'numeric') ; STOPTOL = Options{i+1};
|
nikcleju@2
|
204 # else error('stopTol must be number. Exiting.'); end
|
nikcleju@2
|
205 # case {'P_trans'}
|
nikcleju@2
|
206 # if isa(Options{i+1},'function_handle'); Pt = Options{i+1};
|
nikcleju@2
|
207 # else error('P_trans must be function _handle. Exiting.'); end
|
nikcleju@2
|
208 # case {'maxIter'}
|
nikcleju@2
|
209 # if isa(Options{i+1},'numeric'); MAXITER = Options{i+1};
|
nikcleju@2
|
210 # else error('maxIter must be a number. Exiting.'); end
|
nikcleju@2
|
211 # case {'verbose'}
|
nikcleju@2
|
212 # if isa(Options{i+1},'logical'); verbose = Options{i+1};
|
nikcleju@2
|
213 # else error('verbose must be a logical. Exiting.'); end
|
nikcleju@2
|
214 # case {'start_val'}
|
nikcleju@2
|
215 # if isa(Options{i+1},'numeric') & length(Options{i+1}) == m ;
|
nikcleju@2
|
216 # s_initial = Options{i+1};
|
nikcleju@2
|
217 # initial_given=1;
|
nikcleju@2
|
218 # else error('start_val must be a vector of length m. Exiting.'); end
|
nikcleju@2
|
219 # otherwise
|
nikcleju@2
|
220 # error('Unrecognised option. Exiting.')
|
nikcleju@2
|
221 # end
|
nikcleju@2
|
222 #end
|
nikcleju@2
|
223 if 'stopCrit' in opts:
|
nikcleju@2
|
224 STOPCRIT = opts['stopCrit']
|
nikcleju@2
|
225 if 'stopTol' in opts:
|
nikcleju@2
|
226 if hasattr(opts['stopTol'], '__int__'): # check if numeric
|
nikcleju@2
|
227 STOPTOL = opts['stopTol']
|
nikcleju@2
|
228 else:
|
nikcleju@2
|
229 raise TypeError('stopTol must be number. Exiting.')
|
nikcleju@2
|
230 if 'P_trans' in opts:
|
nikcleju@2
|
231 if hasattr(opts['P_trans'], '__call__'): # check if function handle
|
nikcleju@2
|
232 Pt = opts['P_trans']
|
nikcleju@2
|
233 else:
|
nikcleju@2
|
234 raise TypeError('P_trans must be function _handle. Exiting.')
|
nikcleju@2
|
235 if 'maxIter' in opts:
|
nikcleju@2
|
236 if hasattr(opts['maxIter'], '__int__'): # check if numeric
|
nikcleju@2
|
237 MAXITER = opts['maxIter']
|
nikcleju@2
|
238 else:
|
nikcleju@2
|
239 raise TypeError('maxIter must be a number. Exiting.')
|
nikcleju@2
|
240 if 'verbose' in opts:
|
nikcleju@2
|
241 # TODO: Should check here if is logical
|
nikcleju@2
|
242 verbose = opts['verbose']
|
nikcleju@2
|
243 if 'start_val' in opts:
|
nikcleju@2
|
244 # TODO: Should check here if is numeric
|
nikcleju@2
|
245 if opts['start_val'].size == m:
|
nikcleju@2
|
246 s_initial = opts['start_val']
|
nikcleju@2
|
247 initial_given = 1
|
nikcleju@2
|
248 else:
|
nikcleju@2
|
249 raise ValueError('start_val must be a vector of length m. Exiting.')
|
nikcleju@2
|
250 # Don't exit if unknown option is given, simply ignore it
|
nikcleju@2
|
251
|
nikcleju@2
|
252 #if strcmp(STOPCRIT,'M')
|
nikcleju@2
|
253 # maxM=STOPTOL;
|
nikcleju@2
|
254 #else
|
nikcleju@2
|
255 # maxM=MAXITER;
|
nikcleju@2
|
256 #end
|
nikcleju@2
|
257 if STOPCRIT == 'M':
|
nikcleju@2
|
258 maxM = STOPTOL
|
nikcleju@2
|
259 else:
|
nikcleju@2
|
260 maxM = MAXITER
|
nikcleju@2
|
261
|
nikcleju@2
|
262 # if nargout >=2
|
nikcleju@2
|
263 # err_mse = zeros(maxM,1);
|
nikcleju@2
|
264 # end
|
nikcleju@2
|
265 # if nargout ==3
|
nikcleju@2
|
266 # iter_time = zeros(maxM,1);
|
nikcleju@2
|
267 # end
|
nikcleju@2
|
268 if opts['nargout'] >= 2:
|
nikcleju@2
|
269 err_mse = np.zeros(maxM)
|
nikcleju@2
|
270 if opts['nargout'] == 3:
|
nikcleju@2
|
271 iter_time = np.zeros(maxM)
|
nikcleju@2
|
272
|
nikcleju@2
|
273 ###########################################################################
|
nikcleju@2
|
274 # Make P and Pt functions
|
nikcleju@2
|
275 ###########################################################################
|
nikcleju@2
|
276 #if isa(A,'float') P =@(z) A*z; Pt =@(z) A'*z;
|
nikcleju@2
|
277 #elseif isobject(A) P =@(z) A*z; Pt =@(z) A'*z;
|
nikcleju@2
|
278 #elseif isa(A,'function_handle')
|
nikcleju@2
|
279 # try
|
nikcleju@2
|
280 # if isa(Pt,'function_handle'); P=A;
|
nikcleju@2
|
281 # else error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end
|
nikcleju@2
|
282 # catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end
|
nikcleju@2
|
283 #else error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end
|
nikcleju@2
|
284 if hasattr(A, '__call__'):
|
nikcleju@2
|
285 if hasattr(Pt, '__call__'):
|
nikcleju@2
|
286 P = A
|
nikcleju@2
|
287 else:
|
nikcleju@2
|
288 raise TypeError('If P is a function handle, Pt also needs to be a function handle.')
|
nikcleju@2
|
289 else:
|
nikcleju@2
|
290 # TODO: should check here if A is matrix
|
nikcleju@2
|
291 P = lambda z: np.dot(A,z)
|
nikcleju@2
|
292 Pt = lambda z: np.dot(A.T,z)
|
nikcleju@2
|
293
|
nikcleju@2
|
294 ###########################################################################
|
nikcleju@2
|
295 # Random Check to see if dictionary is normalised
|
nikcleju@2
|
296 ###########################################################################
|
nikcleju@2
|
297 # mask=zeros(m,1);
|
nikcleju@2
|
298 # mask(ceil(rand*m))=1;
|
nikcleju@2
|
299 # nP=norm(P(mask));
|
nikcleju@2
|
300 # if abs(1-nP)>1e-3;
|
nikcleju@2
|
301 # display('Dictionary appears not to have unit norm columns.')
|
nikcleju@2
|
302 # end
|
nikcleju@2
|
303 mask = np.zeros(m)
|
nikcleju@2
|
304 mask[math.floor(np.random.rand() * m)] = 1
|
nikcleju@32
|
305 #nP = np.linalg.norm(P(mask))
|
nikcleju@32
|
306 #if abs(1-nP) > 1e-3:
|
nikcleju@32
|
307 # print 'Dictionary appears not to have unit norm columns.'
|
nikcleju@2
|
308
|
nikcleju@2
|
309 ###########################################################################
|
nikcleju@2
|
310 # Check if we have enough memory and initialise
|
nikcleju@2
|
311 ###########################################################################
|
nikcleju@2
|
312 # try Q=zeros(n,maxM);
|
nikcleju@2
|
313 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
|
nikcleju@2
|
314 # end
|
nikcleju@2
|
315 # try R=zeros(maxM);
|
nikcleju@2
|
316 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
|
nikcleju@2
|
317 # end
|
nikcleju@2
|
318 try:
|
nikcleju@2
|
319 Q = np.zeros((n,maxM))
|
nikcleju@2
|
320 except:
|
nikcleju@2
|
321 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
|
nikcleju@2
|
322 raise
|
nikcleju@2
|
323 try:
|
nikcleju@2
|
324 R = np.zeros((maxM, maxM))
|
nikcleju@2
|
325 except:
|
nikcleju@2
|
326 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
|
nikcleju@2
|
327 raise
|
nikcleju@2
|
328
|
nikcleju@2
|
329 ###########################################################################
|
nikcleju@2
|
330 # Do we start from zero or not?
|
nikcleju@2
|
331 ###########################################################################
|
nikcleju@2
|
332 #if initial_given ==1;
|
nikcleju@2
|
333 # IN = find(s_initial);
|
nikcleju@2
|
334 # if ~isempty(IN)
|
nikcleju@2
|
335 # Residual = x-P(s_initial);
|
nikcleju@2
|
336 # lengthIN=length(IN);
|
nikcleju@2
|
337 # z=[];
|
nikcleju@2
|
338 # for k=1:length(IN)
|
nikcleju@2
|
339 # # Extract new element
|
nikcleju@2
|
340 # mask=zeros(m,1);
|
nikcleju@2
|
341 # mask(IN(k))=1;
|
nikcleju@2
|
342 # new_element=P(mask);
|
nikcleju@2
|
343 #
|
nikcleju@2
|
344 # # Orthogonalise new element
|
nikcleju@2
|
345 # qP=Q(:,1:k-1)'*new_element;
|
nikcleju@2
|
346 # q=new_element-Q(:,1:k-1)*(qP);
|
nikcleju@2
|
347 #
|
nikcleju@2
|
348 # nq=norm(q);
|
nikcleju@2
|
349 # q=q/nq;
|
nikcleju@2
|
350 # # Update QR factorisation
|
nikcleju@2
|
351 # R(1:k-1,k)=qP;
|
nikcleju@2
|
352 # R(k,k)=nq;
|
nikcleju@2
|
353 # Q(:,k)=q;
|
nikcleju@2
|
354 #
|
nikcleju@2
|
355 # z(k)=q'*x;
|
nikcleju@2
|
356 # end
|
nikcleju@2
|
357 # s = s_initial;
|
nikcleju@2
|
358 # Residual=x-Q(:,k)*z;
|
nikcleju@2
|
359 # oldERR = Residual'*Residual/n;
|
nikcleju@2
|
360 # else
|
nikcleju@2
|
361 # IN = [];
|
nikcleju@2
|
362 # Residual = x;
|
nikcleju@2
|
363 # s = s_initial;
|
nikcleju@2
|
364 # sigsize = x'*x/n;
|
nikcleju@2
|
365 # oldERR = sigsize;
|
nikcleju@2
|
366 # k=0;
|
nikcleju@2
|
367 # z=[];
|
nikcleju@2
|
368 # end
|
nikcleju@2
|
369 #
|
nikcleju@2
|
370 #else
|
nikcleju@2
|
371 # IN = [];
|
nikcleju@2
|
372 # Residual = x;
|
nikcleju@2
|
373 # s = s_initial;
|
nikcleju@2
|
374 # sigsize = x'*x/n;
|
nikcleju@2
|
375 # oldERR = sigsize;
|
nikcleju@2
|
376 # k=0;
|
nikcleju@2
|
377 # z=[];
|
nikcleju@2
|
378 #end
|
nikcleju@2
|
379 if initial_given == 1:
|
nikcleju@2
|
380 #IN = find(s_initial);
|
nikcleju@2
|
381 IN = np.nonzero(s_initial)[0].tolist()
|
nikcleju@2
|
382 #if ~isempty(IN)
|
nikcleju@2
|
383 if IN.size > 0:
|
nikcleju@2
|
384 Residual = x - P(s_initial)
|
nikcleju@2
|
385 lengthIN = IN.size
|
nikcleju@2
|
386 z = np.array([])
|
nikcleju@2
|
387 #for k=1:length(IN)
|
nikcleju@2
|
388 for k in np.arange(IN.size):
|
nikcleju@2
|
389 # Extract new element
|
nikcleju@2
|
390 mask = np.zeros(m)
|
nikcleju@2
|
391 mask[IN[k]] = 1
|
nikcleju@2
|
392 new_element = P(mask)
|
nikcleju@2
|
393
|
nikcleju@2
|
394 # Orthogonalise new element
|
nikcleju@2
|
395 #qP=Q(:,1:k-1)'*new_element;
|
nikcleju@2
|
396 if k-1 >= 0:
|
nikcleju@2
|
397 qP = np.dot(Q[:,0:k].T , new_element)
|
nikcleju@2
|
398 #q=new_element-Q(:,1:k-1)*(qP);
|
nikcleju@2
|
399 q = new_element - np.dot(Q[:,0:k] , qP)
|
nikcleju@2
|
400
|
nikcleju@2
|
401 nq = np.linalg.norm(q)
|
nikcleju@2
|
402 q = q / nq
|
nikcleju@2
|
403 # Update QR factorisation
|
nikcleju@2
|
404 R[0:k,k] = qP
|
nikcleju@2
|
405 R[k,k] = nq
|
nikcleju@2
|
406 Q[:,k] = q
|
nikcleju@2
|
407 else:
|
nikcleju@2
|
408 q = new_element
|
nikcleju@2
|
409
|
nikcleju@2
|
410 nq = np.linalg.norm(q)
|
nikcleju@2
|
411 q = q / nq
|
nikcleju@2
|
412 # Update QR factorisation
|
nikcleju@2
|
413 R[k,k] = nq
|
nikcleju@2
|
414 Q[:,k] = q
|
nikcleju@2
|
415
|
nikcleju@2
|
416 z[k] = np.dot(q.T , x)
|
nikcleju@2
|
417 #end
|
nikcleju@2
|
418 s = s_initial.copy()
|
nikcleju@2
|
419 Residual = x - np.dot(Q[:,k] , z)
|
nikcleju@2
|
420 oldERR = np.vdot(Residual , Residual) / n;
|
nikcleju@2
|
421 else:
|
nikcleju@2
|
422 #IN = np.array([], dtype = int)
|
nikcleju@2
|
423 IN = np.array([], dtype = int).tolist()
|
nikcleju@2
|
424 Residual = x.copy()
|
nikcleju@2
|
425 s = s_initial.copy()
|
nikcleju@2
|
426 sigsize = np.vdot(x , x) / n
|
nikcleju@2
|
427 oldERR = sigsize
|
nikcleju@2
|
428 k = 0
|
nikcleju@2
|
429 #z = np.array([])
|
nikcleju@2
|
430 z = []
|
nikcleju@2
|
431 #end
|
nikcleju@2
|
432
|
nikcleju@2
|
433 else:
|
nikcleju@2
|
434 #IN = np.array([], dtype = int)
|
nikcleju@2
|
435 IN = np.array([], dtype = int).tolist()
|
nikcleju@2
|
436 Residual = x.copy()
|
nikcleju@2
|
437 s = s_initial.copy()
|
nikcleju@2
|
438 sigsize = np.vdot(x , x) / n
|
nikcleju@2
|
439 oldERR = sigsize
|
nikcleju@2
|
440 k = 0
|
nikcleju@2
|
441 #z = np.array([])
|
nikcleju@2
|
442 z = []
|
nikcleju@2
|
443 #end
|
nikcleju@2
|
444
|
nikcleju@2
|
445 ###########################################################################
|
nikcleju@2
|
446 # Main algorithm
|
nikcleju@2
|
447 ###########################################################################
|
nikcleju@2
|
448 # if verbose
|
nikcleju@2
|
449 # display('Main iterations...')
|
nikcleju@2
|
450 # end
|
nikcleju@2
|
451 # tic
|
nikcleju@2
|
452 # t=0;
|
nikcleju@2
|
453 # DR=Pt(Residual);
|
nikcleju@2
|
454 # done = 0;
|
nikcleju@2
|
455 # iter=1;
|
nikcleju@2
|
456 if verbose:
|
nikcleju@2
|
457 print 'Main iterations...'
|
nikcleju@2
|
458 tic = time.time()
|
nikcleju@2
|
459 t = 0
|
nikcleju@2
|
460 DR = Pt(Residual)
|
nikcleju@2
|
461 done = 0
|
nikcleju@2
|
462 iter = 1
|
nikcleju@2
|
463
|
nikcleju@2
|
464 #while ~done
|
nikcleju@2
|
465 #
|
nikcleju@2
|
466 # # Select new element
|
nikcleju@2
|
467 # DR(IN)=0;
|
nikcleju@2
|
468 # # Nic: replace selection with random variable
|
nikcleju@2
|
469 # # i.e. Randomized OMP!!
|
nikcleju@2
|
470 # # DON'T FORGET ABOUT THIS!!
|
nikcleju@2
|
471 # [v I]=max(abs(DR));
|
nikcleju@2
|
472 # #I = randp(exp(abs(DR).^2 ./ (norms.^2)'), [1 1]);
|
nikcleju@2
|
473 # IN=[IN I];
|
nikcleju@2
|
474 #
|
nikcleju@2
|
475 #
|
nikcleju@2
|
476 # k=k+1;
|
nikcleju@2
|
477 # # Extract new element
|
nikcleju@2
|
478 # mask=zeros(m,1);
|
nikcleju@2
|
479 # mask(IN(k))=1;
|
nikcleju@2
|
480 # new_element=P(mask);
|
nikcleju@2
|
481 #
|
nikcleju@2
|
482 # # Orthogonalise new element
|
nikcleju@2
|
483 # qP=Q(:,1:k-1)'*new_element;
|
nikcleju@2
|
484 # q=new_element-Q(:,1:k-1)*(qP);
|
nikcleju@2
|
485 #
|
nikcleju@2
|
486 # nq=norm(q);
|
nikcleju@2
|
487 # q=q/nq;
|
nikcleju@2
|
488 # # Update QR factorisation
|
nikcleju@2
|
489 # R(1:k-1,k)=qP;
|
nikcleju@2
|
490 # R(k,k)=nq;
|
nikcleju@2
|
491 # Q(:,k)=q;
|
nikcleju@2
|
492 #
|
nikcleju@2
|
493 # z(k)=q'*x;
|
nikcleju@2
|
494 #
|
nikcleju@2
|
495 # # New residual
|
nikcleju@2
|
496 # Residual=Residual-q*(z(k));
|
nikcleju@2
|
497 # DR=Pt(Residual);
|
nikcleju@2
|
498 #
|
nikcleju@2
|
499 # ERR=Residual'*Residual/n;
|
nikcleju@2
|
500 # if comp_err
|
nikcleju@2
|
501 # err_mse(iter)=ERR;
|
nikcleju@2
|
502 # end
|
nikcleju@2
|
503 #
|
nikcleju@2
|
504 # if comp_time
|
nikcleju@2
|
505 # iter_time(iter)=toc;
|
nikcleju@2
|
506 # end
|
nikcleju@2
|
507 #
|
nikcleju@2
|
508 ############################################################################
|
nikcleju@2
|
509 ## Are we done yet?
|
nikcleju@2
|
510 ############################################################################
|
nikcleju@2
|
511 #
|
nikcleju@2
|
512 # if strcmp(STOPCRIT,'M')
|
nikcleju@2
|
513 # if iter >= STOPTOL
|
nikcleju@2
|
514 # done =1;
|
nikcleju@2
|
515 # elseif verbose && toc-t>10
|
nikcleju@2
|
516 # display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
|
nikcleju@2
|
517 # t=toc;
|
nikcleju@2
|
518 # end
|
nikcleju@2
|
519 # elseif strcmp(STOPCRIT,'mse')
|
nikcleju@2
|
520 # if comp_err
|
nikcleju@2
|
521 # if err_mse(iter)<STOPTOL;
|
nikcleju@2
|
522 # done = 1;
|
nikcleju@2
|
523 # elseif verbose && toc-t>10
|
nikcleju@2
|
524 # display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
|
nikcleju@2
|
525 # t=toc;
|
nikcleju@2
|
526 # end
|
nikcleju@2
|
527 # else
|
nikcleju@2
|
528 # if ERR<STOPTOL;
|
nikcleju@2
|
529 # done = 1;
|
nikcleju@2
|
530 # elseif verbose && toc-t>10
|
nikcleju@2
|
531 # display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
|
nikcleju@2
|
532 # t=toc;
|
nikcleju@2
|
533 # end
|
nikcleju@2
|
534 # end
|
nikcleju@2
|
535 # elseif strcmp(STOPCRIT,'mse_change') && iter >=2
|
nikcleju@2
|
536 # if comp_err && iter >=2
|
nikcleju@2
|
537 # if ((err_mse(iter-1)-err_mse(iter))/sigsize <STOPTOL);
|
nikcleju@2
|
538 # done = 1;
|
nikcleju@2
|
539 # elseif verbose && toc-t>10
|
nikcleju@2
|
540 # display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
|
nikcleju@2
|
541 # t=toc;
|
nikcleju@2
|
542 # end
|
nikcleju@2
|
543 # else
|
nikcleju@2
|
544 # if ((oldERR - ERR)/sigsize < STOPTOL);
|
nikcleju@2
|
545 # done = 1;
|
nikcleju@2
|
546 # elseif verbose && toc-t>10
|
nikcleju@2
|
547 # display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
|
nikcleju@2
|
548 # t=toc;
|
nikcleju@2
|
549 # end
|
nikcleju@2
|
550 # end
|
nikcleju@2
|
551 # elseif strcmp(STOPCRIT,'corr')
|
nikcleju@2
|
552 # if max(abs(DR)) < STOPTOL;
|
nikcleju@2
|
553 # done = 1;
|
nikcleju@2
|
554 # elseif verbose && toc-t>10
|
nikcleju@2
|
555 # display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
|
nikcleju@2
|
556 # t=toc;
|
nikcleju@2
|
557 # end
|
nikcleju@2
|
558 # end
|
nikcleju@2
|
559 #
|
nikcleju@2
|
560 # # Also stop if residual gets too small or maxIter reached
|
nikcleju@2
|
561 # if comp_err
|
nikcleju@2
|
562 # if err_mse(iter)<1e-16
|
nikcleju@2
|
563 # display('Stopping. Exact signal representation found!')
|
nikcleju@2
|
564 # done=1;
|
nikcleju@2
|
565 # end
|
nikcleju@2
|
566 # else
|
nikcleju@2
|
567 #
|
nikcleju@2
|
568 #
|
nikcleju@2
|
569 # if iter>1
|
nikcleju@2
|
570 # if ERR<1e-16
|
nikcleju@2
|
571 # display('Stopping. Exact signal representation found!')
|
nikcleju@2
|
572 # done=1;
|
nikcleju@2
|
573 # end
|
nikcleju@2
|
574 # end
|
nikcleju@2
|
575 # end
|
nikcleju@2
|
576 #
|
nikcleju@2
|
577 # if iter >= MAXITER
|
nikcleju@2
|
578 # display('Stopping. Maximum number of iterations reached!')
|
nikcleju@2
|
579 # done = 1;
|
nikcleju@2
|
580 # end
|
nikcleju@2
|
581 #
|
nikcleju@2
|
582 ############################################################################
|
nikcleju@2
|
583 ## If not done, take another round
|
nikcleju@2
|
584 ############################################################################
|
nikcleju@2
|
585 #
|
nikcleju@2
|
586 # if ~done
|
nikcleju@2
|
587 # iter=iter+1;
|
nikcleju@2
|
588 # oldERR=ERR;
|
nikcleju@2
|
589 # end
|
nikcleju@2
|
590 #end
|
nikcleju@2
|
591 while not done:
|
nikcleju@2
|
592
|
nikcleju@2
|
593 # Select new element
|
nikcleju@2
|
594 DR[IN]=0
|
nikcleju@2
|
595 #[v I]=max(abs(DR));
|
nikcleju@2
|
596 #v = np.abs(DR).max()
|
nikcleju@2
|
597 I = np.abs(DR).argmax()
|
nikcleju@2
|
598 #IN = np.concatenate((IN,I))
|
nikcleju@2
|
599 IN.append(I)
|
nikcleju@2
|
600
|
nikcleju@2
|
601
|
nikcleju@2
|
602 #k = k + 1 Move to end, since is zero based
|
nikcleju@2
|
603
|
nikcleju@2
|
604 # Extract new element
|
nikcleju@2
|
605 mask = np.zeros(m)
|
nikcleju@2
|
606 mask[IN[k]] = 1
|
nikcleju@2
|
607 new_element = P(mask)
|
nikcleju@2
|
608
|
nikcleju@2
|
609 # Orthogonalise new element
|
nikcleju@2
|
610 if k-1 >= 0:
|
nikcleju@2
|
611 qP = np.dot(Q[:,0:k].T , new_element)
|
nikcleju@2
|
612 q = new_element - np.dot(Q[:,0:k] , qP)
|
nikcleju@2
|
613
|
nikcleju@2
|
614 nq = np.linalg.norm(q)
|
nikcleju@2
|
615 q = q/nq
|
nikcleju@2
|
616 # Update QR factorisation
|
nikcleju@2
|
617 R[0:k,k] = qP
|
nikcleju@2
|
618 R[k,k] = nq
|
nikcleju@2
|
619 Q[:,k] = q
|
nikcleju@2
|
620 else:
|
nikcleju@2
|
621 q = new_element
|
nikcleju@2
|
622
|
nikcleju@2
|
623 nq = np.linalg.norm(q)
|
nikcleju@2
|
624 q = q/nq
|
nikcleju@2
|
625 # Update QR factorisation
|
nikcleju@2
|
626 R[k,k] = nq
|
nikcleju@2
|
627 Q[:,k] = q
|
nikcleju@2
|
628
|
nikcleju@2
|
629 #z[k]=np.vdot(q , x)
|
nikcleju@2
|
630 z.append(np.vdot(q , x))
|
nikcleju@2
|
631
|
nikcleju@2
|
632 # New residual
|
nikcleju@2
|
633 Residual = Residual - q * (z[k])
|
nikcleju@2
|
634 DR = Pt(Residual)
|
nikcleju@2
|
635
|
nikcleju@2
|
636 ERR = np.vdot(Residual , Residual) / n
|
nikcleju@2
|
637 if comp_err:
|
nikcleju@2
|
638 err_mse[iter-1] = ERR
|
nikcleju@2
|
639 #end
|
nikcleju@2
|
640
|
nikcleju@2
|
641 if comp_time:
|
nikcleju@2
|
642 iter_time[iter-1] = time.time() - tic
|
nikcleju@2
|
643 #end
|
nikcleju@2
|
644
|
nikcleju@2
|
645 ###########################################################################
|
nikcleju@2
|
646 # Are we done yet?
|
nikcleju@2
|
647 ###########################################################################
|
nikcleju@2
|
648 if STOPCRIT == 'M':
|
nikcleju@2
|
649 if iter >= STOPTOL:
|
nikcleju@2
|
650 done = 1
|
nikcleju@2
|
651 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
652 #display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
|
nikcleju@2
|
653 print 'Iteration '+iter+'. --- '+(STOPTOL-iter)+' iterations to go'
|
nikcleju@2
|
654 t = time.time()
|
nikcleju@2
|
655 #end
|
nikcleju@2
|
656 elif STOPCRIT =='mse':
|
nikcleju@2
|
657 if comp_err:
|
nikcleju@2
|
658 if err_mse[iter-1] < STOPTOL:
|
nikcleju@2
|
659 done = 1
|
nikcleju@2
|
660 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
661 #display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
|
nikcleju@2
|
662 print 'Iteration '+iter+'. --- '+err_mse[iter-1]+' mse'
|
nikcleju@2
|
663 t = time.time()
|
nikcleju@2
|
664 #end
|
nikcleju@2
|
665 else:
|
nikcleju@2
|
666 if ERR < STOPTOL:
|
nikcleju@2
|
667 done = 1
|
nikcleju@2
|
668 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
669 #display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
|
nikcleju@2
|
670 print 'Iteration '+iter+'. --- '+ERR+' mse'
|
nikcleju@2
|
671 t = time.time()
|
nikcleju@2
|
672 #end
|
nikcleju@2
|
673 #end
|
nikcleju@2
|
674 elif STOPCRIT == 'mse_change' and iter >=2:
|
nikcleju@2
|
675 if comp_err and iter >=2:
|
nikcleju@2
|
676 if ((err_mse[iter-2] - err_mse[iter-1])/sigsize < STOPTOL):
|
nikcleju@2
|
677 done = 1
|
nikcleju@2
|
678 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
679 #display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
|
nikcleju@2
|
680 print 'Iteration '+iter+'. --- '+((err_mse[iter-2]-err_mse[iter-1])/sigsize)+' mse change'
|
nikcleju@2
|
681 t = time.time()
|
nikcleju@2
|
682 #end
|
nikcleju@2
|
683 else:
|
nikcleju@2
|
684 if ((oldERR - ERR)/sigsize < STOPTOL):
|
nikcleju@2
|
685 done = 1
|
nikcleju@2
|
686 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
687 #display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
|
nikcleju@2
|
688 print 'Iteration '+iter+'. --- '+((oldERR - ERR)/sigsize)+' mse change'
|
nikcleju@2
|
689 t = time.time()
|
nikcleju@2
|
690 #end
|
nikcleju@2
|
691 #end
|
nikcleju@2
|
692 elif STOPCRIT == 'corr':
|
nikcleju@2
|
693 if np.abs(DR).max() < STOPTOL:
|
nikcleju@2
|
694 done = 1
|
nikcleju@2
|
695 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
|
nikcleju@2
|
696 #display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
|
nikcleju@2
|
697 print 'Iteration '+iter+'. --- '+(np.abs(DR).max())+' corr'
|
nikcleju@2
|
698 t = time.time()
|
nikcleju@2
|
699 #end
|
nikcleju@2
|
700 #end
|
nikcleju@2
|
701
|
nikcleju@2
|
702 # Also stop if residual gets too small or maxIter reached
|
nikcleju@2
|
703 if comp_err:
|
nikcleju@2
|
704 if err_mse[iter-1] < 1e-14:
|
nikcleju@2
|
705 done = 1
|
nikcleju@2
|
706 # Nic: added verbose check
|
nikcleju@2
|
707 if verbose:
|
nikcleju@2
|
708 print 'Stopping. Exact signal representation found!'
|
nikcleju@2
|
709 #end
|
nikcleju@2
|
710 else:
|
nikcleju@2
|
711 if iter > 1:
|
nikcleju@2
|
712 if ERR < 1e-14:
|
nikcleju@2
|
713 done = 1
|
nikcleju@2
|
714 # Nic: added verbose check
|
nikcleju@2
|
715 if verbose:
|
nikcleju@2
|
716 print 'Stopping. Exact signal representation found!'
|
nikcleju@2
|
717 #end
|
nikcleju@2
|
718 #end
|
nikcleju@2
|
719 #end
|
nikcleju@2
|
720
|
nikcleju@2
|
721
|
nikcleju@2
|
722 if iter >= MAXITER:
|
nikcleju@2
|
723 done = 1
|
nikcleju@2
|
724 # Nic: added verbose check
|
nikcleju@2
|
725 if verbose:
|
nikcleju@2
|
726 print 'Stopping. Maximum number of iterations reached!'
|
nikcleju@2
|
727 #end
|
nikcleju@2
|
728
|
nikcleju@2
|
729 ###########################################################################
|
nikcleju@2
|
730 # If not done, take another round
|
nikcleju@2
|
731 ###########################################################################
|
nikcleju@2
|
732 if not done:
|
nikcleju@2
|
733 iter = iter + 1
|
nikcleju@2
|
734 oldERR = ERR
|
nikcleju@2
|
735 #end
|
nikcleju@2
|
736
|
nikcleju@2
|
737 # Moved here from front, since we are 0-based
|
nikcleju@2
|
738 k = k + 1
|
nikcleju@2
|
739 #end
|
nikcleju@2
|
740
|
nikcleju@2
|
741 ###########################################################################
|
nikcleju@2
|
742 # Now we can solve for s by back-substitution
|
nikcleju@2
|
743 ###########################################################################
|
nikcleju@2
|
744 #s(IN)=R(1:k,1:k)\z(1:k)';
|
nikcleju@2
|
745 s[IN] = scipy.linalg.solve(R[0:k,0:k] , np.array(z[0:k]))
|
nikcleju@2
|
746
|
nikcleju@2
|
747 ###########################################################################
|
nikcleju@2
|
748 # Only return as many elements as iterations
|
nikcleju@2
|
749 ###########################################################################
|
nikcleju@2
|
750 if opts['nargout'] >= 2:
|
nikcleju@2
|
751 err_mse = err_mse[0:iter-1]
|
nikcleju@2
|
752 #end
|
nikcleju@2
|
753 if opts['nargout'] == 3:
|
nikcleju@2
|
754 iter_time = iter_time[0:iter-1]
|
nikcleju@2
|
755 #end
|
nikcleju@2
|
756 if verbose:
|
nikcleju@2
|
757 print 'Done'
|
nikcleju@2
|
758 #end
|
nikcleju@2
|
759
|
nikcleju@2
|
760 # Return
|
nikcleju@2
|
761 if opts['nargout'] == 1:
|
nikcleju@2
|
762 return s
|
nikcleju@2
|
763 elif opts['nargout'] == 2:
|
nikcleju@2
|
764 return s, err_mse
|
nikcleju@2
|
765 elif opts['nargout'] == 3:
|
nikcleju@2
|
766 return s, err_mse, iter_time
|
nikcleju@2
|
767
|
nikcleju@2
|
768 # Change history
|
nikcleju@2
|
769 #
|
nikcleju@2
|
770 # 8 of Februray: Algo does no longer stop if dictionary is not normaliesd.
|
nikcleju@2
|
771
|
nikcleju@2
|
772 # End of greed_omp_qr() function
|
nikcleju@2
|
773 #--------------------------------
|
nikcleju@2
|
774
|
nikcleju@2
|
775
|
nikcleju@2
|
776 def omp_qr(x, dict, D, natom, tolerance):
|
nikcleju@2
|
777 """ Recover x using QR implementation of OMP
|
nikcleju@2
|
778
|
nikcleju@2
|
779 Parameter
|
nikcleju@2
|
780 ---------
|
nikcleju@2
|
781 x: measurements
|
nikcleju@2
|
782 dict: dictionary
|
nikcleju@2
|
783 D: Gramian of dictionary
|
nikcleju@2
|
784 natom: iterations
|
nikcleju@2
|
785 tolerance: error tolerance
|
nikcleju@2
|
786
|
nikcleju@2
|
787 Return
|
nikcleju@2
|
788 ------
|
nikcleju@2
|
789 x_hat : estimate of x
|
nikcleju@2
|
790 gamma : indices where non-zero
|
nikcleju@2
|
791
|
nikcleju@2
|
792 For more information, see http://media.aau.dk/null_space_pursuits/2011/10/efficient-omp.html
|
nikcleju@2
|
793 """
|
nikcleju@2
|
794 msize, dictsize = dict.shape
|
nikcleju@2
|
795 normr2 = np.vdot(x,x)
|
nikcleju@2
|
796 normtol2 = tolerance*normr2
|
nikcleju@2
|
797 R = np.zeros((natom,natom))
|
nikcleju@2
|
798 Q = np.zeros((msize,natom))
|
nikcleju@2
|
799 gamma = []
|
nikcleju@2
|
800
|
nikcleju@2
|
801 # find initial projections
|
nikcleju@2
|
802 origprojections = np.dot(x.T,dict)
|
nikcleju@2
|
803 origprojectionsT = origprojections.T
|
nikcleju@2
|
804 projections = origprojections.copy();
|
nikcleju@2
|
805
|
nikcleju@2
|
806 k = 0
|
nikcleju@2
|
807 while (normr2 > normtol2) and (k < natom):
|
nikcleju@2
|
808 # find index of maximum magnitude projection
|
nikcleju@2
|
809 newgam = np.argmax(np.abs(projections ** 2))
|
nikcleju@2
|
810 gamma.append(newgam)
|
nikcleju@2
|
811 # update QR factorization, projections, and residual energy
|
nikcleju@2
|
812 if k == 0:
|
nikcleju@2
|
813 R[0,0] = 1
|
nikcleju@2
|
814 Q[:,0] = dict[:,newgam].copy()
|
nikcleju@2
|
815 # update projections
|
nikcleju@2
|
816 QtempQtempT = np.outer(Q[:,0],Q[:,0])
|
nikcleju@2
|
817 projections -= np.dot(x.T, np.dot(QtempQtempT,dict))
|
nikcleju@2
|
818 # update residual energy
|
nikcleju@2
|
819 normr2 -= np.vdot(x, np.dot(QtempQtempT,x))
|
nikcleju@2
|
820 else:
|
nikcleju@2
|
821 w = scipy.linalg.solve_triangular(R[0:k,0:k],D[gamma[0:k],newgam],trans=1)
|
nikcleju@2
|
822 R[k,k] = np.sqrt(1-np.vdot(w,w))
|
nikcleju@2
|
823 R[0:k,k] = w.copy()
|
nikcleju@2
|
824 Q[:,k] = (dict[:,newgam] - np.dot(QtempQtempT,dict[:,newgam]))/R[k,k]
|
nikcleju@2
|
825 QkQkT = np.outer(Q[:,k],Q[:,k])
|
nikcleju@2
|
826 xTQkQkT = np.dot(x.T,QkQkT)
|
nikcleju@2
|
827 QtempQtempT += QkQkT
|
nikcleju@2
|
828 # update projections
|
nikcleju@2
|
829 projections -= np.dot(xTQkQkT,dict)
|
nikcleju@2
|
830 # update residual energy
|
nikcleju@2
|
831 normr2 -= np.dot(xTQkQkT,x)
|
nikcleju@2
|
832
|
nikcleju@2
|
833 k += 1
|
nikcleju@2
|
834
|
nikcleju@2
|
835 # build solution
|
nikcleju@2
|
836 tempR = R[0:k,0:k]
|
nikcleju@2
|
837 w = scipy.linalg.solve_triangular(tempR,origprojectionsT[gamma[0:k]],trans=1)
|
nikcleju@2
|
838 x_hat = np.zeros((dictsize,1))
|
nikcleju@2
|
839 x_hat[gamma[0:k]] = scipy.linalg.solve_triangular(tempR,w)
|
nikcleju@2
|
840
|
nikcleju@2
|
841 return x_hat, gamma |