changeset 30:5f46ff51c7ff

Added console output detailing parameters Prepared for complete run 10.11.2011
author nikcleju
date Thu, 10 Nov 2011 22:43:11 +0000
parents bc2a96a03b0a
children 829bf04c92af
files pyCSalgos/RecomTST/RecommendedTST.py scripts/ABSapprox.py
diffstat 2 files changed, 97 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/pyCSalgos/RecomTST/RecommendedTST.py	Thu Nov 10 18:49:38 2011 +0000
+++ b/pyCSalgos/RecomTST/RecommendedTST.py	Thu Nov 10 22:43:11 2011 +0000
@@ -3,7 +3,7 @@
 import math
 
 #function beta = RecommendedTST(X,Y, nsweep,tol,xinitial,ro)
-def RecommendedTST(X, Y, nsweep=300, tol=0.00001, xinitial=None, ro=0):
+def RecommendedTST(X, Y, nsweep=300, tol=0.00001, xinitial=None, ro=None):
 
   # function beta=RecommendedTST(X,y, nsweep,tol,xinitial,ro)
   # This function gets the measurement matrix and the measurements and
@@ -111,7 +111,7 @@
   #  end
   if xinitial is None:
     xinitial = np.zeros(p)
-  if ro == 0:
+  if ro == None:
     ro = 0.044417*delta**2 + 0.34142*delta + 0.14844
   
   k1 = math.floor(ro*n)
--- a/scripts/ABSapprox.py	Thu Nov 10 18:49:38 2011 +0000
+++ b/scripts/ABSapprox.py	Thu Nov 10 22:43:11 2011 +0000
@@ -12,6 +12,8 @@
 import pyCSalgos
 import pyCSalgos.GAP.GAP
 import pyCSalgos.SL0.SL0_approx
+import pyCSalgos.OMP.omp_QR
+import pyCSalgos.RecomTST.RecommendedTST
 
 #==========================
 # Algorithm functions
@@ -56,31 +58,62 @@
   L = 10
   return pyCSalgos.SL0.SL0_approx.SL0_approx(aggD,aggy,epsilon,sigmamin,sigma_decrease_factor,mu_0,L)
 
+def run_ompeps(y,M,Omega,D,U,S,Vt,epsilon,lbd):
+  
+  N,n = Omega.shape
+  #D = np.linalg.pinv(Omega)
+  #U,S,Vt = np.linalg.svd(D)
+  aggDupper = np.dot(M,D)
+  aggDlower = Vt[-(N-n):,:]
+  aggD = np.concatenate((aggDupper, lbd * aggDlower))
+  aggy = np.concatenate((y, np.zeros(N-n)))
+  
+  opts = dict()
+  opts['stopCrit'] = 'mse'
+  opts['stopTol'] = epsilon**2 / aggy.size
+  return pyCSalgos.OMP.omp_QR.greed_omp_qr(aggy,aggD,aggD.shape[1],opts)[0]
+  
+def run_tst(y,M,Omega,D,U,S,Vt,epsilon,lbd):
+  
+  N,n = Omega.shape
+  #D = np.linalg.pinv(Omega)
+  #U,S,Vt = np.linalg.svd(D)
+  aggDupper = np.dot(M,D)
+  aggDlower = Vt[-(N-n):,:]
+  aggD = np.concatenate((aggDupper, lbd * aggDlower))
+  aggy = np.concatenate((y, np.zeros(N-n)))
+  
+  return pyCSalgos.RecomTST.RecommendedTST.RecommendedTST(aggD, aggy, nsweep=3000, tol=epsilon / np.linalg.norm(aggy))
+
 #==========================
 # Define tuples (algorithm function, name)
 #==========================
 gap = (run_gap, 'GAP')
-sl0 = (run_sl0, 'SL0_approx')
+sl0 = (run_sl0, 'SL0a')
 bp  = (run_bp, 'BP')
+ompeps = (run_ompeps, 'OMPeps')
+tst = (run_tst, 'TST')
 
 # Define which algorithms to run
 #  1. Algorithms not depending on lambda
 algosN = gap,   # tuple
 #  2. Algorithms depending on lambda (our ABS approach)
-algosL = sl0,   # tuple
+algosL = sl0,bp,ompeps,tst   # tuple
   
 #==========================
 # Interface functions
 #==========================
 def run_multiproc(ncpus=None):
-  d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname = standard_params()
+  d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params()
   run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
-            doparallel=True, ncpus=ncpus)
+            doparallel=True, ncpus=ncpus,\
+            doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
 
 def run():
-  d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname = standard_params()
+  d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params()
   run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
