changeset 33:116dcfacd1cc

Changed standard parameteres matplotlib: not loading Cairo if on Windows
author nikcleju
date Fri, 11 Nov 2011 16:12:17 +0000
parents e1da5140c9a5
children e8c4672e9de4
files scripts/ABSapprox.py
diffstat 1 files changed, 73 insertions(+), 46 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/ABSapprox.py	Fri Nov 11 15:35:55 2011 +0000
+++ b/scripts/ABSapprox.py	Fri Nov 11 16:12:17 2011 +0000
@@ -8,6 +8,7 @@
 import numpy as np
 import scipy.io
 import math
+import os
 
 import pyCSalgos
 import pyCSalgos.GAP.GAP
@@ -96,13 +97,6 @@
 ompeps = (run_ompeps, 'OMPeps')
 tst = (run_tst, 'TST')
 
-# Define which algorithms to run
-#  1. Algorithms not depending on lambda
-algosN = gap,   # tuple
-#  2. Algorithms depending on lambda (our ABS approach)
-#algosL = sl0,bp,ompeps,tst   # tuple
-algosL = sl0,tst
-  
 #==========================
 # Pool initializer function (multiprocessing)
 # Needed to pass the shared variable to the worker processes
@@ -116,51 +110,84 @@
     currmodule.njobs = njobs
   
 #==========================
-# Interface functions
+# Standard parameters
 #==========================
-def run_multiproc(ncpus=None):
-  d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params()
+# Standard parameters 1
+# All algorithms
+# d=50, sigma = 2, delta and rho full resolution (0.05 step), lambdas = 0, 1e-4, 1e-2, 1, 100, 10000
+# Do save data, do save plots, don't show plots
+def std1():
+  # Define which algorithms to run
+  algosN = gap,                 # tuple of algorithms not depending on lambda
+  algosL = sl0,bp,ompeps,tst    # tuple of algorithms depending on lambda (our ABS approach)
+  
+  d = 50.0;
+  sigma = 2.0
+  deltas = np.arange(0.05,1.,0.05)
+  rhos = np.arange(0.05,1.,0.05)
+  #deltas = np.array([0.05, 0.45, 0.95])
+  #rhos = np.array([0.05, 0.45, 0.95])
+  numvects = 100; # Number of vectors to generate
+  SNRdb = 20.;    # This is norm(signal)/norm(noise), so power, not energy
+  # Values for lambda
+  #lambdas = [0 10.^linspace(-5, 4, 10)];
+  lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000])
+  
+  dosavedata = True
+  savedataname = 'approx_pt_std1.mat'
+  doshowplot = False
+  dosaveplot = True
+  saveplotbase = 'approx_pt_std1_'
+  saveplotexts = ('png','pdf','eps')
+
+  return algosN,algosL,d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,\
+          doshowplot,dosaveplot,saveplotbase,saveplotexts
+          
+# Standard parameters 2
+# Algorithms: GAP, SL0 and BP
+# d=50, sigma = 2, delta and rho only 3 x 3, lambdas = 0, 1e-4, 1e-2, 1, 100, 10000
+# Do save data, do save plots, don't show plots
+# Useful for short testing 
+def std2():
+  # Define which algorithms to run
+  algosN = gap,      # tuple of algorithms not depending on lambda
+  algosL = sl0,bp    # tuple of algorithms depending on lambda (our ABS approach)
+  
+  d = 50.0
+  sigma = 2.0
+  deltas = np.array([0.05, 0.45, 0.95])
+  rhos = np.array([0.05, 0.45, 0.95])
+  numvects = 100; # Number of vectors to generate
+  SNRdb = 20.;    # This is norm(signal)/norm(noise), so power, not energy
+  # Values for lambda
+  #lambdas = [0 10.^linspace(-5, 4, 10)];
+  lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000])
+  
+  dosavedata = True
+  savedataname = 'approx_pt_std2.mat'
+  doshowplot = False
+  dosaveplot = True
+  saveplotbase = 'approx_pt_std2_'
+  saveplotexts = ('png','pdf','eps')
+
+  return algosN,algosL,d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,\
+          doshowplot,dosaveplot,saveplotbase,saveplotexts          
+  
+#==========================
+# Interface run functions
+#==========================
+def run_mp(std=std2,ncpus=None):
+  
+  algosN,algosL,d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = std()
   run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
             doparallel=True, ncpus=ncpus,\
             doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
 
-def run():
-  d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params()
+def run(std=std2):
+  algosN,algosL,d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = std()
   run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
             doparallel=False,\
-            doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
-  
-def standard_params():
-  #Set up standard experiment parameters
-  d = 50.0;
-  sigma = 2.0
-  #deltas = np.arange(0.05,1.,0.05)
-  #rhos = np.arange(0.05,1.,0.05)
-  deltas = np.array([0.05, 0.45, 0.95])
-  rhos = np.array([0.05, 0.45, 0.95])
-  #deltas = np.array([0.05])
-  #rhos = np.array([0.05])
-  #delta = 0.8;
-  #rho   = 0.15;
-  numvects = 100; # Number of vectors to generate
-  SNRdb = 20.;    # This is norm(signal)/norm(noise), so power, not energy
-  # Values for lambda
-  #lambdas = [0 10.^linspace(-5, 4, 10)];
-  #lambdas = np.concatenate((np.array([0]), 10**np.linspace(-5, 4, 10)))
-  lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000])
-  
-  dosavedata = True
-  savedataname = 'approx_pt_std1.mat'
-  
-  doshowplot = False
-  dosaveplot = True
-  saveplotbase = 'approx_pt_std1_'
-  saveplotexts = ('png','pdf','eps')
-    
-  
-  return d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,\
-          doshowplot,dosaveplot,saveplotbase,saveplotexts
-  
+            doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)  
 #==========================
 # Main functions  
 #==========================
@@ -183,7 +210,7 @@
   if dosaveplot or doshowplot:
     try:
       import matplotlib
-      if doshowplot:
+      if doshowplot or os.name == 'nt':
         print "Importing matplotlib with default (GUI) backend... ",
       else:
         print "Importing matplotlib with \"Cairo\" backend... ",