Mercurial > hg > pycsalgos
comparison pyCSalgos/GAP/GAP.py @ 27:1a88766113a9
A lot of things.
Fixed problem in Gap
Fixed multiprocessor versions of script (both PP and multiproc)
author | nikcleju |
---|---|
date | Wed, 09 Nov 2011 18:18:42 +0000 |
parents | f0f77d92e0c1 |
children | 4f3bc35195ce |
comparison
equal
deleted
inserted
replaced
26:f0f77d92e0c1 | 27:1a88766113a9 |
---|---|
5 @author: ncleju | 5 @author: ncleju |
6 """ | 6 """ |
7 | 7 |
8 #from numpy import * | 8 #from numpy import * |
9 #from scipy import * | 9 #from scipy import * |
10 import numpy as np | 10 import numpy |
11 import numpy.linalg | |
11 import scipy as sp | 12 import scipy as sp |
12 import scipy.stats #from scipy import stats | 13 #import scipy.stats #from scipy import stats |
13 import scipy.linalg #from scipy import lnalg | 14 #import scipy.linalg #from scipy import lnalg |
14 import math | 15 import math |
15 | 16 |
16 from numpy.random import RandomState | 17 from numpy.random import RandomState |
17 rng = RandomState() | 18 rng = RandomState() |
18 | 19 |
19 def Generate_Analysis_Operator(d, p): | 20 def Generate_Analysis_Operator(d, p): |
20 # generate random tight frame with equal column norms | 21 # generate random tight frame with equal column norms |
21 if p == d: | 22 if p == d: |
22 T = rng.randn(d,d); | 23 T = rng.randn(d,d); |
23 [Omega, discard] = np.qr(T); | 24 [Omega, discard] = numpy.qr(T); |
24 else: | 25 else: |
25 Omega = rng.randn(p, d); | 26 Omega = rng.randn(p, d); |
26 T = np.zeros((p, d)); | 27 T = numpy.zeros((p, d)); |
27 tol = 1e-8; | 28 tol = 1e-8; |
28 max_j = 200; | 29 max_j = 200; |
29 j = 1; | 30 j = 1; |
30 while (sum(sum(abs(T-Omega))) > np.dot(tol,np.dot(p,d)) and j < max_j): | 31 while (sum(sum(abs(T-Omega))) > numpy.dot(tol,numpy.dot(p,d)) and j < max_j): |
31 j = j + 1; | 32 j = j + 1; |
32 T = Omega; | 33 T = Omega; |
33 [U, S, Vh] = sp.linalg.svd(Omega); | 34 [U, S, Vh] = numpy.linalg.svd(Omega); |
34 V = Vh.T | 35 V = Vh.T |
35 #Omega = U * [eye(d); zeros(p-d,d)] * V'; | 36 #Omega = U * [eye(d); zeros(p-d,d)] * V'; |
36 Omega2 = np.dot(np.dot(U, np.concatenate((np.eye(d), np.zeros((p-d,d))))), V.transpose()) | 37 Omega2 = numpy.dot(numpy.dot(U, numpy.concatenate((numpy.eye(d), numpy.zeros((p-d,d))))), V.transpose()) |
37 #Omega = diag(1./sqrt(diag(Omega*Omega')))*Omega; | 38 #Omega = diag(1./sqrt(diag(Omega*Omega')))*Omega; |
38 Omega = np.dot(np.diag(1.0 / np.sqrt(np.diag(np.dot(Omega2,Omega2.transpose())))), Omega2) | 39 Omega = numpy.dot(numpy.diag(1.0 / numpy.sqrt(numpy.diag(numpy.dot(Omega2,Omega2.transpose())))), Omega2) |
39 #end | 40 #end |
40 ##disp(j); | 41 ##disp(j); |
41 #end | 42 #end |
42 return Omega | 43 return Omega |
43 | 44 |
65 # for i = 1:size(Omega,1) | 66 # for i = 1:size(Omega,1) |
66 # Omega(i,:) = Omega(i,:) / norm(Omega(i,:)); | 67 # Omega(i,:) = Omega(i,:) / norm(Omega(i,:)); |
67 # end | 68 # end |
68 | 69 |
69 #Init | 70 #Init |
70 LambdaMat = np.zeros((k,numvectors)) | 71 LambdaMat = numpy.zeros((k,numvectors)) |
71 x0 = np.zeros((d,numvectors)) | 72 x0 = numpy.zeros((d,numvectors)) |
72 y = np.zeros((m,numvectors)) | 73 y = numpy.zeros((m,numvectors)) |
73 realnoise = np.zeros((m,numvectors)) | 74 realnoise = numpy.zeros((m,numvectors)) |
74 | 75 |
75 M = rng.randn(m,d); | 76 M = rng.randn(m,d); |
76 | 77 |
77 #for i=1:numvectors | 78 #for i=1:numvectors |
78 for i in range(0,numvectors): | 79 for i in range(0,numvectors): |
82 if normstr == 'l0': | 83 if normstr == 'l0': |
83 # Unchanged | 84 # Unchanged |
84 | 85 |
85 #Lambda=randperm(p); | 86 #Lambda=randperm(p); |
86 Lambda = rng.permutation(int(p)); | 87 Lambda = rng.permutation(int(p)); |
87 Lambda = np.sort(Lambda[0:k]); | 88 Lambda = numpy.sort(Lambda[0:k]); |
88 LambdaMat[:,i] = Lambda; # store for output | 89 LambdaMat[:,i] = Lambda; # store for output |
89 | 90 |
90 # The signal is drawn at random from the null-space defined by the rows | 91 # The signal is drawn at random from the null-space defined by the rows |
91 # of the matreix Omega(Lambda,:) | 92 # of the matreix Omega(Lambda,:) |
92 [U,D,Vh] = sp.linalg.svd(Omega[Lambda,:]); | 93 [U,D,Vh] = numpy.linalg.svd(Omega[Lambda,:]); |
93 V = Vh.T | 94 V = Vh.T |
94 NullSpace = V[:,k:]; | 95 NullSpace = V[:,k:]; |
95 #print np.dot(NullSpace, rng.randn(d-k,1)).shape | 96 #print numpy.dot(NullSpace, rng.randn(d-k,1)).shape |
96 #print x0[:,i].shape | 97 #print x0[:,i].shape |
97 x0[:,i] = np.squeeze(np.dot(NullSpace, rng.randn(d-k,1))); | 98 x0[:,i] = numpy.squeeze(numpy.dot(NullSpace, rng.randn(d-k,1))); |
98 # Nic: add orthogonality noise | 99 # Nic: add orthogonality noise |
99 # orthonoiseSNRdb = 6; | 100 # orthonoiseSNRdb = 6; |
100 # n = randn(p,1); | 101 # n = randn(p,1); |
101 # #x0(:,i) = x0(:,i) + n / norm(n)^2 * norm(x0(:,i))^2 / 10^(orthonoiseSNRdb/10); | 102 # #x0(:,i) = x0(:,i) + n / norm(n)^2 * norm(x0(:,i))^2 / 10^(orthonoiseSNRdb/10); |
102 # n = n / norm(n)^2 * norm(Omega * x0(:,i))^2 / 10^(orthonoiseSNRdb/10); | 103 # n = n / norm(n)^2 * norm(Omega * x0(:,i))^2 / 10^(orthonoiseSNRdb/10); |
113 print('Nic says: not implemented yet') | 114 print('Nic says: not implemented yet') |
114 raise Exception('Nic says: not implemented yet') | 115 raise Exception('Nic says: not implemented yet') |
115 #end | 116 #end |
116 | 117 |
117 # Acquire measurements | 118 # Acquire measurements |
118 y[:,i] = np.dot(M, x0[:,i]) | 119 y[:,i] = numpy.dot(M, x0[:,i]) |
119 | 120 |
120 # Add noise | 121 # Add noise |
121 t_norm = np.linalg.norm(y[:,i],2); | 122 t_norm = numpy.linalg.norm(y[:,i],2); |
122 n = np.squeeze(rng.randn(m, 1)); | 123 n = numpy.squeeze(rng.randn(m, 1)); |
123 y[:,i] = y[:,i] + noiselevel * t_norm * n / np.linalg.norm(n, 2); | 124 y[:,i] = y[:,i] + noiselevel * t_norm * n / numpy.linalg.norm(n, 2); |
124 realnoise[:,i] = noiselevel * t_norm * n / np.linalg.norm(n, 2) | 125 realnoise[:,i] = noiselevel * t_norm * n / numpy.linalg.norm(n, 2) |
125 #end | 126 #end |
126 | 127 |
127 return x0,y,M,LambdaMat,realnoise | 128 return x0,y,M,LambdaMat,realnoise |
128 | 129 |
129 ##################### | 130 ##################### |
171 if params['l2solver'] == 'pseudoinverse': | 172 if params['l2solver'] == 'pseudoinverse': |
172 #if strcmp(class(M), 'double') && strcmp(class(Omega), 'double') | 173 #if strcmp(class(M), 'double') && strcmp(class(Omega), 'double') |
173 if M.dtype == 'float64' and Omega.dtype == 'double': | 174 if M.dtype == 'float64' and Omega.dtype == 'double': |
174 while True: | 175 while True: |
175 alpha = math.sqrt(lagmult); | 176 alpha = math.sqrt(lagmult); |
176 xhat = np.linalg.lstsq(np.concatenate((M, alpha*Omega[Lambdahat,:])), np.concatenate((y, np.zeros(Lambdahat.size))))[0] | 177 xhat = numpy.linalg.lstsq(numpy.concatenate((M, alpha*Omega[Lambdahat,:])), numpy.concatenate((y, numpy.zeros(Lambdahat.size))))[0] |
177 temp = np.linalg.norm(y - np.dot(M,xhat), 2); | 178 temp = numpy.linalg.norm(y - numpy.dot(M,xhat), 2); |
178 #disp(['fidelity error=', num2str(temp), ' lagmult=', num2str(lagmult)]); | 179 #disp(['fidelity error=', num2str(temp), ' lagmult=', num2str(lagmult)]); |
179 if temp <= params['noise_level']: | 180 if temp <= params['noise_level']: |
180 was_feasible = True; | 181 was_feasible = True; |
181 if was_infeasible: | 182 if was_infeasible: |
182 break; | 183 break; |
189 break; | 190 break; |
190 lagmult = lagmult/lagmultfactor; | 191 lagmult = lagmult/lagmultfactor; |
191 if lagmult < lagmultmin or lagmult > lagmultmax: | 192 if lagmult < lagmultmin or lagmult > lagmultmax: |
192 break; | 193 break; |
193 xprev = xhat.copy(); | 194 xprev = xhat.copy(); |
194 arepr = np.dot(Omega[Lambdahat, :], xhat); | 195 arepr = numpy.dot(Omega[Lambdahat, :], xhat); |
195 return xhat,arepr,lagmult; | 196 return xhat,arepr,lagmult; |
196 | 197 |
197 | 198 |
198 ######################################################################## | 199 ######################################################################## |
199 ## Computation using conjugate gradient method. | 200 ## Computation using conjugate gradient method. |
200 ######################################################################## | 201 ######################################################################## |
201 #if strcmp(class(MH),'function_handle') | 202 #if strcmp(class(MH),'function_handle') |
202 if hasattr(MH, '__call__'): | 203 if hasattr(MH, '__call__'): |
203 b = MH(y); | 204 b = MH(y); |
204 else: | 205 else: |
205 b = np.dot(MH, y); | 206 b = numpy.dot(MH, y); |
206 | 207 |
207 norm_b = np.linalg.norm(b, 2); | 208 norm_b = numpy.linalg.norm(b, 2); |
208 xhat = xinit.copy(); | 209 xhat = xinit.copy(); |
209 xprev = xinit.copy(); | 210 xprev = xinit.copy(); |
210 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b; | 211 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b; |
211 direction = -residual; | 212 direction = -residual; |
212 iter = 0; | 213 iter = 0; |
213 | 214 |
214 while iter < params.max_inner_iteration: | 215 while iter < params.max_inner_iteration: |
215 iter = iter + 1; | 216 iter = iter + 1; |
216 alpha = np.linalg.norm(residual,2)**2 / np.dot(direction.T, TheHermitianMatrix(direction, M, MH, Omega, OmegaH, Lambdahat, lagmult)); | 217 alpha = numpy.linalg.norm(residual,2)**2 / numpy.dot(direction.T, TheHermitianMatrix(direction, M, MH, Omega, OmegaH, Lambdahat, lagmult)); |
217 xhat = xhat + alpha*direction; | 218 xhat = xhat + alpha*direction; |
218 prev_residual = residual.copy(); | 219 prev_residual = residual.copy(); |
219 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b; | 220 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b; |
220 beta = np.linalg.norm(residual,2)**2 / np.linalg.norm(prev_residual,2)**2; | 221 beta = numpy.linalg.norm(residual,2)**2 / numpy.linalg.norm(prev_residual,2)**2; |
221 direction = -residual + beta*direction; | 222 direction = -residual + beta*direction; |
222 | 223 |
223 if np.linalg.norm(residual,2)/norm_b < params['l2_accuracy']*(lagmult**(accuracy_adjustment_exponent)) or iter == params['max_inner_iteration']: | 224 if numpy.linalg.norm(residual,2)/norm_b < params['l2_accuracy']*(lagmult**(accuracy_adjustment_exponent)) or iter == params['max_inner_iteration']: |
224 #if strcmp(class(M), 'function_handle') | 225 #if strcmp(class(M), 'function_handle') |
225 if hasattr(M, '__call__'): | 226 if hasattr(M, '__call__'): |
226 temp = np.linalg.norm(y-M(xhat), 2); | 227 temp = numpy.linalg.norm(y-M(xhat), 2); |
227 else: | 228 else: |
228 temp = np.linalg.norm(y-np.dot(M,xhat), 2); | 229 temp = numpy.linalg.norm(y-numpy.dot(M,xhat), 2); |
229 | 230 |
230 #if strcmp(class(Omega), 'function_handle') | 231 #if strcmp(class(Omega), 'function_handle') |
231 if hasattr(Omega, '__call__'): | 232 if hasattr(Omega, '__call__'): |
232 u = Omega(xhat); | 233 u = Omega(xhat); |
233 u = math.sqrt(lagmult)*np.linalg.norm(u(Lambdahat), 2); | 234 u = math.sqrt(lagmult)*numpy.linalg.norm(u(Lambdahat), 2); |
234 else: | 235 else: |
235 u = math.sqrt(lagmult)*np.linalg.norm(Omega[Lambdahat,:]*xhat, 2); | 236 u = math.sqrt(lagmult)*numpy.linalg.norm(Omega[Lambdahat,:]*xhat, 2); |
236 | 237 |
237 | 238 |
238 #disp(['residual=', num2str(norm(residual,2)), ' norm_b=', num2str(norm_b), ' omegapart=', num2str(u), ' fidelity error=', num2str(temp), ' lagmult=', num2str(lagmult), ' iter=', num2str(iter)]); | 239 #disp(['residual=', num2str(norm(residual,2)), ' norm_b=', num2str(norm_b), ' omegapart=', num2str(u), ' fidelity error=', num2str(temp), ' lagmult=', num2str(lagmult), ' iter=', num2str(iter)]); |
239 | 240 |
240 if temp <= params['noise_level']: | 241 if temp <= params['noise_level']: |
283 #if strcmp(class(Omega),'function_handle') | 284 #if strcmp(class(Omega),'function_handle') |
284 if hasattr(Omega, '__call__'): | 285 if hasattr(Omega, '__call__'): |
285 temp = Omega(xhat); | 286 temp = Omega(xhat); |
286 arepr = temp(Lambdahat); | 287 arepr = temp(Lambdahat); |
287 else: ## here Omega is assumed to be a matrix | 288 else: ## here Omega is assumed to be a matrix |
288 arepr = np.dot(Omega[Lambdahat, :], xhat); | 289 arepr = numpy.dot(Omega[Lambdahat, :], xhat); |
289 | 290 |
290 return xhat,arepr,lagmult | 291 return xhat,arepr,lagmult |
291 | 292 |
292 | 293 |
293 ## | 294 ## |
297 def TheHermitianMatrix(x, M, MH, Omega, OmegaH, L, lm): | 298 def TheHermitianMatrix(x, M, MH, Omega, OmegaH, L, lm): |
298 #if strcmp(class(M), 'function_handle') | 299 #if strcmp(class(M), 'function_handle') |
299 if hasattr(M, '__call__'): | 300 if hasattr(M, '__call__'): |
300 w = MH(M(x)); | 301 w = MH(M(x)); |
301 else: ## M and MH are matrices | 302 else: ## M and MH are matrices |
302 w = np.dot(np.dot(MH, M), x); | 303 w = numpy.dot(numpy.dot(MH, M), x); |
303 | 304 |
304 if hasattr(Omega, '__call__'): | 305 if hasattr(Omega, '__call__'): |
305 v = Omega(x); | 306 v = Omega(x); |
306 vt = np.zeros(v.size); | 307 vt = numpy.zeros(v.size); |
307 vt[L] = v[L].copy(); | 308 vt[L] = v[L].copy(); |
308 w = w + lm*OmegaH(vt); | 309 w = w + lm*OmegaH(vt); |
309 else: ## Omega is assumed to be a matrix and OmegaH is its conjugate transpose | 310 else: ## Omega is assumed to be a matrix and OmegaH is its conjugate transpose |
310 w = w + lm*np.dot(np.dot(OmegaH[:, L],Omega[L, :]),x); | 311 w = w + lm*numpy.dot(numpy.dot(OmegaH[:, L],Omega[L, :]),x); |
311 | 312 |
312 return w | 313 return w |
313 | 314 |
314 def GAP(y, M, MH, Omega, OmegaH, params, xinit): | 315 def GAP(y, M, MH, Omega, OmegaH, params, xinit): |
315 #function [xhat, Lambdahat] = GAP(y, M, MH, Omega, OmegaH, params, xinit) | 316 #function [xhat, Lambdahat] = GAP(y, M, MH, Omega, OmegaH, params, xinit) |
387 # p = length(Omega(zeros(d,1))); | 388 # p = length(Omega(zeros(d,1))); |
388 #else ## Omega is a matrix | 389 #else ## Omega is a matrix |
389 # p = size(Omega, 1); | 390 # p = size(Omega, 1); |
390 #end | 391 #end |
391 if hasattr(Omega, '__call__'): | 392 if hasattr(Omega, '__call__'): |
392 p = Omega(np.zeros((d,1))) | 393 p = Omega(numpy.zeros((d,1))) |
393 else: | 394 else: |
394 p = Omega.shape[0] | 395 p = Omega.shape[0] |
395 | 396 |
396 | 397 |
397 iter = 0 | 398 iter = 0 |
398 lagmult = 1e-4 | 399 lagmult = 1e-4 |
399 #Lambdahat = 1:p; | 400 #Lambdahat = 1:p; |
400 Lambdahat = np.arange(p) | 401 Lambdahat = numpy.arange(p) |
401 #while iter < params.num_iteration | 402 #while iter < params.num_iteration |
402 while iter < params["num_iteration"]: | 403 while iter < params["num_iteration"]: |
403 iter = iter + 1 | 404 iter = iter + 1 |
404 #[xhat, analysis_repr, lagmult] = ArgminOperL2Constrained(y, M, MH, Omega, OmegaH, Lambdahat, xinit, lagmult, params); | 405 #[xhat, analysis_repr, lagmult] = ArgminOperL2Constrained(y, M, MH, Omega, OmegaH, Lambdahat, xinit, lagmult, params); |
405 xhat,analysis_repr,lagmult = ArgminOperL2Constrained(y, M, MH, Omega, OmegaH, Lambdahat, xinit, lagmult, params) | 406 xhat,analysis_repr,lagmult = ArgminOperL2Constrained(y, M, MH, Omega, OmegaH, Lambdahat, xinit, lagmult, params) |
406 #[to_be_removed, maxcoef] = FindRowsToRemove(analysis_repr, params.greedy_level); | 407 #[to_be_removed, maxcoef] = FindRowsToRemove(analysis_repr, params.greedy_level); |
407 to_be_removed, maxcoef = FindRowsToRemove(analysis_repr, params["greedy_level"]) | 408 to_be_removed,maxcoef = FindRowsToRemove(analysis_repr, params["greedy_level"]) |
408 #disp(['** maxcoef=', num2str(maxcoef), ' target=', num2str(params.stopping_coefficient_size), ' rows_remaining=', num2str(length(Lambdahat)), ' lagmult=', num2str(lagmult)]); | 409 #disp(['** maxcoef=', num2str(maxcoef), ' target=', num2str(params.stopping_coefficient_size), ' rows_remaining=', num2str(length(Lambdahat)), ' lagmult=', num2str(lagmult)]); |
409 #print '** maxcoef=',maxcoef,' target=',params['stopping_coefficient_size'],' rows_remaining=',Lambdahat.size,' lagmult=',lagmult | 410 #print '** maxcoef=',maxcoef,' target=',params['stopping_coefficient_size'],' rows_remaining=',Lambdahat.size,' lagmult=',lagmult |
410 if check_stopping_criteria(xhat, xinit, maxcoef, lagmult, Lambdahat, params): | 411 if check_stopping_criteria(xhat, xinit, maxcoef, lagmult, Lambdahat, params): |
411 break | 412 break |
412 | 413 |
413 xinit = xhat.copy() | 414 xinit = xhat.copy() |
414 #Lambdahat[to_be_removed] = [] | 415 #Lambdahat[to_be_removed] = [] |
415 # TODO: find what why squeeze() is needed here!! | 416 Lambdahat = numpy.delete(Lambdahat.squeeze(),to_be_removed) |
416 if len(to_be_removed) != 0: | |
417 Lambdahat = np.delete(Lambdahat.squeeze(),to_be_removed) | |
418 | 417 |
419 #n = sqrt(d); | 418 #n = sqrt(d); |
420 #figure(9); | 419 #figure(9); |
421 #RR = zeros(2*n, n-1); | 420 #RR = zeros(2*n, n-1); |
422 #RR(Lambdahat) = 1; | 421 #RR(Lambdahat) = 1; |
429 #imshow(XD); | 428 #imshow(XD); |
430 #figure(10); | 429 #figure(10); |
431 #imshow(reshape(real(xhat), n, n)); | 430 #imshow(reshape(real(xhat), n, n)); |
432 | 431 |
433 #return; | 432 #return; |
434 return xhat, Lambdahat | 433 return xhat,Lambdahat |
435 | 434 |
436 def FindRowsToRemove(analysis_repr, greedy_level): | 435 def FindRowsToRemove(analysis_repr, greedy_level): |
437 #function [to_be_removed, maxcoef] = FindRowsToRemove(analysis_repr, greedy_level) | 436 #function [to_be_removed, maxcoef] = FindRowsToRemove(analysis_repr, greedy_level) |
438 | 437 |
439 #abscoef = abs(analysis_repr(:)); | 438 #abscoef = abs(analysis_repr(:)); |
440 abscoef = np.abs(analysis_repr) | 439 abscoef = numpy.abs(analysis_repr) |
441 #n = length(abscoef); | 440 #n = length(abscoef); |
442 n = abscoef.size | 441 n = abscoef.size |
443 #maxcoef = max(abscoef); | 442 #maxcoef = max(abscoef); |
444 maxcoef = abscoef.max() | 443 maxcoef = abscoef.max() |
445 if greedy_level >= 1: | 444 if greedy_level >= 1: |
447 qq = sp.stats.mstats.mquantile(abscoef, 1.0-greedy_level/n, 0.5, 0.5) | 446 qq = sp.stats.mstats.mquantile(abscoef, 1.0-greedy_level/n, 0.5, 0.5) |
448 else: | 447 else: |
449 qq = maxcoef*greedy_level | 448 qq = maxcoef*greedy_level |
450 | 449 |
451 #to_be_removed = find(abscoef >= qq); | 450 #to_be_removed = find(abscoef >= qq); |
452 to_be_removed = np.nonzero(abscoef >= qq) | 451 # [0] needed because nonzero() returns a tuple of arrays! |
452 to_be_removed = numpy.nonzero(abscoef >= qq)[0] | |
453 #return; | 453 #return; |
454 return to_be_removed, maxcoef | 454 return to_be_removed,maxcoef |
455 | 455 |
456 def check_stopping_criteria(xhat, xinit, maxcoef, lagmult, Lambdahat, params): | 456 def check_stopping_criteria(xhat, xinit, maxcoef, lagmult, Lambdahat, params): |
457 #function r = check_stopping_criteria(xhat, xinit, maxcoef, lagmult, Lambdahat, params) | 457 #function r = check_stopping_criteria(xhat, xinit, maxcoef, lagmult, Lambdahat, params) |
458 | 458 |
459 #if isfield(params, 'stopping_coefficient_size') && maxcoef < params.stopping_coefficient_size | 459 #if isfield(params, 'stopping_coefficient_size') && maxcoef < params.stopping_coefficient_size |
463 #if isfield(params, 'stopping_lagrange_multiplier_size') && lagmult > params.stopping_lagrange_multiplier_size | 463 #if isfield(params, 'stopping_lagrange_multiplier_size') && lagmult > params.stopping_lagrange_multiplier_size |
464 if ('stopping_lagrange_multiplier_size' in params) and lagmult > params['stopping_lagrange_multiplier_size']: | 464 if ('stopping_lagrange_multiplier_size' in params) and lagmult > params['stopping_lagrange_multiplier_size']: |
465 return 1 | 465 return 1 |
466 | 466 |
467 #if isfield(params, 'stopping_relative_solution_change') && norm(xhat-xinit)/norm(xhat) < params.stopping_relative_solution_change | 467 #if isfield(params, 'stopping_relative_solution_change') && norm(xhat-xinit)/norm(xhat) < params.stopping_relative_solution_change |
468 if ('stopping_relative_solution_change' in params) and np.linalg.norm(xhat-xinit)/np.linalg.norm(xhat) < params['stopping_relative_solution_change']: | 468 if ('stopping_relative_solution_change' in params) and numpy.linalg.norm(xhat-xinit)/numpy.linalg.norm(xhat) < params['stopping_relative_solution_change']: |
469 return 1 | 469 return 1 |
470 | 470 |
471 #if isfield(params, 'stopping_cosparsity') && length(Lambdahat) < params.stopping_cosparsity | 471 #if isfield(params, 'stopping_cosparsity') && length(Lambdahat) < params.stopping_cosparsity |
472 if ('stopping_cosparsity' in params) and Lambdahat.size() < params['stopping_cosparsity']: | 472 if ('stopping_cosparsity' in params) and Lambdahat.size() < params['stopping_cosparsity']: |
473 return 1 | 473 return 1 |