annotate test_exact.py @ 21:d395461b92ae tip

Lots and lots of modifications. Approximate recovery script working.
author Nic Cleju <nikcleju@gmail.com>
date Mon, 23 Apr 2012 10:54:57 +0300
parents eccc7a5b9ee3
children
rev   line source
nikcleju@15 1 # -*- coding: utf-8 -*-
nikcleju@15 2 """
nikcleju@17 3 Main script for exact reconstruction tests.
nikcleju@17 4 Author: Nicolae Cleju
nikcleju@17 5 """
nikcleju@17 6 __author__ = "Nicolae Cleju"
nikcleju@17 7 __license__ = "GPL"
nikcleju@17 8 __email__ = "nikcleju@gmail.com"
nikcleju@15 9
nikcleju@15 10
nikcleju@15 11 import numpy
nikcleju@15 12 import scipy.io
nikcleju@15 13 import math
nikcleju@15 14 import os
nikcleju@15 15 import time
nikcleju@17 16 import multiprocessing
nikcleju@17 17 import sys
nikcleju@15 18
nikcleju@17 19 # Try to do smart importing of matplotlib
nikcleju@15 20 try:
nikcleju@15 21 import matplotlib
nikcleju@15 22 if os.name == 'nt':
nikcleju@15 23 print "Importing matplotlib with default (GUI) backend... "
nikcleju@15 24 else:
nikcleju@15 25 print "Importing matplotlib with \"Cairo\" backend... "
nikcleju@15 26 matplotlib.use('Cairo')
nikcleju@15 27 import matplotlib.pyplot as plt
nikcleju@15 28 import matplotlib.cm as cm
nikcleju@15 29 import matplotlib.colors as mcolors
nikcleju@15 30 except:
nikcleju@15 31 print "FAIL"
nikcleju@15 32 print "Importing matplotlib.pyplot failed. No figures at all"
nikcleju@15 33 print "Try selecting a different backend"
nikcleju@15 34
nikcleju@15 35 currmodule = sys.modules[__name__]
nikcleju@17 36 printLock = None # Lock for printing in a thread-safe way
nikcleju@17 37 # Thread-safe variable to store number of finished tasks
nikcleju@15 38 currmodule.proccount = multiprocessing.Value('I', 0) # 'I' = unsigned int, see docs (multiprocessing, array)
nikcleju@15 39
nikcleju@17 40 # Contains pre-defined simulation parameters
nikcleju@15 41 import stdparams_exact
nikcleju@17 42
nikcleju@17 43 # Analysis operator and data generation functions
nikcleju@15 44 import AnalysisGenerate
nikcleju@15 45
nikcleju@15 46 # For exceptions
nikcleju@15 47 import pyCSalgos.BP.l1eq_pd
nikcleju@15 48 import pyCSalgos.NESTA.NESTA
nikcleju@15 49
nikcleju@19 50 # For plotting with right axes
nikcleju@19 51 import utils
nikcleju@19 52
nikcleju@15 53
nikcleju@17 54
nikcleju@15 55 def initProcess(share, ntasks, printLock):
nikcleju@15 56 """
nikcleju@15 57 Pool initializer function (multiprocessing)
nikcleju@15 58 Needed to pass the shared variable to the worker processes
nikcleju@15 59 The variables must be global in the module in order to be seen later in run_once_tuple()
nikcleju@15 60 see http://stackoverflow.com/questions/1675766/how-to-combine-pool-map-with-array-shared-memory-in-python-multiprocessing
nikcleju@15 61 """
nikcleju@15 62 currmodule = sys.modules[__name__]
nikcleju@15 63 currmodule.proccount = share
nikcleju@15 64 currmodule.ntasks = ntasks
nikcleju@15 65 currmodule._printLock = printLock
nikcleju@15 66
nikcleju@15 67
nikcleju@15 68 def generateTaskParams(globalparams):
nikcleju@15 69 """
nikcleju@17 70 Generate a list of task parameters (for parallel running)
nikcleju@15 71 """
nikcleju@15 72 taskparams = []
nikcleju@15 73 SNRdb = globalparams['SNRdb']
nikcleju@15 74 sigma = globalparams['sigma']
nikcleju@15 75 d = globalparams['d']
nikcleju@15 76 deltas = globalparams['deltas']
nikcleju@15 77 rhos = globalparams['rhos']
nikcleju@15 78 numvects = globalparams['numvects']
nikcleju@15 79 algos = globalparams['algos']
nikcleju@15 80
nikcleju@15 81 # Process parameters
nikcleju@15 82 noiselevel = 1.0 / (10.0**(SNRdb/10.0));
nikcleju@15 83
nikcleju@15 84 for delta in deltas:
nikcleju@15 85 for rho in rhos:
nikcleju@15 86 p = round(sigma*d);
nikcleju@15 87 m = round(delta*d);
nikcleju@15 88 l = round(d - rho*m);
nikcleju@15 89
nikcleju@15 90 # Generate Omega and data based on parameters
nikcleju@15 91 Omega = AnalysisGenerate.Generate_Analysis_Operator(d, p);
nikcleju@15 92 # Optionally make Omega more coherent
nikcleju@15 93 #U,S,Vt = numpy.linalg.svd(Omega);
nikcleju@15 94 #Sdnew = S * (1+numpy.arange(S.size)) # Make D coherent, not Omega!
nikcleju@15 95 #Snew = numpy.vstack((numpy.diag(Sdnew), numpy.zeros((Omega.shape[0] - Omega.shape[1], Omega.shape[1]))))
nikcleju@15 96 #Omega = numpy.dot(U , numpy.dot(Snew,Vt))
nikcleju@15 97
nikcleju@15 98 # Generate data
nikcleju@15 99 x0,y,M,Lambda,realnoise = AnalysisGenerate.Generate_Data_Known_Omega(Omega, d,p,m,l,noiselevel, numvects,'l0')
nikcleju@15 100
nikcleju@15 101 # Append task params
nikcleju@15 102 taskparams.append((algos,Omega,y,M,x0))
nikcleju@15 103
nikcleju@15 104 return taskparams
nikcleju@15 105
nikcleju@15 106 def processResults(params, taskresults):
nikcleju@15 107 """
nikcleju@15 108 Process the raw task results
nikcleju@15 109 """
nikcleju@15 110 deltas = params['deltas']
nikcleju@15 111 rhos = params['rhos']
nikcleju@15 112 algos = params['algos']
nikcleju@15 113
nikcleju@15 114 # Init results
nikcleju@15 115 meanmatrix = dict()
nikcleju@15 116 elapsed = dict()
nikcleju@15 117 for algo in algos:
nikcleju@15 118 meanmatrix[algo[1]] = numpy.zeros((rhos.size, deltas.size))
nikcleju@15 119 elapsed[algo[1]] = 0
nikcleju@15 120
nikcleju@15 121 # Process results
nikcleju@15 122 idx = 0
nikcleju@15 123 for idelta,delta in zip(numpy.arange(deltas.size),deltas):
nikcleju@15 124 for irho,rho in zip(numpy.arange(rhos.size),rhos):
nikcleju@15 125 mrelerr,addelapsed = taskresults[idx]
nikcleju@15 126 idx = idx+1
nikcleju@15 127 for algotuple in algos:
nikcleju@15 128 meanmatrix[algotuple[1]][irho,idelta] = mrelerr[algotuple[1]]
nikcleju@15 129 if meanmatrix[algotuple[1]][irho,idelta] < 0 or math.isnan(meanmatrix[algotuple[1]][irho,idelta]):
nikcleju@15 130 meanmatrix[algotuple[1]][irho,idelta] = 0
nikcleju@15 131 elapsed[algotuple[1]] = elapsed[algotuple[1]] + addelapsed[algotuple[1]]
nikcleju@15 132
nikcleju@15 133 procresults = dict()
nikcleju@15 134 procresults['meanmatrix'] = meanmatrix
nikcleju@15 135 procresults['elapsed'] = elapsed
nikcleju@15 136 return procresults
nikcleju@15 137
nikcleju@15 138 def saveSim(params, procresults):
nikcleju@15 139 """
nikcleju@15 140 Save simulation to mat file
nikcleju@15 141 """
nikcleju@15 142 #tosaveparams = ['d','sigma','deltas','rhos','numvects','SNRdb']
nikcleju@15 143 #tosaveprocresults = ['meanmatrix','elapsed']
nikcleju@15 144
nikcleju@15 145 tosave = dict()
nikcleju@15 146 tosave['meanmatrix'] = procresults['meanmatrix']
nikcleju@15 147 tosave['elapsed'] = procresults['elapsed']
nikcleju@15 148 tosave['d'] = params['d']
nikcleju@15 149 tosave['sigma'] = params['sigma']
nikcleju@15 150 tosave['deltas'] = params['deltas']
nikcleju@15 151 tosave['rhos'] = params['rhos']
nikcleju@15 152 tosave['numvects'] = params['numvects']
nikcleju@15 153 tosave['SNRdb'] = params['SNRdb']
nikcleju@15 154 tosave['saveplotbase'] = params['saveplotbase']
nikcleju@15 155 tosave['saveplotexts'] = params['saveplotexts']
nikcleju@15 156 # Save algo names as cell array
nikcleju@15 157 obj_arr = numpy.zeros((len(params['algos']),), dtype=numpy.object)
nikcleju@15 158 idx = 0
nikcleju@15 159 for algotuple in params['algos']:
nikcleju@15 160 obj_arr[idx] = algotuple[1]
nikcleju@15 161 idx = idx+1
nikcleju@15 162 tosave['algonames'] = obj_arr
nikcleju@15 163 try:
nikcleju@15 164 scipy.io.savemat(params['savedataname'], tosave)
nikcleju@15 165 except:
nikcleju@15 166 print "Save error"
nikcleju@15 167
nikcleju@15 168 def loadSim(savedataname):
nikcleju@15 169 """
nikcleju@15 170 Load simulation from mat file
nikcleju@15 171 """
nikcleju@15 172 mdict = scipy.io.loadmat(savedataname)
nikcleju@15 173
nikcleju@15 174 params = dict()
nikcleju@15 175 procresults = dict()
nikcleju@15 176
nikcleju@15 177 procresults['meanmatrix'] = mdict['meanmatrix'][0,0]
nikcleju@15 178 procresults['elapsed'] = mdict['elapsed']
nikcleju@15 179 params['d'] = mdict['d']
nikcleju@15 180 params['sigma'] = mdict['sigma']
nikcleju@15 181 params['deltas'] = mdict['deltas']
nikcleju@15 182 params['rhos'] = mdict['rhos']
nikcleju@15 183 params['numvects'] = mdict['numvects']
nikcleju@15 184 params['SNRdb'] = mdict['SNRdb']
nikcleju@15 185 params['saveplotbase'] = mdict['saveplotbase'][0]
nikcleju@15 186 params['saveplotexts'] = mdict['saveplotexts']
nikcleju@15 187
nikcleju@15 188 algonames = mdict['algonames'][:,0]
nikcleju@15 189 params['algonames'] = []
nikcleju@15 190 for algoname in algonames:
nikcleju@15 191 params['algonames'].append(algoname[0])
nikcleju@15 192
nikcleju@15 193 return params, procresults
nikcleju@15 194
nikcleju@15 195 def plot(savedataname):
nikcleju@15 196 """
nikcleju@17 197 Plot results from a mat file.
nikcleju@17 198 The files are saved in the current folder.
nikcleju@15 199 """
nikcleju@15 200 params, procresults = loadSim(savedataname)
nikcleju@15 201 meanmatrix = procresults['meanmatrix']
nikcleju@15 202 saveplotbase = params['saveplotbase']
nikcleju@15 203 saveplotexts = params['saveplotexts']
nikcleju@15 204 algonames = params['algonames']
nikcleju@15 205
nikcleju@15 206 for algoname in algonames:
nikcleju@15 207 plt.figure()
nikcleju@15 208 plt.imshow(meanmatrix[algoname], cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower')
nikcleju@15 209 for ext in saveplotexts:
nikcleju@15 210 plt.savefig(saveplotbase + algoname + '.' + ext, bbox_inches='tight')
nikcleju@15 211 #plt.show()
nikcleju@15 212
nikcleju@15 213 #==========================
nikcleju@15 214 # Main function
nikcleju@15 215 #==========================
nikcleju@15 216 def run(params):
nikcleju@15 217 """
nikcleju@17 218 Run simulation with given parameters
nikcleju@15 219 """
nikcleju@15 220
nikcleju@17 221 print "This is analysis recovery ABS exact script by Nic"
nikcleju@18 222 print "Running simulation"
nikcleju@15 223
nikcleju@15 224 algos = params['algos']
nikcleju@15 225 d = params['d']
nikcleju@15 226 sigma = params['sigma']
nikcleju@15 227 deltas = params['deltas']
nikcleju@15 228 rhos = params['rhos']
nikcleju@15 229 numvects = params['numvects']
nikcleju@15 230 SNRdb = params['SNRdb']
nikcleju@18 231 if 'ncpus' in params:
nikcleju@18 232 ncpus = params['ncpus']
nikcleju@18 233 else:
nikcleju@18 234 ncpus = None
nikcleju@15 235 savedataname = params['savedataname']
nikcleju@15 236
nikcleju@15 237 if ncpus is None:
nikcleju@15 238 print " Running in parallel with default",multiprocessing.cpu_count(),"threads using \"multiprocessing\" package"
nikcleju@15 239 if multiprocessing.cpu_count() == 1:
nikcleju@15 240 doparallel = False
nikcleju@15 241 else:
nikcleju@15 242 doparallel = True
nikcleju@15 243 elif ncpus > 1:
nikcleju@15 244 print " Running in parallel with",ncpus,"threads using \"multiprocessing\" package"
nikcleju@15 245 doparallel = True
nikcleju@15 246 elif ncpus == 1:
nikcleju@15 247 print "Running single thread"
nikcleju@15 248 doparallel = False
nikcleju@15 249 else:
nikcleju@15 250 print "Wrong number of threads, exiting"
nikcleju@15 251 return
nikcleju@15 252
nikcleju@15 253 # Print summary of parameters
nikcleju@15 254 print "Parameters:"
nikcleju@15 255 print " Running algorithms",[algotuple[1] for algotuple in algos]
nikcleju@15 256
nikcleju@15 257 # Prepare parameters
nikcleju@18 258 print "Generating task parameters..."
nikcleju@15 259 taskparams = generateTaskParams(params)
nikcleju@15 260
nikcleju@15 261 # Store global variables
nikcleju@15 262 currmodule.ntasks = len(taskparams)
nikcleju@15 263
nikcleju@15 264 # Run
nikcleju@18 265 print "Running..."
nikcleju@15 266 taskresults = []
nikcleju@15 267 if doparallel:
nikcleju@15 268 currmodule.printLock = multiprocessing.Lock()
nikcleju@15 269 pool = multiprocessing.Pool(ncpus,initializer=initProcess,initargs=(currmodule.proccount,currmodule.ntasks,currmodule.printLock))
nikcleju@15 270 taskresults = pool.map(run_once_tuple, taskparams)
nikcleju@15 271 else:
nikcleju@15 272 for taskparam in taskparams:
nikcleju@15 273 taskresults.append(run_once_tuple(taskparam))
nikcleju@15 274
nikcleju@15 275 # Process results
nikcleju@15 276 procresults = processResults(params, taskresults)
nikcleju@15 277
nikcleju@15 278 # Save
nikcleju@15 279 saveSim(params, procresults)
nikcleju@15 280
nikcleju@15 281 print "Finished."
nikcleju@15 282
nikcleju@15 283 def run_once_tuple(t):
nikcleju@17 284 """
nikcleju@17 285 Wrapper for run_once() that explodes the tuple argument t and shows
nikcleju@17 286 the number of finished / remaining tasks
nikcleju@17 287 """
nikcleju@17 288
nikcleju@17 289 # Call run_once() here
nikcleju@15 290 results = run_once(*t)
nikcleju@15 291
nikcleju@15 292 if currmodule.printLock:
nikcleju@15 293 currmodule.printLock.acquire()
nikcleju@15 294
nikcleju@15 295 currmodule.proccount.value = currmodule.proccount.value + 1
nikcleju@15 296 print "================================"
nikcleju@15 297 print "Finished task",currmodule.proccount.value,"/",currmodule.ntasks,"tasks remaining",currmodule.ntasks - currmodule.proccount.value,"/",currmodule.ntasks
nikcleju@15 298 print "================================"
nikcleju@15 299
nikcleju@15 300 currmodule.printLock.release()
nikcleju@15 301
nikcleju@15 302 return results
nikcleju@15 303
nikcleju@15 304 def run_once(algos,Omega,y,M,x0):
nikcleju@15 305 """
nikcleju@17 306 Run single task (i.e. task function)
nikcleju@15 307 """
nikcleju@15 308
nikcleju@15 309 d = Omega.shape[1]
nikcleju@15 310
nikcleju@15 311 nalgos = len(algos)
nikcleju@15 312
nikcleju@15 313 xrec = dict()
nikcleju@15 314 err = dict()
nikcleju@15 315 relerr = dict()
nikcleju@15 316 elapsed = dict()
nikcleju@15 317
nikcleju@15 318 # Prepare storage variables for algorithms
nikcleju@15 319 for i,algo in zip(numpy.arange(nalgos),algos):
nikcleju@15 320 xrec[algo[1]] = numpy.zeros((d, y.shape[1]))
nikcleju@15 321 err[algo[1]] = numpy.zeros(y.shape[1])
nikcleju@15 322 relerr[algo[1]] = numpy.zeros(y.shape[1])
nikcleju@15 323 elapsed[algo[1]] = 0
nikcleju@15 324
nikcleju@15 325 # Run algorithms
nikcleju@15 326 for iy in numpy.arange(y.shape[1]):
nikcleju@15 327 for algofunc,strname in algos:
nikcleju@15 328 try:
nikcleju@15 329 timestart = time.time()
nikcleju@15 330 xrec[strname][:,iy] = algofunc(y[:,iy],M,Omega)
nikcleju@15 331 elapsed[strname] = elapsed[strname] + (time.time() - timestart)
nikcleju@15 332 except pyCSalgos.BP.l1eq_pd.l1eqNotImplementedError as e:
nikcleju@15 333 if currmodule.printLock:
nikcleju@15 334 currmodule.printLock.acquire()
nikcleju@15 335 print "Caught exception when running algorithm",strname," :",e.message
nikcleju@15 336 currmodule.printLock.release()
nikcleju@18 337 except ValueError as e:
nikcleju@18 338 if currmodule.printLock:
nikcleju@18 339 currmodule.printLock.acquire()
nikcleju@18 340 print "Caught ValueError exception when running algorithm",strname," :",e.message
nikcleju@18 341 currmodule.printLock.release()
nikcleju@15 342 err[strname][iy] = numpy.linalg.norm(x0[:,iy] - xrec[strname][:,iy])
nikcleju@15 343 relerr[strname][iy] = err[strname][iy] / numpy.linalg.norm(x0[:,iy])
nikcleju@15 344 for algofunc,strname in algos:
nikcleju@15 345 if currmodule.printLock:
nikcleju@15 346 currmodule.printLock.acquire()
nikcleju@15 347 print strname,' : avg relative error = ',numpy.mean(relerr[strname])
nikcleju@15 348 currmodule.printLock.release()
nikcleju@15 349
nikcleju@15 350 # Prepare results
nikcleju@15 351 #mrelerr = dict()
nikcleju@15 352 #for algotuple in algos:
nikcleju@15 353 # mrelerr[algotuple[1]] = numpy.mean(relerr[algotuple[1]])
nikcleju@15 354 #return mrelerr,elapsed
nikcleju@15 355
nikcleju@15 356 exactthr = 1e-6
nikcleju@15 357 mrelerr = dict()
nikcleju@15 358 for algotuple in algos:
nikcleju@15 359 mrelerr[algotuple[1]] = float(numpy.count_nonzero(relerr[algotuple[1]] < exactthr)) / y.shape[1]
nikcleju@15 360 return mrelerr,elapsed
nikcleju@15 361
nikcleju@15 362
nikcleju@15 363 def testMatlab():
nikcleju@17 364 """
nikcleju@17 365 For debugging only.
nikcleju@17 366 Load parameters from a mat file saved by Matlab.
nikcleju@17 367 """
nikcleju@15 368 mdict = scipy.io.loadmat("E:\\CS\\Ale mele\\Analysis_ExactRec\\temp.mat")
nikcleju@15 369 algos = stdparams_exact.std1()[0]
nikcleju@15 370 res = run_once(algos, mdict['Omega'].byteswap().newbyteorder(),mdict['y'],mdict['M'],mdict['x0'])
nikcleju@15 371
nikcleju@15 372 def generateFig():
nikcleju@17 373 """
nikcleju@17 374 Generates figures from paper "Analysis-based sparse reconstruction with synthesis-based solvers".
nikcleju@17 375 The figures are saved in the current folder.
nikcleju@17 376 """
nikcleju@20 377 run(stdparams_exact.params1)
nikcleju@19 378 #plot(stdparams_exact.params1['savedataname'])
nikcleju@19 379 utils.replot_exact(stdparams_exact.params1['savedataname'],
nikcleju@19 380 algonames = None, # will read them from mat file
nikcleju@19 381 doshow=False,
nikcleju@19 382 dosave=True,
nikcleju@19 383 saveplotbase=stdparams_exact.params1['saveplotbase'],
nikcleju@19 384 saveplotexts=stdparams_exact.params1['saveplotexts'])
nikcleju@15 385
nikcleju@15 386 # Script main
nikcleju@15 387 if __name__ == "__main__":
nikcleju@15 388
nikcleju@17 389 # Set the number of cpus for paraller running (or comment to leave default = max)
nikcleju@17 390 #stdparams_exact.paramstest['ncpus'] = 1
nikcleju@18 391
nikcleju@18 392 # Run test parameters
nikcleju@18 393 #stdparams_exact.paramstest['ncpus'] = 1
nikcleju@18 394 #run(stdparams_exact.paramstest)
nikcleju@18 395 #plot(stdparams_exact.paramstest['savedataname'])
nikcleju@18 396
nikcleju@18 397 #stdparams_exact.params1['ncpus'] = 1
nikcleju@17 398 generateFig()
nikcleju@17 399