comparison scripts/ABSapprox.py @ 30:5f46ff51c7ff

Added console output detailing parameters Prepared for complete run 10.11.2011
author nikcleju
date Thu, 10 Nov 2011 22:43:11 +0000
parents bc2a96a03b0a
children 829bf04c92af
comparison
equal deleted inserted replaced
29:bc2a96a03b0a 30:5f46ff51c7ff
10 import math 10 import math
11 11
12 import pyCSalgos 12 import pyCSalgos
13 import pyCSalgos.GAP.GAP 13 import pyCSalgos.GAP.GAP
14 import pyCSalgos.SL0.SL0_approx 14 import pyCSalgos.SL0.SL0_approx
15 import pyCSalgos.OMP.omp_QR
16 import pyCSalgos.RecomTST.RecommendedTST
15 17
16 #========================== 18 #==========================
17 # Algorithm functions 19 # Algorithm functions
18 #========================== 20 #==========================
19 def run_gap(y,M,Omega,epsilon): 21 def run_gap(y,M,Omega,epsilon):
54 sigma_decrease_factor = 0.5 56 sigma_decrease_factor = 0.5
55 mu_0 = 2 57 mu_0 = 2
56 L = 10 58 L = 10
57 return pyCSalgos.SL0.SL0_approx.SL0_approx(aggD,aggy,epsilon,sigmamin,sigma_decrease_factor,mu_0,L) 59 return pyCSalgos.SL0.SL0_approx.SL0_approx(aggD,aggy,epsilon,sigmamin,sigma_decrease_factor,mu_0,L)
58 60
61 def run_ompeps(y,M,Omega,D,U,S,Vt,epsilon,lbd):
62
63 N,n = Omega.shape
64 #D = np.linalg.pinv(Omega)
65 #U,S,Vt = np.linalg.svd(D)
66 aggDupper = np.dot(M,D)
67 aggDlower = Vt[-(N-n):,:]
68 aggD = np.concatenate((aggDupper, lbd * aggDlower))
69 aggy = np.concatenate((y, np.zeros(N-n)))
70
71 opts = dict()
72 opts['stopCrit'] = 'mse'
73 opts['stopTol'] = epsilon**2 / aggy.size
74 return pyCSalgos.OMP.omp_QR.greed_omp_qr(aggy,aggD,aggD.shape[1],opts)[0]
75
76 def run_tst(y,M,Omega,D,U,S,Vt,epsilon,lbd):
77
78 N,n = Omega.shape
79 #D = np.linalg.pinv(Omega)
80 #U,S,Vt = np.linalg.svd(D)
81 aggDupper = np.dot(M,D)
82 aggDlower = Vt[-(N-n):,:]
83 aggD = np.concatenate((aggDupper, lbd * aggDlower))
84 aggy = np.concatenate((y, np.zeros(N-n)))
85
86 return pyCSalgos.RecomTST.RecommendedTST.RecommendedTST(aggD, aggy, nsweep=3000, tol=epsilon / np.linalg.norm(aggy))
87
59 #========================== 88 #==========================
60 # Define tuples (algorithm function, name) 89 # Define tuples (algorithm function, name)
61 #========================== 90 #==========================
62 gap = (run_gap, 'GAP') 91 gap = (run_gap, 'GAP')
63 sl0 = (run_sl0, 'SL0_approx') 92 sl0 = (run_sl0, 'SL0a')
64 bp = (run_bp, 'BP') 93 bp = (run_bp, 'BP')
94 ompeps = (run_ompeps, 'OMPeps')
95 tst = (run_tst, 'TST')
65 96
66 # Define which algorithms to run 97 # Define which algorithms to run
67 # 1. Algorithms not depending on lambda 98 # 1. Algorithms not depending on lambda
68 algosN = gap, # tuple 99 algosN = gap, # tuple
69 # 2. Algorithms depending on lambda (our ABS approach) 100 # 2. Algorithms depending on lambda (our ABS approach)
70 algosL = sl0, # tuple 101 algosL = sl0,bp,ompeps,tst # tuple
71 102
72 #========================== 103 #==========================
73 # Interface functions 104 # Interface functions
74 #========================== 105 #==========================
75 def run_multiproc(ncpus=None): 106 def run_multiproc(ncpus=None):
76 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname = standard_params() 107 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params()
77 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\ 108 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
78 doparallel=True, ncpus=ncpus) 109 doparallel=True, ncpus=ncpus,\
110 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
79 111
80 def run(): 112 def run():
81 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname = standard_params() 113 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params()
82 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\ 114 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
83 doparallel=False) 115 doparallel=False,\
116 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
84 117
85 def standard_params(): 118 def standard_params():
86 #Set up standard experiment parameters 119 #Set up standard experiment parameters
87 d = 50.0; 120 d = 50.0;
88 sigma = 2.0 121 sigma = 2.0
89 #deltas = np.arange(0.05,1.,0.05) 122 #deltas = np.arange(0.05,1.,0.05)
90 #rhos = np.arange(0.05,1.,0.05) 123 #rhos = np.arange(0.05,1.,0.05)
91 deltas = np.array([0.05, 0.45, 0.95]) 124 #deltas = np.array([0.05, 0.45, 0.95])
92 rhos = np.array([0.05, 0.45, 0.95]) 125 #rhos = np.array([0.05, 0.45, 0.95])
93 #deltas = np.array([0.05]) 126 deltas = np.array([0.05])
94 #rhos = np.array([0.05]) 127 rhos = np.array([0.05])
95 #delta = 0.8; 128 #delta = 0.8;
96 #rho = 0.15; 129 #rho = 0.15;
97 numvects = 100; # Number of vectors to generate 130 numvects = 100; # Number of vectors to generate
98 SNRdb = 20.; # This is norm(signal)/norm(noise), so power, not energy 131 SNRdb = 20.; # This is norm(signal)/norm(noise), so power, not energy
99 # Values for lambda 132 # Values for lambda
100 #lambdas = [0 10.^linspace(-5, 4, 10)]; 133 #lambdas = [0 10.^linspace(-5, 4, 10)];
101 #lambdas = np.concatenate((np.array([0]), 10**np.linspace(-5, 4, 10))) 134 #lambdas = np.concatenate((np.array([0]), 10**np.linspace(-5, 4, 10)))
102 lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000]) 135 lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000])
103 136
104 dosavedata = True 137 dosavedata = True
105 savedataname = 'ABSapprox.mat' 138 savedataname = 'approx_pt_std1.mat'
139
140 doshowplot = False
141 dosaveplot = True
142 saveplotbase = 'approx_pt_std1_'
143 saveplotexts = ('png','pdf','eps')
106 144
107 145
108 return d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname 146 return d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,\
147 doshowplot,dosaveplot,saveplotbase,saveplotexts
109 148
110 #========================== 149 #==========================
111 # Main functions 150 # Main functions
112 #========================== 151 #==========================
113 def run_multi(algosN, algosL, d, sigma, deltas, rhos, lambdas, numvects, SNRdb, 152 def run_multi(algosN, algosL, d, sigma, deltas, rhos, lambdas, numvects, SNRdb,
114 doparallel=False, ncpus=None,\ 153 doparallel=False, ncpus=None,\
115 doshowplot=False, dosaveplot=False, saveplotbase=None, saveplotexts=None,\ 154 doshowplot=False, dosaveplot=False, saveplotbase=None, saveplotexts=None,\
116 dosavedata=False, savedataname=None): 155 dosavedata=False, savedataname=None):
156
157 print "This is analysis recovery ABS approximation script by Nic"
158 print "Running phase transition ( run_multi() )"
117 159
118 if doparallel: 160 if doparallel:
119 from multiprocessing import Pool 161 from multiprocessing import Pool
120 162
121 # TODO: load different engine for matplotlib that allows saving without showing 163 if dosaveplot or doshowplot:
122 try: 164 try:
123 import matplotlib.pyplot as plt 165 import matplotlib
124 except: 166 if doshowplot:
125 dosaveplot = False 167 print "Importing matplotlib with default (GUI) backend... ",
126 doshowplot = False 168 else:
127 if dosaveplot and doshowplot: 169 print "Importing matplotlib with \"Cairo\" backend... ",
128 import matplotlib.cm as cm 170 matplotlib.use('Cairo')
171 import matplotlib.pyplot as plt
172 import matplotlib.cm as cm
173 print "OK"
174 except:
175 print "FAIL"
176 print "Importing matplotlib.pyplot failed. No figures at all"
177 print "Try selecting a different backend"
178 doshowplot = False
179 dosaveplot = False
180
181 # Print summary of parameters
182 print "Parameters:"
183 if doparallel:
184 if ncpus is None:
185 print " Running in parallel with default threads using \"multiprocessing\" package"
186 else:
187 print " Running in parallel with",ncpus,"threads using \"multiprocessing\" package"
188 else:
189 print "Running single thread"
190 if doshowplot:
191 print " Showing figures"
192 else:
193 print " Not showing figures"
194 if dosaveplot:
195 print " Saving figures as "+saveplotbase+"* with extensions ",saveplotexts
196 else:
197 print " Not saving figures"
198 print " Running algorithms",[algotuple[1] for algotuple in algosN],[algotuple[1] for algotuple in algosL]
129 199
130 nalgosN = len(algosN) 200 nalgosN = len(algosN)
131 nalgosL = len(algosL) 201 nalgosL = len(algosL)
132 202
133 meanmatrix = dict() 203 meanmatrix = dict()
136 for i,algo in zip(np.arange(nalgosL),algosL): 206 for i,algo in zip(np.arange(nalgosL),algosL):
137 meanmatrix[algo[1]] = np.zeros((lambdas.size, rhos.size, deltas.size)) 207 meanmatrix[algo[1]] = np.zeros((lambdas.size, rhos.size, deltas.size))
138 208
139 # Prepare parameters 209 # Prepare parameters
140 jobparams = [] 210 jobparams = []
211 print " (delta, rho) pairs to be run:"
141 for idelta,delta in zip(np.arange(deltas.size),deltas): 212 for idelta,delta in zip(np.arange(deltas.size),deltas):
142 for irho,rho in zip(np.arange(rhos.size),rhos): 213 for irho,rho in zip(np.arange(rhos.size),rhos):
143 214
144 # Generate data and operator 215 # Generate data and operator
145 Omega,x0,y,M,realnoise = generateData(d,sigma,delta,rho,numvects,SNRdb) 216 Omega,x0,y,M,realnoise = generateData(d,sigma,delta,rho,numvects,SNRdb)
146 217
147 #Save the parameters, and run after 218 #Save the parameters, and run after
148 print "***** delta = ",delta," rho = ",rho 219 print " delta = ",delta," rho = ",rho
149 jobparams.append((algosN,algosL, Omega,y,lambdas,realnoise,M,x0)) 220 jobparams.append((algosN,algosL, Omega,y,lambdas,realnoise,M,x0))
150 221
222 print "End of parameters"
223
151 # Run 224 # Run
152 jobresults = [] 225 jobresults = []
153 if doparallel: 226 if doparallel:
154 pool = Pool(4) 227 pool = Pool(4)
155 jobresults = pool.map(run_once_tuple,jobparams) 228 jobresults = pool.map(run_once_tuple,jobparams)
209 for ilbd in np.arange(lambdas.size): 282 for ilbd in np.arange(lambdas.size):
210 plt.figure() 283 plt.figure()
211 plt.imshow(meanmatrix[algoname][ilbd], cmap=cm.gray, interpolation='nearest',origin='lower') 284 plt.imshow(meanmatrix[algoname][ilbd], cmap=cm.gray, interpolation='nearest',origin='lower')
212 if dosaveplot: 285 if dosaveplot:
213 for ext in saveplotexts: 286 for ext in saveplotexts:
214 plt.savefig(saveplotbase + algoname + lambdas[ilbd] + '.' + ext) 287 plt.savefig(saveplotbase + algoname + ('_lbd%.0e' % lambdas[ilbd]) + '.' + ext)
215 if doshowplot: 288 if doshowplot:
216 plt.show() 289 plt.show()
217 290
218 print "Finished." 291 print "Finished."
219 292