nikcleju@14
|
1 # -*- coding: utf-8 -*-
|
nikcleju@14
|
2 """
|
nikcleju@17
|
3 Old version, ignore it
|
nikcleju@14
|
4
|
nikcleju@17
|
5 Author: Nicolae Cleju
|
nikcleju@14
|
6 """
|
nikcleju@14
|
7
|
nikcleju@14
|
8 import numpy
|
nikcleju@14
|
9 import scipy.io
|
nikcleju@14
|
10 import math
|
nikcleju@14
|
11 import os
|
nikcleju@14
|
12 import time
|
nikcleju@14
|
13
|
nikcleju@14
|
14 import multiprocessing
|
nikcleju@14
|
15 import sys
|
nikcleju@14
|
16 _currmodule = sys.modules[__name__]
|
nikcleju@14
|
17 # Lock for printing in a thread-safe way
|
nikcleju@14
|
18 _printLock = None
|
nikcleju@14
|
19
|
nikcleju@14
|
20 import stdparams_exact
|
nikcleju@14
|
21 import AnalysisGenerate
|
nikcleju@14
|
22
|
nikcleju@14
|
23 # For exceptions
|
nikcleju@14
|
24 import pyCSalgos.BP.l1eq_pd
|
nikcleju@14
|
25 import pyCSalgos.NESTA.NESTA
|
nikcleju@14
|
26
|
nikcleju@14
|
27
|
nikcleju@14
|
28 def _initProcess(share, njobs, printLock):
|
nikcleju@14
|
29 """
|
nikcleju@14
|
30 Pool initializer function (multiprocessing)
|
nikcleju@14
|
31 Needed to pass the shared variable to the worker processes
|
nikcleju@14
|
32 The variables must be global in the module in order to be seen later in run_once_tuple()
|
nikcleju@14
|
33 see http://stackoverflow.com/questions/1675766/how-to-combine-pool-map-with-array-shared-memory-in-python-multiprocessing
|
nikcleju@14
|
34 """
|
nikcleju@14
|
35 currmodule = sys.modules[__name__]
|
nikcleju@14
|
36 currmodule.proccount = share
|
nikcleju@14
|
37 currmodule.njobs = njobs
|
nikcleju@14
|
38 currmodule._printLock = printLock
|
nikcleju@14
|
39
|
nikcleju@14
|
40 #==========================
|
nikcleju@14
|
41 # Interface run functions
|
nikcleju@14
|
42 #==========================
|
nikcleju@14
|
43 def run(std=stdparams_exact.std1,ncpus=None):
|
nikcleju@14
|
44
|
nikcleju@14
|
45 algos,d,sigma,deltas,rhos,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = std()
|
nikcleju@14
|
46 run_multi(algos, d,sigma,deltas,rhos,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
|
nikcleju@14
|
47 ncpus=ncpus,\
|
nikcleju@14
|
48 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
|
nikcleju@14
|
49
|
nikcleju@14
|
50 #==========================
|
nikcleju@14
|
51 # Main functions
|
nikcleju@14
|
52 #==========================
|
nikcleju@14
|
53 def run_multi(algos, d, sigma, deltas, rhos, numvects, SNRdb, ncpus=None,\
|
nikcleju@14
|
54 doshowplot=False, dosaveplot=False, saveplotbase=None, saveplotexts=None,\
|
nikcleju@14
|
55 dosavedata=False, savedataname=None):
|
nikcleju@14
|
56
|
nikcleju@14
|
57 print "This is analysis recovery ABS approximation script by Nic"
|
nikcleju@14
|
58 print "Running phase transition ( run_multi() )"
|
nikcleju@14
|
59
|
nikcleju@14
|
60 if ncpus is None:
|
nikcleju@14
|
61 print " Running in parallel with default",multiprocessing.cpu_count(),"threads using \"multiprocessing\" package"
|
nikcleju@14
|
62 if multiprocessing.cpu_count() == 1:
|
nikcleju@14
|
63 doparallel = False
|
nikcleju@14
|
64 else:
|
nikcleju@14
|
65 doparallel = True
|
nikcleju@14
|
66 elif ncpus > 1:
|
nikcleju@14
|
67 print " Running in parallel with",ncpus,"threads using \"multiprocessing\" package"
|
nikcleju@14
|
68 doparallel = True
|
nikcleju@14
|
69 elif ncpus == 1:
|
nikcleju@14
|
70 print "Running single thread"
|
nikcleju@14
|
71 doparallel = False
|
nikcleju@14
|
72 else:
|
nikcleju@14
|
73 print "Wrong number of threads, exiting"
|
nikcleju@14
|
74 return
|
nikcleju@14
|
75
|
nikcleju@14
|
76 if dosaveplot or doshowplot:
|
nikcleju@14
|
77 try:
|
nikcleju@14
|
78 import matplotlib
|
nikcleju@14
|
79 if doshowplot or os.name == 'nt':
|
nikcleju@14
|
80 print "Importing matplotlib with default (GUI) backend... ",
|
nikcleju@14
|
81 else:
|
nikcleju@14
|
82 print "Importing matplotlib with \"Cairo\" backend... ",
|
nikcleju@14
|
83 matplotlib.use('Cairo')
|
nikcleju@14
|
84 import matplotlib.pyplot as plt
|
nikcleju@14
|
85 import matplotlib.cm as cm
|
nikcleju@14
|
86 import matplotlib.colors as mcolors
|
nikcleju@14
|
87 print "OK"
|
nikcleju@14
|
88 except:
|
nikcleju@14
|
89 print "FAIL"
|
nikcleju@14
|
90 print "Importing matplotlib.pyplot failed. No figures at all"
|
nikcleju@14
|
91 print "Try selecting a different backend"
|
nikcleju@14
|
92 doshowplot = False
|
nikcleju@14
|
93 dosaveplot = False
|
nikcleju@14
|
94
|
nikcleju@14
|
95 # Print summary of parameters
|
nikcleju@14
|
96 print "Parameters:"
|
nikcleju@14
|
97 if doshowplot:
|
nikcleju@14
|
98 print " Showing figures"
|
nikcleju@14
|
99 else:
|
nikcleju@14
|
100 print " Not showing figures"
|
nikcleju@14
|
101 if dosaveplot:
|
nikcleju@14
|
102 print " Saving figures as "+saveplotbase+"* with extensions ",saveplotexts
|
nikcleju@14
|
103 else:
|
nikcleju@14
|
104 print " Not saving figures"
|
nikcleju@14
|
105 print " Running algorithms",[algotuple[1] for algotuple in algos]
|
nikcleju@14
|
106
|
nikcleju@14
|
107 nalgos = len(algos)
|
nikcleju@14
|
108
|
nikcleju@14
|
109 meanmatrix = dict()
|
nikcleju@14
|
110 elapsed = dict()
|
nikcleju@14
|
111 for i,algo in zip(numpy.arange(nalgos),algos):
|
nikcleju@14
|
112 meanmatrix[algo[1]] = numpy.zeros((rhos.size, deltas.size))
|
nikcleju@14
|
113 elapsed[algo[1]] = 0
|
nikcleju@14
|
114
|
nikcleju@14
|
115 # Prepare parameters
|
nikcleju@14
|
116 jobparams = []
|
nikcleju@14
|
117 print " (delta, rho) pairs to be run:"
|
nikcleju@14
|
118 for idelta,delta in zip(numpy.arange(deltas.size),deltas):
|
nikcleju@14
|
119 for irho,rho in zip(numpy.arange(rhos.size),rhos):
|
nikcleju@14
|
120
|
nikcleju@14
|
121 # Generate data and operator
|
nikcleju@14
|
122 Omega,x0,y,M,realnoise = generateData(d,sigma,delta,rho,numvects,SNRdb)
|
nikcleju@14
|
123
|
nikcleju@14
|
124 #Save the parameters, and run after
|
nikcleju@14
|
125 print " delta = ",delta," rho = ",rho
|
nikcleju@14
|
126 jobparams.append((algos,Omega,y,M,x0))
|
nikcleju@14
|
127
|
nikcleju@14
|
128 print "End of parameters"
|
nikcleju@14
|
129
|
nikcleju@14
|
130 _currmodule.njobs = len(jobparams)
|
nikcleju@14
|
131 # Thread-safe variable to store number of finished jobs
|
nikcleju@14
|
132 _currmodule.proccount = multiprocessing.Value('I', 0) # 'I' = unsigned int, see docs (multiprocessing, array)
|
nikcleju@14
|
133
|
nikcleju@14
|
134 # Run
|
nikcleju@14
|
135 jobresults = []
|
nikcleju@14
|
136
|
nikcleju@14
|
137 if doparallel:
|
nikcleju@14
|
138 _currmodule._printLock = multiprocessing.Lock()
|
nikcleju@14
|
139 pool = multiprocessing.Pool(ncpus,initializer=_initProcess,initargs=(_currmodule.proccount,_currmodule.njobs,_currmodule._printLock))
|
nikcleju@14
|
140 jobresults = pool.map(run_once_tuple, jobparams)
|
nikcleju@14
|
141 else:
|
nikcleju@14
|
142 for jobparam in jobparams:
|
nikcleju@14
|
143 jobresults.append(run_once_tuple(jobparam))
|
nikcleju@14
|
144
|
nikcleju@14
|
145 # Read results
|
nikcleju@14
|
146 idx = 0
|
nikcleju@14
|
147 for idelta,delta in zip(numpy.arange(deltas.size),deltas):
|
nikcleju@14
|
148 for irho,rho in zip(numpy.arange(rhos.size),rhos):
|
nikcleju@14
|
149 mrelerr,addelapsed = jobresults[idx]
|
nikcleju@14
|
150 idx = idx+1
|
nikcleju@14
|
151 for algotuple in algos:
|
nikcleju@14
|
152 meanmatrix[algotuple[1]][irho,idelta] = mrelerr[algotuple[1]]
|
nikcleju@14
|
153 if meanmatrix[algotuple[1]][irho,idelta] < 0 or math.isnan(meanmatrix[algotuple[1]][irho,idelta]):
|
nikcleju@14
|
154 meanmatrix[algotuple[1]][irho,idelta] = 0
|
nikcleju@14
|
155 elapsed[algotuple[1]] = elapsed[algotuple[1]] + addelapsed[algotuple[1]]
|
nikcleju@14
|
156
|
nikcleju@14
|
157 # Save
|
nikcleju@14
|
158 if dosavedata:
|
nikcleju@14
|
159 tosave = dict()
|
nikcleju@14
|
160 tosave['meanmatrix'] = meanmatrix
|
nikcleju@14
|
161 tosave['elapsed'] = elapsed
|
nikcleju@14
|
162 tosave['d'] = d
|
nikcleju@14
|
163 tosave['sigma'] = sigma
|
nikcleju@14
|
164 tosave['deltas'] = deltas
|
nikcleju@14
|
165 tosave['rhos'] = rhos
|
nikcleju@14
|
166 tosave['numvects'] = numvects
|
nikcleju@14
|
167 tosave['SNRdb'] = SNRdb
|
nikcleju@14
|
168 # Save algo names as cell array
|
nikcleju@14
|
169 obj_arr = numpy.zeros((len(algos),), dtype=numpy.object)
|
nikcleju@14
|
170 idx = 0
|
nikcleju@14
|
171 for algotuple in algos:
|
nikcleju@14
|
172 obj_arr[idx] = algotuple[1]
|
nikcleju@14
|
173 idx = idx+1
|
nikcleju@14
|
174 tosave['algonames'] = obj_arr
|
nikcleju@14
|
175 try:
|
nikcleju@14
|
176 scipy.io.savemat(savedataname, tosave)
|
nikcleju@14
|
177 except:
|
nikcleju@14
|
178 print "Save error"
|
nikcleju@14
|
179 # Show
|
nikcleju@14
|
180 if doshowplot or dosaveplot:
|
nikcleju@14
|
181 for algotuple in algos:
|
nikcleju@14
|
182 algoname = algotuple[1]
|
nikcleju@14
|
183 plt.figure()
|
nikcleju@14
|
184 plt.imshow(meanmatrix[algoname], cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower')
|
nikcleju@14
|
185 if dosaveplot:
|
nikcleju@14
|
186 for ext in saveplotexts:
|
nikcleju@14
|
187 plt.savefig(saveplotbase + algoname + '.' + ext, bbox_inches='tight')
|
nikcleju@14
|
188 if doshowplot:
|
nikcleju@14
|
189 plt.show()
|
nikcleju@14
|
190
|
nikcleju@14
|
191 print "Finished."
|
nikcleju@14
|
192
|
nikcleju@14
|
193 def run_once_tuple(t):
|
nikcleju@14
|
194 results = run_once(*t)
|
nikcleju@14
|
195
|
nikcleju@14
|
196 if _currmodule._printLock:
|
nikcleju@14
|
197 _currmodule._printLock.acquire()
|
nikcleju@14
|
198
|
nikcleju@14
|
199 _currmodule.proccount.value = _currmodule.proccount.value + 1
|
nikcleju@14
|
200 print "================================"
|
nikcleju@14
|
201 print "Finished job",_currmodule.proccount.value,"/",_currmodule.njobs,"jobs remaining",_currmodule.njobs - _currmodule.proccount.value,"/",_currmodule.njobs
|
nikcleju@14
|
202 print "================================"
|
nikcleju@14
|
203
|
nikcleju@14
|
204 _currmodule._printLock.release()
|
nikcleju@14
|
205
|
nikcleju@14
|
206 return results
|
nikcleju@14
|
207
|
nikcleju@14
|
208 def run_once(algos,Omega,y,M,x0):
|
nikcleju@14
|
209
|
nikcleju@14
|
210 d = Omega.shape[1]
|
nikcleju@14
|
211
|
nikcleju@14
|
212 nalgos = len(algos)
|
nikcleju@14
|
213
|
nikcleju@14
|
214 xrec = dict()
|
nikcleju@14
|
215 err = dict()
|
nikcleju@14
|
216 relerr = dict()
|
nikcleju@14
|
217 elapsed = dict()
|
nikcleju@14
|
218
|
nikcleju@14
|
219 # Prepare storage variables for algorithms
|
nikcleju@14
|
220 for i,algo in zip(numpy.arange(nalgos),algos):
|
nikcleju@14
|
221 xrec[algo[1]] = numpy.zeros((d, y.shape[1]))
|
nikcleju@14
|
222 err[algo[1]] = numpy.zeros(y.shape[1])
|
nikcleju@14
|
223 relerr[algo[1]] = numpy.zeros(y.shape[1])
|
nikcleju@14
|
224 elapsed[algo[1]] = 0
|
nikcleju@14
|
225
|
nikcleju@14
|
226 # Run algorithms
|
nikcleju@14
|
227 for iy in numpy.arange(y.shape[1]):
|
nikcleju@14
|
228 for algofunc,strname in algos:
|
nikcleju@14
|
229 try:
|
nikcleju@14
|
230 timestart = time.time()
|
nikcleju@14
|
231 xrec[strname][:,iy] = algofunc(y[:,iy],M,Omega)
|
nikcleju@14
|
232 elapsed[strname] = elapsed[strname] + (time.time() - timestart)
|
nikcleju@14
|
233 except pyCSalgos.BP.l1eq_pd.l1eqNotImplementedError as e:
|
nikcleju@14
|
234 if _currmodule._printLock:
|
nikcleju@14
|
235 _currmodule._printLock.acquire()
|
nikcleju@14
|
236 print "Caught exception when running algorithm",strname," :",e.message
|
nikcleju@14
|
237 _currmodule._printLock.release()
|
nikcleju@14
|
238 err[strname][iy] = numpy.linalg.norm(x0[:,iy] - xrec[strname][:,iy])
|
nikcleju@14
|
239 relerr[strname][iy] = err[strname][iy] / numpy.linalg.norm(x0[:,iy])
|
nikcleju@14
|
240 for algofunc,strname in algos:
|
nikcleju@14
|
241 if _currmodule._printLock:
|
nikcleju@14
|
242 _currmodule._printLock.acquire()
|
nikcleju@14
|
243 print strname,' : avg relative error = ',numpy.mean(relerr[strname])
|
nikcleju@14
|
244 _currmodule._printLock.release()
|
nikcleju@14
|
245
|
nikcleju@14
|
246 # Prepare results
|
nikcleju@14
|
247 #mrelerr = dict()
|
nikcleju@14
|
248 #for algotuple in algos:
|
nikcleju@14
|
249 # mrelerr[algotuple[1]] = numpy.mean(relerr[algotuple[1]])
|
nikcleju@14
|
250 #return mrelerr,elapsed
|
nikcleju@14
|
251
|
nikcleju@14
|
252 # Should return number of reconstructions with error < threshold, not average error
|
nikcleju@14
|
253 exactthr = 1e-6
|
nikcleju@14
|
254 mrelerr = dict()
|
nikcleju@14
|
255 for algotuple in algos:
|
nikcleju@14
|
256 mrelerr[algotuple[1]] = float(numpy.count_nonzero(relerr[algotuple[1]] < exactthr)) / y.shape[1]
|
nikcleju@14
|
257 return mrelerr,elapsed
|
nikcleju@14
|
258
|
nikcleju@14
|
259
|
nikcleju@14
|
260 def generateData(d,sigma,delta,rho,numvects,SNRdb):
|
nikcleju@14
|
261
|
nikcleju@14
|
262 # Process parameters
|
nikcleju@14
|
263 noiselevel = 1.0 / (10.0**(SNRdb/10.0));
|
nikcleju@14
|
264 p = round(sigma*d);
|
nikcleju@14
|
265 m = round(delta*d);
|
nikcleju@14
|
266 l = round(d - rho*m);
|
nikcleju@14
|
267
|
nikcleju@14
|
268 # Generate Omega and data based on parameters
|
nikcleju@14
|
269 Omega = AnalysisGenerate.Generate_Analysis_Operator(d, p);
|
nikcleju@14
|
270 # Optionally make Omega more coherent
|
nikcleju@14
|
271 #U,S,Vt = numpy.linalg.svd(Omega);
|
nikcleju@14
|
272 #Sdnew = S * (1+numpy.arange(S.size)) # Make D coherent, not Omega!
|
nikcleju@14
|
273 #Snew = numpy.vstack((numpy.diag(Sdnew), numpy.zeros((Omega.shape[0] - Omega.shape[1], Omega.shape[1]))))
|
nikcleju@14
|
274 #Omega = numpy.dot(U , numpy.dot(Snew,Vt))
|
nikcleju@14
|
275
|
nikcleju@14
|
276 # Generate data
|
nikcleju@14
|
277 x0,y,M,Lambda,realnoise = AnalysisGenerate.Generate_Data_Known_Omega(Omega, d,p,m,l,noiselevel, numvects,'l0');
|
nikcleju@14
|
278
|
nikcleju@14
|
279 return Omega,x0,y,M,realnoise
|
nikcleju@14
|
280
|
nikcleju@14
|
281
|
nikcleju@14
|
282 def testMatlab():
|
nikcleju@14
|
283 mdict = scipy.io.loadmat("E:\\CS\\Ale mele\\Analysis_ExactRec\\temp.mat")
|
nikcleju@14
|
284 algos = stdparams_exact.std1()[0]
|
nikcleju@14
|
285 res = run_once(algos, mdict['Omega'].byteswap().newbyteorder(),mdict['y'],mdict['M'],mdict['x0'])
|
nikcleju@14
|
286
|
nikcleju@14
|
287 def generateFig():
|
nikcleju@14
|
288 run(stdparams_exact.std1)
|
nikcleju@14
|
289
|
nikcleju@14
|
290 # Script main
|
nikcleju@14
|
291 if __name__ == "__main__":
|
nikcleju@14
|
292 #import cProfile
|
nikcleju@14
|
293 #cProfile.run('mainrun()', 'profile')
|
nikcleju@14
|
294 #run_mp(stdparams_exact.stdtest)
|
nikcleju@14
|
295 #runsingleexampledebug()
|
nikcleju@14
|
296
|
nikcleju@14
|
297 run(stdparams_exact.std1, ncpus=3)
|
nikcleju@14
|
298 #testMatlab()
|
nikcleju@14
|
299 #run(stdparams_exact.stdtest, ncpus=1)
|