annotate test_exact_old.py @ 21:d395461b92ae tip

Lots and lots of modifications. Approximate recovery script working.
author Nic Cleju <nikcleju@gmail.com>
date Mon, 23 Apr 2012 10:54:57 +0300
parents 7fdf964f4edd
children
rev   line source
nikcleju@14 1 # -*- coding: utf-8 -*-
nikcleju@14 2 """
nikcleju@17 3 Old version, ignore it
nikcleju@14 4
nikcleju@17 5 Author: Nicolae Cleju
nikcleju@14 6 """
nikcleju@14 7
nikcleju@14 8 import numpy
nikcleju@14 9 import scipy.io
nikcleju@14 10 import math
nikcleju@14 11 import os
nikcleju@14 12 import time
nikcleju@14 13
nikcleju@14 14 import multiprocessing
nikcleju@14 15 import sys
nikcleju@14 16 _currmodule = sys.modules[__name__]
nikcleju@14 17 # Lock for printing in a thread-safe way
nikcleju@14 18 _printLock = None
nikcleju@14 19
nikcleju@14 20 import stdparams_exact
nikcleju@14 21 import AnalysisGenerate
nikcleju@14 22
nikcleju@14 23 # For exceptions
nikcleju@14 24 import pyCSalgos.BP.l1eq_pd
nikcleju@14 25 import pyCSalgos.NESTA.NESTA
nikcleju@14 26
nikcleju@14 27
nikcleju@14 28 def _initProcess(share, njobs, printLock):
nikcleju@14 29 """
nikcleju@14 30 Pool initializer function (multiprocessing)
nikcleju@14 31 Needed to pass the shared variable to the worker processes
nikcleju@14 32 The variables must be global in the module in order to be seen later in run_once_tuple()
nikcleju@14 33 see http://stackoverflow.com/questions/1675766/how-to-combine-pool-map-with-array-shared-memory-in-python-multiprocessing
nikcleju@14 34 """
nikcleju@14 35 currmodule = sys.modules[__name__]
nikcleju@14 36 currmodule.proccount = share
nikcleju@14 37 currmodule.njobs = njobs
nikcleju@14 38 currmodule._printLock = printLock
nikcleju@14 39
nikcleju@14 40 #==========================
nikcleju@14 41 # Interface run functions
nikcleju@14 42 #==========================
nikcleju@14 43 def run(std=stdparams_exact.std1,ncpus=None):
nikcleju@14 44
nikcleju@14 45 algos,d,sigma,deltas,rhos,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = std()
nikcleju@14 46 run_multi(algos, d,sigma,deltas,rhos,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
nikcleju@14 47 ncpus=ncpus,\
nikcleju@14 48 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
nikcleju@14 49
nikcleju@14 50 #==========================
nikcleju@14 51 # Main functions
nikcleju@14 52 #==========================
nikcleju@14 53 def run_multi(algos, d, sigma, deltas, rhos, numvects, SNRdb, ncpus=None,\
nikcleju@14 54 doshowplot=False, dosaveplot=False, saveplotbase=None, saveplotexts=None,\
nikcleju@14 55 dosavedata=False, savedataname=None):
nikcleju@14 56
nikcleju@14 57 print "This is analysis recovery ABS approximation script by Nic"
nikcleju@14 58 print "Running phase transition ( run_multi() )"
nikcleju@14 59
nikcleju@14 60 if ncpus is None:
nikcleju@14 61 print " Running in parallel with default",multiprocessing.cpu_count(),"threads using \"multiprocessing\" package"
nikcleju@14 62 if multiprocessing.cpu_count() == 1:
nikcleju@14 63 doparallel = False
nikcleju@14 64 else:
nikcleju@14 65 doparallel = True
nikcleju@14 66 elif ncpus > 1:
nikcleju@14 67 print " Running in parallel with",ncpus,"threads using \"multiprocessing\" package"
nikcleju@14 68 doparallel = True
nikcleju@14 69 elif ncpus == 1:
nikcleju@14 70 print "Running single thread"
nikcleju@14 71 doparallel = False
nikcleju@14 72 else:
nikcleju@14 73 print "Wrong number of threads, exiting"
nikcleju@14 74 return
nikcleju@14 75
nikcleju@14 76 if dosaveplot or doshowplot:
nikcleju@14 77 try:
nikcleju@14 78 import matplotlib
nikcleju@14 79 if doshowplot or os.name == 'nt':
nikcleju@14 80 print "Importing matplotlib with default (GUI) backend... ",
nikcleju@14 81 else:
nikcleju@14 82 print "Importing matplotlib with \"Cairo\" backend... ",
nikcleju@14 83 matplotlib.use('Cairo')
nikcleju@14 84 import matplotlib.pyplot as plt
nikcleju@14 85 import matplotlib.cm as cm
nikcleju@14 86 import matplotlib.colors as mcolors
nikcleju@14 87 print "OK"
nikcleju@14 88 except:
nikcleju@14 89 print "FAIL"
nikcleju@14 90 print "Importing matplotlib.pyplot failed. No figures at all"
nikcleju@14 91 print "Try selecting a different backend"
nikcleju@14 92 doshowplot = False
nikcleju@14 93 dosaveplot = False
nikcleju@14 94
nikcleju@14 95 # Print summary of parameters
nikcleju@14 96 print "Parameters:"
nikcleju@14 97 if doshowplot:
nikcleju@14 98 print " Showing figures"
nikcleju@14 99 else:
nikcleju@14 100 print " Not showing figures"
nikcleju@14 101 if dosaveplot:
nikcleju@14 102 print " Saving figures as "+saveplotbase+"* with extensions ",saveplotexts
nikcleju@14 103 else:
nikcleju@14 104 print " Not saving figures"
nikcleju@14 105 print " Running algorithms",[algotuple[1] for algotuple in algos]
nikcleju@14 106
nikcleju@14 107 nalgos = len(algos)
nikcleju@14 108
nikcleju@14 109 meanmatrix = dict()
nikcleju@14 110 elapsed = dict()
nikcleju@14 111 for i,algo in zip(numpy.arange(nalgos),algos):
nikcleju@14 112 meanmatrix[algo[1]] = numpy.zeros((rhos.size, deltas.size))
nikcleju@14 113 elapsed[algo[1]] = 0
nikcleju@14 114
nikcleju@14 115 # Prepare parameters
nikcleju@14 116 jobparams = []
nikcleju@14 117 print " (delta, rho) pairs to be run:"
nikcleju@14 118 for idelta,delta in zip(numpy.arange(deltas.size),deltas):
nikcleju@14 119 for irho,rho in zip(numpy.arange(rhos.size),rhos):
nikcleju@14 120
nikcleju@14 121 # Generate data and operator
nikcleju@14 122 Omega,x0,y,M,realnoise = generateData(d,sigma,delta,rho,numvects,SNRdb)
nikcleju@14 123
nikcleju@14 124 #Save the parameters, and run after
nikcleju@14 125 print " delta = ",delta," rho = ",rho
nikcleju@14 126 jobparams.append((algos,Omega,y,M,x0))
nikcleju@14 127
nikcleju@14 128 print "End of parameters"
nikcleju@14 129
nikcleju@14 130 _currmodule.njobs = len(jobparams)
nikcleju@14 131 # Thread-safe variable to store number of finished jobs
nikcleju@14 132 _currmodule.proccount = multiprocessing.Value('I', 0) # 'I' = unsigned int, see docs (multiprocessing, array)
nikcleju@14 133
nikcleju@14 134 # Run
nikcleju@14 135 jobresults = []
nikcleju@14 136
nikcleju@14 137 if doparallel:
nikcleju@14 138 _currmodule._printLock = multiprocessing.Lock()
nikcleju@14 139 pool = multiprocessing.Pool(ncpus,initializer=_initProcess,initargs=(_currmodule.proccount,_currmodule.njobs,_currmodule._printLock))
nikcleju@14 140 jobresults = pool.map(run_once_tuple, jobparams)
nikcleju@14 141 else:
nikcleju@14 142 for jobparam in jobparams:
nikcleju@14 143 jobresults.append(run_once_tuple(jobparam))
nikcleju@14 144
nikcleju@14 145 # Read results
nikcleju@14 146 idx = 0
nikcleju@14 147 for idelta,delta in zip(numpy.arange(deltas.size),deltas):
nikcleju@14 148 for irho,rho in zip(numpy.arange(rhos.size),rhos):
nikcleju@14 149 mrelerr,addelapsed = jobresults[idx]
nikcleju@14 150 idx = idx+1
nikcleju@14 151 for algotuple in algos:
nikcleju@14 152 meanmatrix[algotuple[1]][irho,idelta] = mrelerr[algotuple[1]]
nikcleju@14 153 if meanmatrix[algotuple[1]][irho,idelta] < 0 or math.isnan(meanmatrix[algotuple[1]][irho,idelta]):
nikcleju@14 154 meanmatrix[algotuple[1]][irho,idelta] = 0
nikcleju@14 155 elapsed[algotuple[1]] = elapsed[algotuple[1]] + addelapsed[algotuple[1]]
nikcleju@14 156
nikcleju@14 157 # Save
nikcleju@14 158 if dosavedata:
nikcleju@14 159 tosave = dict()
nikcleju@14 160 tosave['meanmatrix'] = meanmatrix
nikcleju@14 161 tosave['elapsed'] = elapsed
nikcleju@14 162 tosave['d'] = d
nikcleju@14 163 tosave['sigma'] = sigma
nikcleju@14 164 tosave['deltas'] = deltas
nikcleju@14 165 tosave['rhos'] = rhos
nikcleju@14 166 tosave['numvects'] = numvects
nikcleju@14 167 tosave['SNRdb'] = SNRdb
nikcleju@14 168 # Save algo names as cell array
nikcleju@14 169 obj_arr = numpy.zeros((len(algos),), dtype=numpy.object)
nikcleju@14 170 idx = 0
nikcleju@14 171 for algotuple in algos:
nikcleju@14 172 obj_arr[idx] = algotuple[1]
nikcleju@14 173 idx = idx+1
nikcleju@14 174 tosave['algonames'] = obj_arr
nikcleju@14 175 try:
nikcleju@14 176 scipy.io.savemat(savedataname, tosave)
nikcleju@14 177 except:
nikcleju@14 178 print "Save error"
nikcleju@14 179 # Show
nikcleju@14 180 if doshowplot or dosaveplot:
nikcleju@14 181 for algotuple in algos:
nikcleju@14 182 algoname = algotuple[1]
nikcleju@14 183 plt.figure()
nikcleju@14 184 plt.imshow(meanmatrix[algoname], cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower')
nikcleju@14 185 if dosaveplot:
nikcleju@14 186 for ext in saveplotexts:
nikcleju@14 187 plt.savefig(saveplotbase + algoname + '.' + ext, bbox_inches='tight')
nikcleju@14 188 if doshowplot:
nikcleju@14 189 plt.show()
nikcleju@14 190
nikcleju@14 191 print "Finished."
nikcleju@14 192
nikcleju@14 193 def run_once_tuple(t):
nikcleju@14 194 results = run_once(*t)
nikcleju@14 195
nikcleju@14 196 if _currmodule._printLock:
nikcleju@14 197 _currmodule._printLock.acquire()
nikcleju@14 198
nikcleju@14 199 _currmodule.proccount.value = _currmodule.proccount.value + 1
nikcleju@14 200 print "================================"
nikcleju@14 201 print "Finished job",_currmodule.proccount.value,"/",_currmodule.njobs,"jobs remaining",_currmodule.njobs - _currmodule.proccount.value,"/",_currmodule.njobs
nikcleju@14 202 print "================================"
nikcleju@14 203
nikcleju@14 204 _currmodule._printLock.release()
nikcleju@14 205
nikcleju@14 206 return results
nikcleju@14 207
nikcleju@14 208 def run_once(algos,Omega,y,M,x0):
nikcleju@14 209
nikcleju@14 210 d = Omega.shape[1]
nikcleju@14 211
nikcleju@14 212 nalgos = len(algos)
nikcleju@14 213
nikcleju@14 214 xrec = dict()
nikcleju@14 215 err = dict()
nikcleju@14 216 relerr = dict()
nikcleju@14 217 elapsed = dict()
nikcleju@14 218
nikcleju@14 219 # Prepare storage variables for algorithms
nikcleju@14 220 for i,algo in zip(numpy.arange(nalgos),algos):
nikcleju@14 221 xrec[algo[1]] = numpy.zeros((d, y.shape[1]))
nikcleju@14 222 err[algo[1]] = numpy.zeros(y.shape[1])
nikcleju@14 223 relerr[algo[1]] = numpy.zeros(y.shape[1])
nikcleju@14 224 elapsed[algo[1]] = 0
nikcleju@14 225
nikcleju@14 226 # Run algorithms
nikcleju@14 227 for iy in numpy.arange(y.shape[1]):
nikcleju@14 228 for algofunc,strname in algos:
nikcleju@14 229 try:
nikcleju@14 230 timestart = time.time()
nikcleju@14 231 xrec[strname][:,iy] = algofunc(y[:,iy],M,Omega)
nikcleju@14 232 elapsed[strname] = elapsed[strname] + (time.time() - timestart)
nikcleju@14 233 except pyCSalgos.BP.l1eq_pd.l1eqNotImplementedError as e:
nikcleju@14 234 if _currmodule._printLock:
nikcleju@14 235 _currmodule._printLock.acquire()
nikcleju@14 236 print "Caught exception when running algorithm",strname," :",e.message
nikcleju@14 237 _currmodule._printLock.release()
nikcleju@14 238 err[strname][iy] = numpy.linalg.norm(x0[:,iy] - xrec[strname][:,iy])
nikcleju@14 239 relerr[strname][iy] = err[strname][iy] / numpy.linalg.norm(x0[:,iy])
nikcleju@14 240 for algofunc,strname in algos:
nikcleju@14 241 if _currmodule._printLock:
nikcleju@14 242 _currmodule._printLock.acquire()
nikcleju@14 243 print strname,' : avg relative error = ',numpy.mean(relerr[strname])
nikcleju@14 244 _currmodule._printLock.release()
nikcleju@14 245
nikcleju@14 246 # Prepare results
nikcleju@14 247 #mrelerr = dict()
nikcleju@14 248 #for algotuple in algos:
nikcleju@14 249 # mrelerr[algotuple[1]] = numpy.mean(relerr[algotuple[1]])
nikcleju@14 250 #return mrelerr,elapsed
nikcleju@14 251
nikcleju@14 252 # Should return number of reconstructions with error < threshold, not average error
nikcleju@14 253 exactthr = 1e-6
nikcleju@14 254 mrelerr = dict()
nikcleju@14 255 for algotuple in algos:
nikcleju@14 256 mrelerr[algotuple[1]] = float(numpy.count_nonzero(relerr[algotuple[1]] < exactthr)) / y.shape[1]
nikcleju@14 257 return mrelerr,elapsed
nikcleju@14 258
nikcleju@14 259
nikcleju@14 260 def generateData(d,sigma,delta,rho,numvects,SNRdb):
nikcleju@14 261
nikcleju@14 262 # Process parameters
nikcleju@14 263 noiselevel = 1.0 / (10.0**(SNRdb/10.0));
nikcleju@14 264 p = round(sigma*d);
nikcleju@14 265 m = round(delta*d);
nikcleju@14 266 l = round(d - rho*m);
nikcleju@14 267
nikcleju@14 268 # Generate Omega and data based on parameters
nikcleju@14 269 Omega = AnalysisGenerate.Generate_Analysis_Operator(d, p);
nikcleju@14 270 # Optionally make Omega more coherent
nikcleju@14 271 #U,S,Vt = numpy.linalg.svd(Omega);
nikcleju@14 272 #Sdnew = S * (1+numpy.arange(S.size)) # Make D coherent, not Omega!
nikcleju@14 273 #Snew = numpy.vstack((numpy.diag(Sdnew), numpy.zeros((Omega.shape[0] - Omega.shape[1], Omega.shape[1]))))
nikcleju@14 274 #Omega = numpy.dot(U , numpy.dot(Snew,Vt))
nikcleju@14 275
nikcleju@14 276 # Generate data
nikcleju@14 277 x0,y,M,Lambda,realnoise = AnalysisGenerate.Generate_Data_Known_Omega(Omega, d,p,m,l,noiselevel, numvects,'l0');
nikcleju@14 278
nikcleju@14 279 return Omega,x0,y,M,realnoise
nikcleju@14 280
nikcleju@14 281
nikcleju@14 282 def testMatlab():
nikcleju@14 283 mdict = scipy.io.loadmat("E:\\CS\\Ale mele\\Analysis_ExactRec\\temp.mat")
nikcleju@14 284 algos = stdparams_exact.std1()[0]
nikcleju@14 285 res = run_once(algos, mdict['Omega'].byteswap().newbyteorder(),mdict['y'],mdict['M'],mdict['x0'])
nikcleju@14 286
nikcleju@14 287 def generateFig():
nikcleju@14 288 run(stdparams_exact.std1)
nikcleju@14 289
nikcleju@14 290 # Script main
nikcleju@14 291 if __name__ == "__main__":
nikcleju@14 292 #import cProfile
nikcleju@14 293 #cProfile.run('mainrun()', 'profile')
nikcleju@14 294 #run_mp(stdparams_exact.stdtest)
nikcleju@14 295 #runsingleexampledebug()
nikcleju@14 296
nikcleju@14 297 run(stdparams_exact.std1, ncpus=3)
nikcleju@14 298 #testMatlab()
nikcleju@14 299 #run(stdparams_exact.stdtest, ncpus=1)