comparison scripts/ABSapprox.py @ 33:116dcfacd1cc

Changed standard parameteres matplotlib: not loading Cairo if on Windows
author nikcleju
date Fri, 11 Nov 2011 16:12:17 +0000
parents e1da5140c9a5
children e8c4672e9de4
comparison
equal deleted inserted replaced
32:e1da5140c9a5 33:116dcfacd1cc
6 """ 6 """
7 7
8 import numpy as np 8 import numpy as np
9 import scipy.io 9 import scipy.io
10 import math 10 import math
11 import os
11 12
12 import pyCSalgos 13 import pyCSalgos
13 import pyCSalgos.GAP.GAP 14 import pyCSalgos.GAP.GAP
14 import pyCSalgos.SL0.SL0_approx 15 import pyCSalgos.SL0.SL0_approx
15 import pyCSalgos.OMP.omp_QR 16 import pyCSalgos.OMP.omp_QR
94 sl0 = (run_sl0, 'SL0a') 95 sl0 = (run_sl0, 'SL0a')
95 bp = (run_bp, 'BP') 96 bp = (run_bp, 'BP')
96 ompeps = (run_ompeps, 'OMPeps') 97 ompeps = (run_ompeps, 'OMPeps')
97 tst = (run_tst, 'TST') 98 tst = (run_tst, 'TST')
98 99
99 # Define which algorithms to run
100 # 1. Algorithms not depending on lambda
101 algosN = gap, # tuple
102 # 2. Algorithms depending on lambda (our ABS approach)
103 #algosL = sl0,bp,ompeps,tst # tuple
104 algosL = sl0,tst
105
106 #========================== 100 #==========================
107 # Pool initializer function (multiprocessing) 101 # Pool initializer function (multiprocessing)
108 # Needed to pass the shared variable to the worker processes 102 # Needed to pass the shared variable to the worker processes
109 # The variables must be global in the module in order to be seen later in run_once_tuple() 103 # The variables must be global in the module in order to be seen later in run_once_tuple()
110 # see http://stackoverflow.com/questions/1675766/how-to-combine-pool-map-with-array-shared-memory-in-python-multiprocessing 104 # see http://stackoverflow.com/questions/1675766/how-to-combine-pool-map-with-array-shared-memory-in-python-multiprocessing
114 currmodule = sys.modules[__name__] 108 currmodule = sys.modules[__name__]
115 currmodule.proccount = share 109 currmodule.proccount = share
116 currmodule.njobs = njobs 110 currmodule.njobs = njobs
117 111
118 #========================== 112 #==========================
119 # Interface functions 113 # Standard parameters
120 #========================== 114 #==========================
121 def run_multiproc(ncpus=None): 115 # Standard parameters 1
122 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params() 116 # All algorithms
123 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\ 117 # d=50, sigma = 2, delta and rho full resolution (0.05 step), lambdas = 0, 1e-4, 1e-2, 1, 100, 10000
124 doparallel=True, ncpus=ncpus,\ 118 # Do save data, do save plots, don't show plots
125 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts) 119 def std1():
126 120 # Define which algorithms to run
127 def run(): 121 algosN = gap, # tuple of algorithms not depending on lambda
128 d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = standard_params() 122 algosL = sl0,bp,ompeps,tst # tuple of algorithms depending on lambda (our ABS approach)
129 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\ 123
130 doparallel=False,\
131 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
132
133 def standard_params():
134 #Set up standard experiment parameters
135 d = 50.0; 124 d = 50.0;
136 sigma = 2.0 125 sigma = 2.0
137 #deltas = np.arange(0.05,1.,0.05) 126 deltas = np.arange(0.05,1.,0.05)
138 #rhos = np.arange(0.05,1.,0.05) 127 rhos = np.arange(0.05,1.,0.05)
139 deltas = np.array([0.05, 0.45, 0.95]) 128 #deltas = np.array([0.05, 0.45, 0.95])
140 rhos = np.array([0.05, 0.45, 0.95]) 129 #rhos = np.array([0.05, 0.45, 0.95])
141 #deltas = np.array([0.05])
142 #rhos = np.array([0.05])
143 #delta = 0.8;
144 #rho = 0.15;
145 numvects = 100; # Number of vectors to generate 130 numvects = 100; # Number of vectors to generate
146 SNRdb = 20.; # This is norm(signal)/norm(noise), so power, not energy 131 SNRdb = 20.; # This is norm(signal)/norm(noise), so power, not energy
147 # Values for lambda 132 # Values for lambda
148 #lambdas = [0 10.^linspace(-5, 4, 10)]; 133 #lambdas = [0 10.^linspace(-5, 4, 10)];
149 #lambdas = np.concatenate((np.array([0]), 10**np.linspace(-5, 4, 10)))
150 lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000]) 134 lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000])
151 135
152 dosavedata = True 136 dosavedata = True
153 savedataname = 'approx_pt_std1.mat' 137 savedataname = 'approx_pt_std1.mat'
154
155 doshowplot = False 138 doshowplot = False
156 dosaveplot = True 139 dosaveplot = True
157 saveplotbase = 'approx_pt_std1_' 140 saveplotbase = 'approx_pt_std1_'
158 saveplotexts = ('png','pdf','eps') 141 saveplotexts = ('png','pdf','eps')
159 142
160 143 return algosN,algosL,d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,\
161 return d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,\
162 doshowplot,dosaveplot,saveplotbase,saveplotexts 144 doshowplot,dosaveplot,saveplotbase,saveplotexts
163 145
146 # Standard parameters 2
147 # Algorithms: GAP, SL0 and BP
148 # d=50, sigma = 2, delta and rho only 3 x 3, lambdas = 0, 1e-4, 1e-2, 1, 100, 10000
149 # Do save data, do save plots, don't show plots
150 # Useful for short testing
151 def std2():
152 # Define which algorithms to run
153 algosN = gap, # tuple of algorithms not depending on lambda
154 algosL = sl0,bp # tuple of algorithms depending on lambda (our ABS approach)
155
156 d = 50.0
157 sigma = 2.0
158 deltas = np.array([0.05, 0.45, 0.95])
159 rhos = np.array([0.05, 0.45, 0.95])
160 numvects = 100; # Number of vectors to generate
161 SNRdb = 20.; # This is norm(signal)/norm(noise), so power, not energy
162 # Values for lambda
163 #lambdas = [0 10.^linspace(-5, 4, 10)];
164 lambdas = np.array([0., 0.0001, 0.01, 1, 100, 10000])
165
166 dosavedata = True
167 savedataname = 'approx_pt_std2.mat'
168 doshowplot = False
169 dosaveplot = True
170 saveplotbase = 'approx_pt_std2_'
171 saveplotexts = ('png','pdf','eps')
172
173 return algosN,algosL,d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,\
174 doshowplot,dosaveplot,saveplotbase,saveplotexts
175
176 #==========================
177 # Interface run functions
178 #==========================
179 def run_mp(std=std2,ncpus=None):
180
181 algosN,algosL,d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = std()
182 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
183 doparallel=True, ncpus=ncpus,\
184 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
185
186 def run(std=std2):
187 algosN,algosL,d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata,savedataname,doshowplot,dosaveplot,saveplotbase,saveplotexts = std()
188 run_multi(algosN, algosL, d,sigma,deltas,rhos,lambdas,numvects,SNRdb,dosavedata=dosavedata,savedataname=savedataname,\
189 doparallel=False,\
190 doshowplot=doshowplot,dosaveplot=dosaveplot,saveplotbase=saveplotbase,saveplotexts=saveplotexts)
164 #========================== 191 #==========================
165 # Main functions 192 # Main functions
166 #========================== 193 #==========================
167 def run_multi(algosN, algosL, d, sigma, deltas, rhos, lambdas, numvects, SNRdb, 194 def run_multi(algosN, algosL, d, sigma, deltas, rhos, lambdas, numvects, SNRdb,
168 doparallel=False, ncpus=None,\ 195 doparallel=False, ncpus=None,\
181 currmodule.proccount = multiprocessing.Value('I', 0) # 'I' = unsigned int, see docs (multiprocessing, array) 208 currmodule.proccount = multiprocessing.Value('I', 0) # 'I' = unsigned int, see docs (multiprocessing, array)
182 209
183 if dosaveplot or doshowplot: 210 if dosaveplot or doshowplot:
184 try: 211 try:
185 import matplotlib 212 import matplotlib
186 if doshowplot: 213 if doshowplot or os.name == 'nt':
187 print "Importing matplotlib with default (GUI) backend... ", 214 print "Importing matplotlib with default (GUI) backend... ",
188 else: 215 else:
189 print "Importing matplotlib with \"Cairo\" backend... ", 216 print "Importing matplotlib with \"Cairo\" backend... ",
190 matplotlib.use('Cairo') 217 matplotlib.use('Cairo')
191 import matplotlib.pyplot as plt 218 import matplotlib.pyplot as plt