-            doparallel=False)
+            doparallel=False,\
+            doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
   
 def standard_params():
   #Set up standard experiment parameters
@@ -88,10 +121,10 @@
   sigma = 2.0
   #deltas = np.arange(0.05,1.,0.05)
   #rhos = np.arange(0.05,1.,0.05)
-  deltas = np.array([0.05, 0.45, 0.95])
-  rhos = np.array([0.05, 0.45, 0.95])
-  #deltas = np.array([0.05])
-  #rhos = np.array([0.05])
+  #deltas = np.array([0.05, 0.45, 0.95])
+  #rhos = np.array([0.05, 0.45, 0.95])
+  deltas = np.array([0.05])
+  rhos = np.array([0.05])
   #delta = 0.8;
   #rho   = 0.15;
   numvects = 100; # Number of vectors to generate
@@ -102,10 +135,16 @@
   lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000])
   
   dosavedata = True
-  savedataname = 'ABSapprox.mat'
+  savedataname = 'approx_pt_std1.mat'
+  
+  doshowplot = False
+  dosaveplot = True
+  saveplotbase = 'approx_pt_std1_'
+  saveplotexts = ('png','pdf','eps')
     
   
-  return d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname
+  return d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,\
+          doshowplot,dosaveplot,saveplotbase,saveplotexts
   
 #==========================
 # Main functions  
@@ -114,18 +153,49 @@
             doparallel=False, ncpus=None,\
             doshowplot=False, dosaveplot=False, saveplotbase=None, saveplotexts=None,\
             dosavedata=False, savedataname=None):
+
+  print "This is analysis recovery ABS approximation script by Nic"
+  print "Running phase transition ( run_multi() )"
   
   if doparallel:
     from multiprocessing import Pool
     
-  # TODO: load different engine for matplotlib that allows saving without showing
-  try: 
-    import matplotlib.pyplot as plt
-  except:
-    dosaveplot = False
-    doshowplot = False
-  if dosaveplot and doshowplot:  
-    import matplotlib.cm as cm    
+  if dosaveplot or doshowplot:
+    try:
+      import matplotlib
+      if doshowplot:
+        print "Importing matplotlib with default (GUI) backend... ",
+      else:
+        print "Importing matplotlib with \"Cairo\" backend... ",
+        matplotlib.use('Cairo')
+      import matplotlib.pyplot as plt
+      import matplotlib.cm as cm
+      print "OK"        
+    except:
+      print "FAIL"
+      print "Importing matplotlib.pyplot failed. No figures at all"
+      print "Try selecting a different backend"
+      doshowplot = False
+      dosaveplot = False
+  
+  # Print summary of parameters
+  print "Parameters:"
+  if doparallel:
+    if ncpus is None:
+      print "  Running in parallel with default threads using \"multiprocessing\" package"
+    else:
+      print "  Running in parallel with",ncpus,"threads using \"multiprocessing\" package"
+  else:
+    print "Running single thread"
+  if doshowplot:
+    print "  Showing figures"
+  else:
+    print "  Not showing figures"
+  if dosaveplot:
+    print "  Saving figures as "+saveplotbase+"* with extensions ",saveplotexts
+  else:
+    print "  Not saving figures"
+  print "  Running algorithms",[algotuple[1] for algotuple in algosN],[algotuple[1] for algotuple in algosL]
   
   nalgosN = len(algosN)  
   nalgosL = len(algosL)
@@ -138,6 +208,7 @@
   
   # Prepare parameters
   jobparams = []
+  print "  (delta, rho) pairs to be run:"
   for idelta,delta in zip(np.arange(deltas.size),deltas):
     for irho,rho in zip(np.arange(rhos.size),rhos):
       
@@ -145,9 +216,11 @@
       Omega,x0,y,M,realnoise = generateData(d,sigma,delta,rho,numvects,SNRdb)
       
       #Save the parameters, and run after
-      print "***** delta = ",delta," rho = ",rho
+      print "    delta = ",delta," rho = ",rho
       jobparams.append((algosN,algosL, Omega,y,lambdas,realnoise,M,x0))
 
+  print "End of parameters"
+  
   # Run
   jobresults = []
   if doparallel:
@@ -211,7 +284,7 @@
         plt.imshow(meanmatrix[algoname][ilbd], cmap=cm.gray, interpolation='nearest',origin='lower')
         if dosaveplot:
           for ext in saveplotexts:
-            plt.savefig(saveplotbase + algoname + lambdas[ilbd] + '.' + ext)
+            plt.savefig(saveplotbase + algoname + ('_lbd%.0e' % lambdas[ilbd]) + '.' + ext)
     if doshowplot:
       plt.show()