Mercurial > hg > pycsalgos
changeset 48:f6824da67b51
Added NESTA algotuple and NestaError class
author | nikcleju |
---|---|
date | Thu, 01 Dec 2011 00:22:41 +0000 |
parents | e6a5f2173015 |
children | 04f8ae8d3eef |
files | pyCSalgos/NESTA/NESTA.py scripts/ABSapprox.py |
diffstat | 2 files changed, 38 insertions(+), 17 deletions(-) [+] |
line wrap: on
line diff
--- a/pyCSalgos/NESTA/NESTA.py Wed Nov 30 23:41:46 2011 +0000 +++ b/pyCSalgos/NESTA/NESTA.py Thu Dec 01 00:22:41 2011 +0000 @@ -8,6 +8,9 @@ import numpy import math +class NestaError(Exception): + pass + #function [xk,niter,residuals,outputData,opts] =NESTA(A,At,b,muf,delta,opts) def NESTA(A,At,b,muf,delta,opts=None): # [xk,niter,residuals,outputData] =NESTA(A,At,b,muf,delta,opts) @@ -353,7 +356,7 @@ if numpy.linalg.norm(AAtz - z) / numpy.linalg.norm(z) > 1e-8: #error('Measurement matrix A must be a partial isometry: AA''=I'); print 'Measurement matrix A must be a partial isometry: AA''=I' - raise + raise NestaError('Measurement matrix A must be a partial isometry: AA''=I') #end #end @@ -373,7 +376,7 @@ if delta > 0 and USV is None: #error('delta must be zero for non-projections'); print 'delta must be zero for non-projections' - raise + raise NesteError('delta must be zero for non-projections') #end #if isa(AAtinv,'function_handle') if hasattr(AAtinv,'__call__'): @@ -421,7 +424,7 @@ elif TypeMin.lower() == 'tv': #mu0 = ValMUTv(Ux_ref); print 'Nic: TODO: not implemented yet' - raise + raise NestaError('Nic: TODO: not implemented yet') # -- If U was set by the user and normU not supplied, then calcuate norm(U) #if U_userSet && isempty(normU) @@ -567,7 +570,7 @@ #printf('Variable #s is #f, should be at least #f\n',... # field,var,mn); error('variable out-of-bounds'); print 'Variable',field,'is',var,', should be at least',mn - raise + raise NestaError('setOpts error: value too small') #end #end #if nargin >= 4 && ~isempty(mx) @@ -576,7 +579,7 @@ #printf('Variable #s is #f, should be at least #f\n',... # field,var,mn); error('variable out-of-bounds'); print 'Variable',field,'is',var,', should be at most',mx - raise + raise NestaError('setOpts error: value too large') #end #end #opts.(field) = var; @@ -1095,7 +1098,7 @@ if delta < 0: print 'delta must be greater or equal to zero' - raise + raise NestaError('delta must be greater or equal to zero') if hasattr(A,'__call__'): Atfun = At; @@ -1179,7 +1182,7 @@ #---- TV Minimization if TypeMin == 'TV': print 'Nic:TODO: TV minimization not yet implemented!' - raise + raise NestaError('Nic:TODO: TV minimization not yet implemented!') #if strcmpi(TypeMin,'TV') # Lmu = 8*Lmu; # Dv = spdiags([reshape([-ones(n-1,n); zeros(1,n)],N,1) ... @@ -1279,7 +1282,7 @@ #if lambdaY > 0, disp('lambda is positive!'); keyboard; end if lambdaY > 0: print 'lambda is positive!' - raise + raise NestaError('lambda is positive!') yk = xk + projection; Ayk = Afun(yk); # DEBUGGING @@ -1384,7 +1387,7 @@ projection,projIter,lambdaZ = fastProjection(Q,S,V,dfp,bp,deltap, .999*lambdaZ ) if lambdaZ > 0: print 'lambda is positive!' - raise + raise NestaError('lambda is positive!') zk = projection.copy(); # zk = SLmu1*projection; Azk = Afun(zk); @@ -1442,7 +1445,7 @@ if abs(fx)>1e20 or abs(residuals[k,0]) >1e20 or numpy.isnan(fx): #error('Nesta: possible divergence or NaN. Bad estimate of ||A''A||?'); print 'Nesta: possible divergence or NaN. Bad estimate of ||A''A||?' - raise + raise NestaError('Nesta: possible divergence or NaN. Bad estimate of ||A''A||?') #end #end
--- a/scripts/ABSapprox.py Wed Nov 30 23:41:46 2011 +0000 +++ b/scripts/ABSapprox.py Thu Dec 01 00:22:41 2011 +0000 @@ -17,6 +17,7 @@ import pyCSalgos.SL0.SL0_approx import pyCSalgos.OMP.omp_QR import pyCSalgos.RecomTST.RecommendedTST +import pyCSalgos.NESTA.NESTA #========================== # Algorithm functions @@ -56,6 +57,19 @@ L = 10 return np.dot(D , pyCSalgos.SL0.SL0_approx.SL0_approx_analysis(Aeps,Aexact,y,epsilon,sigmamin,sigma_decrease_factor,mu_0,L)) +def run_nesta(y,M,Omega,epsilon): + + U,S,V = np.linalg.svd(M, full_matrices = True) + V = V.T # Make like Matlab + m,n = M.shape # Make like Matlab + S = np.hstack((np.diag(S), np.zeros((m,n-m)))) + + opt_muf = 1e-3 + optsUSV = {'U':U, 'S':S, 'V':V} + opts = {'U':Omega, 'Ut':Omega.T.copy(), 'USV':optsUSV, 'TolVar':1e-5, 'Verbose':0} + return pyCSalgos.NESTA.NESTA.NESTA(M, None, y, opt_muf, epsilon, opts)[0] + + def run_sl0(y,M,Omega,D,U,S,Vt,epsilon,lbd): N,n = Omega.shape @@ -121,6 +135,7 @@ sl0 = (run_sl0, 'SL0a') sl0analysis = (run_sl0_analysis, 'SL0a2') bpanalysis = (run_bp_analysis, 'BPa2') +nesta = (run_nesta, 'NESTA') bp = (run_bp, 'BP') ompeps = (run_ompeps, 'OMPeps') tst = (run_tst, 'TST') @@ -147,18 +162,19 @@ # Useful for short testing def stdtest(): # Define which algorithms to run - algosN = gap, # tuple of algorithms not depending on lambda - algosL = sl0,bp # tuple of algorithms depending on lambda (our ABS approach) + algosN = nesta, # tuple of algorithms not depending on lambda + #algosL = sl0,bp # tuple of algorithms depending on lambda (our ABS approach) + algosL = () d = 50.0 sigma = 2.0 - #deltas = np.array([0.05, 0.45, 0.95]) - #rhos = np.array([0.05, 0.45, 0.95]) + deltas = np.array([0.05, 0.45, 0.95]) + rhos = np.array([0.05, 0.45, 0.95]) #deltas = np.array([0.95]) - deltas = np.arange(0.05,1.,0.05) - rhos = np.array([0.05]) + #deltas = np.arange(0.05,1.,0.05) + #rhos = np.array([0.05]) numvects = 10; # Number of vectors to generate - SNRdb = 7.; # This is norm(signal)/norm(noise), so power, not energy + SNRdb = 20.; # This is norm(signal)/norm(noise), so power, not energy # Values for lambda #lambdas = [0 10.^linspace(-5, 4, 10)]; lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000]) @@ -504,6 +520,8 @@ xrec[strname][:,iy] = algofunc(y[:,iy],M,Omega,epsilon) except pyCSalgos.BP.l1qec.l1qecInputValueError as e: print "Caught exception when running algorithm",strname," :",e.message + except pyCSalgos.NESTA.NESTA.NestaError as e: + print "Caught exception when running algorithm",strname," :",e.message err[strname][iy] = np.linalg.norm(x0[:,iy] - xrec[strname][:,iy]) relerr[strname][iy] = err[strname][iy] / np.linalg.norm(x0[:,iy]) for algofunc,strname in algosN: