Mercurial > hg > pycsalgos
comparison scripts/ABSapprox.py @ 30:5f46ff51c7ff
Added console output detailing parameters
Prepared for complete run
10.11.2011
author | nikcleju |
---|---|
date | Thu, 10 Nov 2011 22:43:11 +0000 |
parents | bc2a96a03b0a |
children | 829bf04c92af |
comparison
equal
deleted
inserted
replaced
29:bc2a96a03b0a | 30:5f46ff51c7ff |
---|---|
10 import math | 10 import math |
11 | 11 |
12 import pyCSalgos | 12 import pyCSalgos |
13 import pyCSalgos.GAP.GAP | 13 import pyCSalgos.GAP.GAP |
14 import pyCSalgos.SL0.SL0_approx | 14 import pyCSalgos.SL0.SL0_approx |
15 import pyCSalgos.OMP.omp_QR | |
16 import pyCSalgos.RecomTST.RecommendedTST | |
15 | 17 |
16 #========================== | 18 #========================== |
17 # Algorithm functions | 19 # Algorithm functions |
18 #========================== | 20 #========================== |
19 def run_gap(y,M,Omega,epsilon): | 21 def run_gap(y,M,Omega,epsilon): |
54 sigma_decrease_factor = 0.5 | 56 sigma_decrease_factor = 0.5 |
55 mu_0 = 2 | 57 mu_0 = 2 |
56 L = 10 | 58 L = 10 |
57 return pyCSalgos.SL0.SL0_approx.SL0_approx(aggD,aggy,epsilon,sigmamin,sigma_decrease_factor,mu_0,L) | 59 return pyCSalgos.SL0.SL0_approx.SL0_approx(aggD,aggy,epsilon,sigmamin,sigma_decrease_factor,mu_0,L) |
58 | 60 |
61 def run_ompeps(y,M,Omega,D,U,S,Vt,epsilon,lbd): | |
62 | |
63 N,n = Omega.shape | |
64 #D = np.linalg.pinv(Omega) | |
65 #U,S,Vt = np.linalg.svd(D) | |
66 aggDupper = np.dot(M,D) | |
67 aggDlower = Vt[-(N-n):,:] | |
68 aggD = np.concatenate((aggDupper, lbd * aggDlower)) | |
69 aggy = np.concatenate((y, np.zeros(N-n))) | |
70 | |
71 opts = dict() | |
72 opts['stopCrit'] = 'mse' | |
73 opts['stopTol'] = epsilon**2 / aggy.size | |
74 return pyCSalgos.OMP.omp_QR.greed_omp_qr(aggy,aggD,aggD.shape[1],opts)[0] | |
75 | |
76 def run_tst(y,M,Omega,D,U,S,Vt,epsilon,lbd): | |
77 | |
78 N,n = Omega.shape | |
79 #D = np.linalg.pinv(Omega) | |
80 #U,S,Vt = np.linalg.svd(D) | |
81 aggDupper = np.dot(M,D) | |
82 aggDlower = Vt[-(N-n):,:] | |
83 aggD = np.concatenate((aggDupper, lbd * aggDlower)) | |
84 aggy = np.concatenate((y, np.zeros(N-n))) | |
85 | |
86 return pyCSalgos.RecomTST.RecommendedTST.RecommendedTST(aggD, aggy, nsweep=3000, tol=epsilon / np.linalg.norm(aggy)) | |
87 | |
59 #========================== | 88 #========================== |
60 # Define tuples (algorithm function, name) | 89 # Define tuples (algorithm function, name) |
61 #========================== | 90 #========================== |
62 gap = (run_gap, 'GAP') | 91 gap = (run_gap, 'GAP') |
63 sl0 = (run_sl0, 'SL0_approx') | 92 sl0 = (run_sl0, 'SL0a') |
64 bp = (run_bp, 'BP') | 93 bp = (run_bp, 'BP') |
94 ompeps = (run_ompeps, 'OMPeps') | |
95 tst = (run_tst, 'TST') | |
65 | 96 |
66 # Define which algorithms to run | 97 # Define which algorithms to run |
67 # 1. Algorithms not depending on lambda | 98 # 1. Algorithms not depending on lambda |
68 algosN = gap, # tuple | 99 algosN = gap, # tuple |
69 # 2. Algorithms depending on lambda (our ABS approach) | 100 # 2. Algorithms depending on lambda (our ABS approach) |
70 algosL = sl0, # tuple | 101 algosL = sl0,bp,ompeps,tst # tuple |
71 | 102 |
72 #========================== | 103 #========================== |
73 # Interface functions | 104 # Interface functions |
74 #========================== | 105 #========================== |
75 def run_multiproc(ncpus=None): | 106 def run_multiproc(ncpus=None): |
76 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname = standard_params() | 107 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params() |
77 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\ | 108 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\ |
78 doparallel=True, ncpus=ncpus) | 109 doparallel=True, ncpus=ncpus,\ |
110 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts) | |
79 | 111 |
80 def run(): | 112 def run(): |
81 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname = standard_params() | 113 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params() |
82 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\ | 114 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\ |
83 doparallel=False) | 115 doparallel=False,\ |
116 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts) | |
84 | 117 |
85 def standard_params(): | 118 def standard_params(): |
86 #Set up standard experiment parameters | 119 #Set up standard experiment parameters |
87 d = 50.0; | 120 d = 50.0; |
88 sigma = 2.0 | 121 sigma = 2.0 |
89 #deltas = np.arange(0.05,1.,0.05) | 122 #deltas = np.arange(0.05,1.,0.05) |
90 #rhos = np.arange(0.05,1.,0.05) | 123 #rhos = np.arange(0.05,1.,0.05) |
91 deltas = np.array([0.05, 0.45, 0.95]) | 124 #deltas = np.array([0.05, 0.45, 0.95]) |
92 rhos = np.array([0.05, 0.45, 0.95]) | 125 #rhos = np.array([0.05, 0.45, 0.95]) |
93 #deltas = np.array([0.05]) | 126 deltas = np.array([0.05]) |
94 #rhos = np.array([0.05]) | 127 rhos = np.array([0.05]) |
95 #delta = 0.8; | 128 #delta = 0.8; |
96 #rho = 0.15; | 129 #rho = 0.15; |
97 numvects = 100; # Number of vectors to generate | 130 numvects = 100; # Number of vectors to generate |
98 SNRdb = 20.; # This is norm(signal)/norm(noise), so power, not energy | 131 SNRdb = 20.; # This is norm(signal)/norm(noise), so power, not energy |
99 # Values for lambda | 132 # Values for lambda |
100 #lambdas = [0 10.^linspace(-5, 4, 10)]; | 133 #lambdas = [0 10.^linspace(-5, 4, 10)]; |
101 #lambdas = np.concatenate((np.array([0]), 10**np.linspace(-5, 4, 10))) | 134 #lambdas = np.concatenate((np.array([0]), 10**np.linspace(-5, 4, 10))) |
102 lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000]) | 135 lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000]) |
103 | 136 |
104 dosavedata = True | 137 dosavedata = True |
105 savedataname = 'ABSapprox.mat' | 138 savedataname = 'approx_pt_std1.mat' |
139 | |
140 doshowplot = False | |
141 dosaveplot = True | |
142 saveplotbase = 'approx_pt_std1_' | |
143 saveplotexts = ('png','pdf','eps') | |
106 | 144 |
107 | 145 |
108 return d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname | 146 return d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,\ |
147 doshowplot,dosaveplot,saveplotbase,saveplotexts | |
109 | 148 |
110 #========================== | 149 #========================== |
111 # Main functions | 150 # Main functions |
112 #========================== | 151 #========================== |
113 def run_multi(algosN, algosL, d, sigma, deltas, rhos, lambdas, numvects, SNRdb, | 152 def run_multi(algosN, algosL, d, sigma, deltas, rhos, lambdas, numvects, SNRdb, |
114 doparallel=False, ncpus=None,\ | 153 doparallel=False, ncpus=None,\ |
115 doshowplot=False, dosaveplot=False, saveplotbase=None, saveplotexts=None,\ | 154 doshowplot=False, dosaveplot=False, saveplotbase=None, saveplotexts=None,\ |
116 dosavedata=False, savedataname=None): | 155 dosavedata=False, savedataname=None): |
156 | |
157 print "This is analysis recovery ABS approximation script by Nic" | |
158 print "Running phase transition ( run_multi() )" | |
117 | 159 |
118 if doparallel: | 160 if doparallel: |
119 from multiprocessing import Pool | 161 from multiprocessing import Pool |
120 | 162 |
121 # TODO: load different engine for matplotlib that allows saving without showing | 163 if dosaveplot or doshowplot: |
122 try: | 164 try: |
123 import matplotlib.pyplot as plt | 165 import matplotlib |
124 except: | 166 if doshowplot: |
125 dosaveplot = False | 167 print "Importing matplotlib with default (GUI) backend... ", |
126 doshowplot = False | 168 else: |
127 if dosaveplot and doshowplot: | 169 print "Importing matplotlib with \"Cairo\" backend... ", |
128 import matplotlib.cm as cm | 170 matplotlib.use('Cairo') |
171 import matplotlib.pyplot as plt | |
172 import matplotlib.cm as cm | |
173 print "OK" | |
174 except: | |
175 print "FAIL" | |
176 print "Importing matplotlib.pyplot failed. No figures at all" | |
177 print "Try selecting a different backend" | |
178 doshowplot = False | |
179 dosaveplot = False | |
180 | |
181 # Print summary of parameters | |
182 print "Parameters:" | |
183 if doparallel: | |
184 if ncpus is None: | |
185 print " Running in parallel with default threads using \"multiprocessing\" package" | |
186 else: | |
187 print " Running in parallel with",ncpus,"threads using \"multiprocessing\" package" | |
188 else: | |
189 print "Running single thread" | |
190 if doshowplot: | |
191 print " Showing figures" | |
192 else: | |
193 print " Not showing figures" | |
194 if dosaveplot: | |
195 print " Saving figures as "+saveplotbase+"* with extensions ",saveplotexts | |
196 else: | |
197 print " Not saving figures" | |
198 print " Running algorithms",[algotuple[1] for algotuple in algosN],[algotuple[1] for algotuple in algosL] | |
129 | 199 |
130 nalgosN = len(algosN) | 200 nalgosN = len(algosN) |
131 nalgosL = len(algosL) | 201 nalgosL = len(algosL) |
132 | 202 |
133 meanmatrix = dict() | 203 meanmatrix = dict() |
136 for i,algo in zip(np.arange(nalgosL),algosL): | 206 for i,algo in zip(np.arange(nalgosL),algosL): |
137 meanmatrix[algo[1]] = np.zeros((lambdas.size, rhos.size, deltas.size)) | 207 meanmatrix[algo[1]] = np.zeros((lambdas.size, rhos.size, deltas.size)) |
138 | 208 |
139 # Prepare parameters | 209 # Prepare parameters |
140 jobparams = [] | 210 jobparams = [] |
211 print " (delta, rho) pairs to be run:" | |
141 for idelta,delta in zip(np.arange(deltas.size),deltas): | 212 for idelta,delta in zip(np.arange(deltas.size),deltas): |
142 for irho,rho in zip(np.arange(rhos.size),rhos): | 213 for irho,rho in zip(np.arange(rhos.size),rhos): |
143 | 214 |
144 # Generate data and operator | 215 # Generate data and operator |
145 Omega,x0,y,M,realnoise = generateData(d,sigma,delta,rho,numvects,SNRdb) | 216 Omega,x0,y,M,realnoise = generateData(d,sigma,delta,rho,numvects,SNRdb) |
146 | 217 |
147 #Save the parameters, and run after | 218 #Save the parameters, and run after |
148 print "***** delta = ",delta," rho = ",rho | 219 print " delta = ",delta," rho = ",rho |
149 jobparams.append((algosN,algosL, Omega,y,lambdas,realnoise,M,x0)) | 220 jobparams.append((algosN,algosL, Omega,y,lambdas,realnoise,M,x0)) |
150 | 221 |
222 print "End of parameters" | |
223 | |
151 # Run | 224 # Run |
152 jobresults = [] | 225 jobresults = [] |
153 if doparallel: | 226 if doparallel: |
154 pool = Pool(4) | 227 pool = Pool(4) |
155 jobresults = pool.map(run_once_tuple,jobparams) | 228 jobresults = pool.map(run_once_tuple,jobparams) |
209 for ilbd in np.arange(lambdas.size): | 282 for ilbd in np.arange(lambdas.size): |
210 plt.figure() | 283 plt.figure() |
211 plt.imshow(meanmatrix[algoname][ilbd], cmap=cm.gray, interpolation='nearest',origin='lower') | 284 plt.imshow(meanmatrix[algoname][ilbd], cmap=cm.gray, interpolation='nearest',origin='lower') |
212 if dosaveplot: | 285 if dosaveplot: |
213 for ext in saveplotexts: | 286 for ext in saveplotexts: |
214 plt.savefig(saveplotbase + algoname + lambdas[ilbd] + '.' + ext) | 287 plt.savefig(saveplotbase + algoname + ('_lbd%.0e' % lambdas[ilbd]) + '.' + ext) |
215 if doshowplot: | 288 if doshowplot: |
216 plt.show() | 289 plt.show() |
217 | 290 |
218 print "Finished." | 291 print "Finished." |
219 | 292 |