Mercurial > hg > absrec
comparison test_exact_v2.py @ 15:a27cfe83fe12
Changing, changing, trying to get a common framework for batch jobs
author | Nic Cleju <nikcleju@gmail.com> |
---|---|
date | Tue, 20 Mar 2012 17:18:23 +0200 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
14:f2eb027ed101 | 15:a27cfe83fe12 |
---|---|
1 # -*- coding: utf-8 -*- | |
2 """ | |
3 Created on Sat Nov 05 18:08:40 2011 | |
4 | |
5 @author: Nic | |
6 """ | |
7 | |
8 import numpy | |
9 import scipy.io | |
10 import math | |
11 import os | |
12 import time | |
13 | |
14 try: | |
15 import matplotlib | |
16 if os.name == 'nt': | |
17 print "Importing matplotlib with default (GUI) backend... " | |
18 else: | |
19 print "Importing matplotlib with \"Cairo\" backend... " | |
20 matplotlib.use('Cairo') | |
21 import matplotlib.pyplot as plt | |
22 import matplotlib.cm as cm | |
23 import matplotlib.colors as mcolors | |
24 except: | |
25 print "FAIL" | |
26 print "Importing matplotlib.pyplot failed. No figures at all" | |
27 print "Try selecting a different backend" | |
28 | |
29 import multiprocessing | |
30 import sys | |
31 currmodule = sys.modules[__name__] | |
32 # Lock for printing in a thread-safe way | |
33 printLock = None | |
34 # Thread-safe variable to store number of finished tasks | |
35 currmodule.proccount = multiprocessing.Value('I', 0) # 'I' = unsigned int, see docs (multiprocessing, array) | |
36 | |
37 import stdparams_exact | |
38 import AnalysisGenerate | |
39 | |
40 # For exceptions | |
41 import pyCSalgos.BP.l1eq_pd | |
42 import pyCSalgos.NESTA.NESTA | |
43 | |
44 | |
45 def initProcess(share, ntasks, printLock): | |
46 """ | |
47 Pool initializer function (multiprocessing) | |
48 Needed to pass the shared variable to the worker processes | |
49 The variables must be global in the module in order to be seen later in run_once_tuple() | |
50 see http://stackoverflow.com/questions/1675766/how-to-combine-pool-map-with-array-shared-memory-in-python-multiprocessing | |
51 """ | |
52 currmodule = sys.modules[__name__] | |
53 currmodule.proccount = share | |
54 currmodule.ntasks = ntasks | |
55 currmodule._printLock = printLock | |
56 | |
57 | |
58 def generateTaskParams(globalparams): | |
59 """ | |
60 Generate a list of task parameters | |
61 """ | |
62 taskparams = [] | |
63 SNRdb = globalparams['SNRdb'] | |
64 sigma = globalparams['sigma'] | |
65 d = globalparams['d'] | |
66 deltas = globalparams['deltas'] | |
67 rhos = globalparams['rhos'] | |
68 numvects = globalparams['numvects'] | |
69 algos = globalparams['algos'] | |
70 | |
71 # Process parameters | |
72 noiselevel = 1.0 / (10.0**(SNRdb/10.0)); | |
73 | |
74 for delta in deltas: | |
75 for rho in rhos: | |
76 p = round(sigma*d); | |
77 m = round(delta*d); | |
78 l = round(d - rho*m); | |
79 | |
80 # Generate Omega and data based on parameters | |
81 Omega = AnalysisGenerate.Generate_Analysis_Operator(d, p); | |
82 # Optionally make Omega more coherent | |
83 #U,S,Vt = numpy.linalg.svd(Omega); | |
84 #Sdnew = S * (1+numpy.arange(S.size)) # Make D coherent, not Omega! | |
85 #Snew = numpy.vstack((numpy.diag(Sdnew), numpy.zeros((Omega.shape[0] - Omega.shape[1], Omega.shape[1])))) | |
86 #Omega = numpy.dot(U , numpy.dot(Snew,Vt)) | |
87 | |
88 # Generate data | |
89 x0,y,M,Lambda,realnoise = AnalysisGenerate.Generate_Data_Known_Omega(Omega, d,p,m,l,noiselevel, numvects,'l0') | |
90 | |
91 # Append task params | |
92 taskparams.append((algos,Omega,y,M,x0)) | |
93 | |
94 return taskparams | |
95 | |
96 def processResults(params, taskresults): | |
97 """ | |
98 Process the raw task results | |
99 """ | |
100 deltas = params['deltas'] | |
101 rhos = params['rhos'] | |
102 algos = params['algos'] | |
103 | |
104 # Init results | |
105 meanmatrix = dict() | |
106 elapsed = dict() | |
107 for algo in algos: | |
108 meanmatrix[algo[1]] = numpy.zeros((rhos.size, deltas.size)) | |
109 elapsed[algo[1]] = 0 | |
110 | |
111 # Process results | |
112 idx = 0 | |
113 for idelta,delta in zip(numpy.arange(deltas.size),deltas): | |
114 for irho,rho in zip(numpy.arange(rhos.size),rhos): | |
115 mrelerr,addelapsed = taskresults[idx] | |
116 idx = idx+1 | |
117 for algotuple in algos: | |
118 meanmatrix[algotuple[1]][irho,idelta] = mrelerr[algotuple[1]] | |
119 if meanmatrix[algotuple[1]][irho,idelta] < 0 or math.isnan(meanmatrix[algotuple[1]][irho,idelta]): | |
120 meanmatrix[algotuple[1]][irho,idelta] = 0 | |
121 elapsed[algotuple[1]] = elapsed[algotuple[1]] + addelapsed[algotuple[1]] | |
122 | |
123 procresults = dict() | |
124 procresults['meanmatrix'] = meanmatrix | |
125 procresults['elapsed'] = elapsed | |
126 return procresults | |
127 | |
128 def saveSim(params, procresults): | |
129 """ | |
130 Save simulation to mat file | |
131 """ | |
132 #tosaveparams = ['d','sigma','deltas','rhos','numvects','SNRdb'] | |
133 #tosaveprocresults = ['meanmatrix','elapsed'] | |
134 | |
135 tosave = dict() | |
136 tosave['meanmatrix'] = procresults['meanmatrix'] | |
137 tosave['elapsed'] = procresults['elapsed'] | |
138 tosave['d'] = params['d'] | |
139 tosave['sigma'] = params['sigma'] | |
140 tosave['deltas'] = params['deltas'] | |
141 tosave['rhos'] = params['rhos'] | |
142 tosave['numvects'] = params['numvects'] | |
143 tosave['SNRdb'] = params['SNRdb'] | |
144 tosave['saveplotbase'] = params['saveplotbase'] | |
145 tosave['saveplotexts'] = params['saveplotexts'] | |
146 # Save algo names as cell array | |
147 obj_arr = numpy.zeros((len(params['algos']),), dtype=numpy.object) | |
148 idx = 0 | |
149 for algotuple in params['algos']: | |
150 obj_arr[idx] = algotuple[1] | |
151 idx = idx+1 | |
152 tosave['algonames'] = obj_arr | |
153 try: | |
154 scipy.io.savemat(params['savedataname'], tosave) | |
155 except: | |
156 print "Save error" | |
157 | |
158 def loadSim(savedataname): | |
159 """ | |
160 Load simulation from mat file | |
161 """ | |
162 mdict = scipy.io.loadmat(savedataname) | |
163 | |
164 params = dict() | |
165 procresults = dict() | |
166 | |
167 procresults['meanmatrix'] = mdict['meanmatrix'][0,0] | |
168 procresults['elapsed'] = mdict['elapsed'] | |
169 params['d'] = mdict['d'] | |
170 params['sigma'] = mdict['sigma'] | |
171 params['deltas'] = mdict['deltas'] | |
172 params['rhos'] = mdict['rhos'] | |
173 params['numvects'] = mdict['numvects'] | |
174 params['SNRdb'] = mdict['SNRdb'] | |
175 params['saveplotbase'] = mdict['saveplotbase'][0] | |
176 params['saveplotexts'] = mdict['saveplotexts'] | |
177 | |
178 algonames = mdict['algonames'][:,0] | |
179 params['algonames'] = [] | |
180 for algoname in algonames: | |
181 params['algonames'].append(algoname[0]) | |
182 | |
183 return params, procresults | |
184 | |
185 def plot(savedataname): | |
186 """ | |
187 Plot results | |
188 """ | |
189 params, procresults = loadSim(savedataname) | |
190 meanmatrix = procresults['meanmatrix'] | |
191 saveplotbase = params['saveplotbase'] | |
192 saveplotexts = params['saveplotexts'] | |
193 algonames = params['algonames'] | |
194 | |
195 for algoname in algonames: | |
196 plt.figure() | |
197 plt.imshow(meanmatrix[algoname], cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower') | |
198 for ext in saveplotexts: | |
199 plt.savefig(saveplotbase + algoname + '.' + ext, bbox_inches='tight') | |
200 #plt.show() | |
201 | |
202 #========================== | |
203 # Main function | |
204 #========================== | |
205 def run(params): | |
206 """ | |
207 Run with given parameters | |
208 """ | |
209 | |
210 print "This is analysis recovery ABS approximation script by Nic" | |
211 print "Running phase transition ( run_multi() )" | |
212 | |
213 algos = params['algos'] | |
214 d = params['d'] | |
215 sigma = params['sigma'] | |
216 deltas = params['deltas'] | |
217 rhos = params['rhos'] | |
218 numvects = params['numvects'] | |
219 SNRdb = params['SNRdb'] | |
220 ncpus = params['ncpus'] | |
221 savedataname = params['savedataname'] | |
222 | |
223 if ncpus is None: | |
224 print " Running in parallel with default",multiprocessing.cpu_count(),"threads using \"multiprocessing\" package" | |
225 if multiprocessing.cpu_count() == 1: | |
226 doparallel = False | |
227 else: | |
228 doparallel = True | |
229 elif ncpus > 1: | |
230 print " Running in parallel with",ncpus,"threads using \"multiprocessing\" package" | |
231 doparallel = True | |
232 elif ncpus == 1: | |
233 print "Running single thread" | |
234 doparallel = False | |
235 else: | |
236 print "Wrong number of threads, exiting" | |
237 return | |
238 | |
239 # Print summary of parameters | |
240 print "Parameters:" | |
241 print " Running algorithms",[algotuple[1] for algotuple in algos] | |
242 | |
243 # Prepare parameters | |
244 taskparams = generateTaskParams(params) | |
245 | |
246 # Store global variables | |
247 currmodule.ntasks = len(taskparams) | |
248 | |
249 # Run | |
250 taskresults = [] | |
251 if doparallel: | |
252 currmodule.printLock = multiprocessing.Lock() | |
253 pool = multiprocessing.Pool(ncpus,initializer=initProcess,initargs=(currmodule.proccount,currmodule.ntasks,currmodule.printLock)) | |
254 taskresults = pool.map(run_once_tuple, taskparams) | |
255 else: | |
256 for taskparam in taskparams: | |
257 taskresults.append(run_once_tuple(taskparam)) | |
258 | |
259 # Process results | |
260 procresults = processResults(params, taskresults) | |
261 | |
262 # Save | |
263 saveSim(params, procresults) | |
264 | |
265 print "Finished." | |
266 | |
267 def run_once_tuple(t): | |
268 results = run_once(*t) | |
269 | |
270 if currmodule.printLock: | |
271 currmodule.printLock.acquire() | |
272 | |
273 currmodule.proccount.value = currmodule.proccount.value + 1 | |
274 print "================================" | |
275 print "Finished task",currmodule.proccount.value,"/",currmodule.ntasks,"tasks remaining",currmodule.ntasks - currmodule.proccount.value,"/",currmodule.ntasks | |
276 print "================================" | |
277 | |
278 currmodule.printLock.release() | |
279 | |
280 return results | |
281 | |
282 def run_once(algos,Omega,y,M,x0): | |
283 """ | |
284 Run single task (task function) | |
285 """ | |
286 | |
287 d = Omega.shape[1] | |
288 | |
289 nalgos = len(algos) | |
290 | |
291 xrec = dict() | |
292 err = dict() | |
293 relerr = dict() | |
294 elapsed = dict() | |
295 | |
296 # Prepare storage variables for algorithms | |
297 for i,algo in zip(numpy.arange(nalgos),algos): | |
298 xrec[algo[1]] = numpy.zeros((d, y.shape[1])) | |
299 err[algo[1]] = numpy.zeros(y.shape[1]) | |
300 relerr[algo[1]] = numpy.zeros(y.shape[1]) | |
301 elapsed[algo[1]] = 0 | |
302 | |
303 # Run algorithms | |
304 for iy in numpy.arange(y.shape[1]): | |
305 for algofunc,strname in algos: | |
306 try: | |
307 timestart = time.time() | |
308 xrec[strname][:,iy] = algofunc(y[:,iy],M,Omega) | |
309 elapsed[strname] = elapsed[strname] + (time.time() - timestart) | |
310 except pyCSalgos.BP.l1eq_pd.l1eqNotImplementedError as e: | |
311 if currmodule.printLock: | |
312 currmodule.printLock.acquire() | |
313 print "Caught exception when running algorithm",strname," :",e.message | |
314 currmodule.printLock.release() | |
315 err[strname][iy] = numpy.linalg.norm(x0[:,iy] - xrec[strname][:,iy]) | |
316 relerr[strname][iy] = err[strname][iy] / numpy.linalg.norm(x0[:,iy]) | |
317 for algofunc,strname in algos: | |
318 if currmodule.printLock: | |
319 currmodule.printLock.acquire() | |
320 print strname,' : avg relative error = ',numpy.mean(relerr[strname]) | |
321 currmodule.printLock.release() | |
322 | |
323 # Prepare results | |
324 #mrelerr = dict() | |
325 #for algotuple in algos: | |
326 # mrelerr[algotuple[1]] = numpy.mean(relerr[algotuple[1]]) | |
327 #return mrelerr,elapsed | |
328 | |
329 # Should return number of reconstructions with error < threshold, not average error | |
330 exactthr = 1e-6 | |
331 mrelerr = dict() | |
332 for algotuple in algos: | |
333 mrelerr[algotuple[1]] = float(numpy.count_nonzero(relerr[algotuple[1]] < exactthr)) / y.shape[1] | |
334 return mrelerr,elapsed | |
335 | |
336 | |
337 def testMatlab(): | |
338 mdict = scipy.io.loadmat("E:\\CS\\Ale mele\\Analysis_ExactRec\\temp.mat") | |
339 algos = stdparams_exact.std1()[0] | |
340 res = run_once(algos, mdict['Omega'].byteswap().newbyteorder(),mdict['y'],mdict['M'],mdict['x0']) | |
341 | |
342 def generateFig(): | |
343 run(stdparams_exact.std1) | |
344 | |
345 # Script main | |
346 if __name__ == "__main__": | |
347 #import cProfile | |
348 #cProfile.run('mainrun()', 'profile') | |
349 #run_mp(stdparams_exact.stdtest) | |
350 #runsingleexampledebug() | |
351 | |
352 stdparams_exact.paramstest['ncpus'] = 1 | |
353 run(stdparams_exact.paramstest) | |
354 plot(stdparams_exact.paramstest['savedataname']) |