annotate pyCSalgos/GAP/GAP.py @ 52:768b57e446ab

Created Analysis.py and working
author nikcleju
date Thu, 08 Dec 2011 09:05:04 +0000
parents 4f3bc35195ce
children a115c982a0fd
rev   line source
nikcleju@21 1 # -*- coding: utf-8 -*-
nikcleju@21 2 """
nikcleju@21 3 Created on Thu Oct 13 14:05:22 2011
nikcleju@21 4
nikcleju@21 5 @author: ncleju
nikcleju@21 6 """
nikcleju@21 7
nikcleju@52 8
nikcleju@27 9 import numpy
nikcleju@27 10 import numpy.linalg
nikcleju@21 11 import scipy as sp
nikcleju@52 12
nikcleju@21 13 import math
nikcleju@21 14
nikcleju@21 15
nikcleju@21 16 #function [xhat, arepr, lagmult] = ArgminOperL2Constrained(y, M, MH, Omega, OmegaH, Lambdahat, xinit, ilagmult, params)
nikcleju@21 17 def ArgminOperL2Constrained(y, M, MH, Omega, OmegaH, Lambdahat, xinit, ilagmult, params):
nikcleju@21 18
nikcleju@21 19 #
nikcleju@21 20 # This function aims to compute
nikcleju@21 21 # xhat = argmin || Omega(Lambdahat, :) * x ||_2 subject to || y - M*x ||_2 <= epsilon.
nikcleju@21 22 # arepr is the analysis representation corresponding to Lambdahat, i.e.,
nikcleju@21 23 # arepr = Omega(Lambdahat, :) * xhat.
nikcleju@21 24 # The function also returns the lagrange multiplier in the process used to compute xhat.
nikcleju@21 25 #
nikcleju@21 26 # Inputs:
nikcleju@21 27 # y : observation/measurements of an unknown vector x0. It is equal to M*x0 + noise.
nikcleju@21 28 # M : Measurement matrix
nikcleju@21 29 # MH : M', the conjugate transpose of M
nikcleju@21 30 # Omega : analysis operator
nikcleju@21 31 # OmegaH : Omega', the conjugate transpose of Omega. Also, synthesis operator.
nikcleju@21 32 # Lambdahat : an index set indicating some rows of Omega.
nikcleju@21 33 # xinit : initial estimate that will be used for the conjugate gradient algorithm.
nikcleju@21 34 # ilagmult : initial lagrange multiplier to be used in
nikcleju@21 35 # params : parameters
nikcleju@21 36 # params.noise_level : this corresponds to epsilon above.
nikcleju@21 37 # params.max_inner_iteration : `maximum' number of iterations in conjugate gradient method.
nikcleju@21 38 # params.l2_accurary : the l2 accuracy parameter used in conjugate gradient method
nikcleju@21 39 # params.l2solver : if the value is 'pseudoinverse', then direct matrix computation (not conjugate gradient method) is used. Otherwise, conjugate gradient method is used.
nikcleju@21 40 #
nikcleju@21 41
nikcleju@21 42 #d = length(xinit)
nikcleju@21 43 d = xinit.size
nikcleju@21 44 lagmultmax = 1e5;
nikcleju@21 45 lagmultmin = 1e-4;
nikcleju@21 46 lagmultfactor = 2.0;
nikcleju@21 47 accuracy_adjustment_exponent = 4/5.;
nikcleju@21 48 lagmult = max(min(ilagmult, lagmultmax), lagmultmin);
nikcleju@21 49 was_infeasible = 0;
nikcleju@21 50 was_feasible = 0;
nikcleju@21 51
nikcleju@21 52 #######################################################################
nikcleju@21 53 ## Computation done using direct matrix computation from matlab. (no conjugate gradient method.)
nikcleju@21 54 #######################################################################
nikcleju@21 55 #if strcmp(params.l2solver, 'pseudoinverse')
nikcleju@21 56 if params['l2solver'] == 'pseudoinverse':
nikcleju@21 57 #if strcmp(class(M), 'double') && strcmp(class(Omega), 'double')
nikcleju@21 58 if M.dtype == 'float64' and Omega.dtype == 'double':
nikcleju@21 59 while True:
nikcleju@21 60 alpha = math.sqrt(lagmult);
nikcleju@27 61 xhat = numpy.linalg.lstsq(numpy.concatenate((M, alpha*Omega[Lambdahat,:])), numpy.concatenate((y, numpy.zeros(Lambdahat.size))))[0]
nikcleju@27 62 temp = numpy.linalg.norm(y - numpy.dot(M,xhat), 2);
nikcleju@21 63 #disp(['fidelity error=', num2str(temp), ' lagmult=', num2str(lagmult)]);
nikcleju@21 64 if temp <= params['noise_level']:
nikcleju@21 65 was_feasible = True;
nikcleju@21 66 if was_infeasible:
nikcleju@21 67 break;
nikcleju@21 68 else:
nikcleju@21 69 lagmult = lagmult*lagmultfactor;
nikcleju@21 70 elif temp > params['noise_level']:
nikcleju@21 71 was_infeasible = True;
nikcleju@21 72 if was_feasible:
nikcleju@21 73 xhat = xprev.copy();
nikcleju@21 74 break;
nikcleju@21 75 lagmult = lagmult/lagmultfactor;
nikcleju@21 76 if lagmult < lagmultmin or lagmult > lagmultmax:
nikcleju@21 77 break;
nikcleju@21 78 xprev = xhat.copy();
nikcleju@27 79 arepr = numpy.dot(Omega[Lambdahat, :], xhat);
nikcleju@21 80 return xhat,arepr,lagmult;
nikcleju@21 81
nikcleju@21 82
nikcleju@21 83 ########################################################################
nikcleju@21 84 ## Computation using conjugate gradient method.
nikcleju@21 85 ########################################################################
nikcleju@21 86 #if strcmp(class(MH),'function_handle')
nikcleju@21 87 if hasattr(MH, '__call__'):
nikcleju@21 88 b = MH(y);
nikcleju@21 89 else:
nikcleju@27 90 b = numpy.dot(MH, y);
nikcleju@21 91
nikcleju@27 92 norm_b = numpy.linalg.norm(b, 2);
nikcleju@21 93 xhat = xinit.copy();
nikcleju@21 94 xprev = xinit.copy();
nikcleju@21 95 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b;
nikcleju@21 96 direction = -residual;
nikcleju@21 97 iter = 0;
nikcleju@21 98
nikcleju@21 99 while iter < params.max_inner_iteration:
nikcleju@21 100 iter = iter + 1;
nikcleju@27 101 alpha = numpy.linalg.norm(residual,2)**2 / numpy.dot(direction.T, TheHermitianMatrix(direction, M, MH, Omega, OmegaH, Lambdahat, lagmult));
nikcleju@21 102 xhat = xhat + alpha*direction;
nikcleju@21 103 prev_residual = residual.copy();
nikcleju@21 104 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b;
nikcleju@27 105 beta = numpy.linalg.norm(residual,2)**2 / numpy.linalg.norm(prev_residual,2)**2;
nikcleju@21 106 direction = -residual + beta*direction;
nikcleju@21 107
nikcleju@27 108 if numpy.linalg.norm(residual,2)/norm_b < params['l2_accuracy']*(lagmult**(accuracy_adjustment_exponent)) or iter == params['max_inner_iteration']:
nikcleju@21 109 #if strcmp(class(M), 'function_handle')
nikcleju@21 110 if hasattr(M, '__call__'):
nikcleju@27 111 temp = numpy.linalg.norm(y-M(xhat), 2);
nikcleju@21 112 else:
nikcleju@27 113 temp = numpy.linalg.norm(y-numpy.dot(M,xhat), 2);
nikcleju@21 114
nikcleju@21 115 #if strcmp(class(Omega), 'function_handle')
nikcleju@21 116 if hasattr(Omega, '__call__'):
nikcleju@21 117 u = Omega(xhat);
nikcleju@27 118 u = math.sqrt(lagmult)*numpy.linalg.norm(u(Lambdahat), 2);
nikcleju@21 119 else:
nikcleju@27 120 u = math.sqrt(lagmult)*numpy.linalg.norm(Omega[Lambdahat,:]*xhat, 2);
nikcleju@21 121
nikcleju@21 122
nikcleju@21 123 #disp(['residual=', num2str(norm(residual,2)), ' norm_b=', num2str(norm_b), ' omegapart=', num2str(u), ' fidelity error=', num2str(temp), ' lagmult=', num2str(lagmult), ' iter=', num2str(iter)]);
nikcleju@21 124
nikcleju@21 125 if temp <= params['noise_level']:
nikcleju@21 126 was_feasible = True;
nikcleju@21 127 if was_infeasible:
nikcleju@21 128 break;
nikcleju@21 129 else:
nikcleju@21 130 lagmult = lagmultfactor*lagmult;
nikcleju@21 131 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b;
nikcleju@21 132 direction = -residual;
nikcleju@21 133 iter = 0;
nikcleju@21 134 elif temp > params['noise_level']:
nikcleju@21 135 lagmult = lagmult/lagmultfactor;
nikcleju@21 136 if was_feasible:
nikcleju@21 137 xhat = xprev.copy();
nikcleju@21 138 break;
nikcleju@21 139 was_infeasible = True;
nikcleju@21 140 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b;
nikcleju@21 141 direction = -residual;
nikcleju@21 142 iter = 0;
nikcleju@21 143 if lagmult > lagmultmax or lagmult < lagmultmin:
nikcleju@21 144 break;
nikcleju@21 145 xprev = xhat.copy();
nikcleju@21 146 #elseif norm(xprev-xhat)/norm(xhat) < 1e-2
nikcleju@21 147 # disp(['rel_change=', num2str(norm(xprev-xhat)/norm(xhat))]);
nikcleju@21 148 # if strcmp(class(M), 'function_handle')
nikcleju@21 149 # temp = norm(y-M(xhat), 2);
nikcleju@21 150 # else
nikcleju@21 151 # temp = norm(y-M*xhat, 2);
nikcleju@21 152 # end
nikcleju@21 153 #
nikcleju@21 154 # if temp > 1.2*params.noise_level
nikcleju@21 155 # was_infeasible = 1;
nikcleju@21 156 # lagmult = lagmult/lagmultfactor;
nikcleju@21 157 # xprev = xhat;
nikcleju@21 158 # end
nikcleju@21 159
nikcleju@21 160 #disp(['fidelity_error=', num2str(temp)]);
nikcleju@21 161 print 'fidelity_error=',temp
nikcleju@21 162 #if iter == params['max_inner_iteration']:
nikcleju@21 163 #disp('max_inner_iteration reached. l2_accuracy not achieved.');
nikcleju@21 164
nikcleju@21 165 ##
nikcleju@21 166 # Compute analysis representation for xhat
nikcleju@21 167 ##
nikcleju@21 168 #if strcmp(class(Omega),'function_handle')
nikcleju@21 169 if hasattr(Omega, '__call__'):
nikcleju@21 170 temp = Omega(xhat);
nikcleju@21 171 arepr = temp(Lambdahat);
nikcleju@21 172 else: ## here Omega is assumed to be a matrix
nikcleju@27 173 arepr = numpy.dot(Omega[Lambdahat, :], xhat);
nikcleju@21 174
nikcleju@21 175 return xhat,arepr,lagmult
nikcleju@21 176
nikcleju@21 177
nikcleju@21 178 ##
nikcleju@21 179 # This function computes (M'*M + lm*Omega(L,:)'*Omega(L,:)) * x.
nikcleju@21 180 ##
nikcleju@21 181 #function w = TheHermitianMatrix(x, M, MH, Omega, OmegaH, L, lm)
nikcleju@21 182 def TheHermitianMatrix(x, M, MH, Omega, OmegaH, L, lm):
nikcleju@21 183 #if strcmp(class(M), 'function_handle')
nikcleju@21 184 if hasattr(M, '__call__'):
nikcleju@21 185 w = MH(M(x));
nikcleju@21 186 else: ## M and MH are matrices
nikcleju@27 187 w = numpy.dot(numpy.dot(MH, M), x);
nikcleju@21 188
nikcleju@21 189 if hasattr(Omega, '__call__'):
nikcleju@21 190 v = Omega(x);
nikcleju@27 191 vt = numpy.zeros(v.size);
nikcleju@21 192 vt[L] = v[L].copy();
nikcleju@21 193 w = w + lm*OmegaH(vt);
nikcleju@21 194 else: ## Omega is assumed to be a matrix and OmegaH is its conjugate transpose
nikcleju@27 195 w = w + lm*numpy.dot(numpy.dot(OmegaH[:, L],Omega[L, :]),x);
nikcleju@21 196
nikcleju@21 197 return w
nikcleju@21 198
nikcleju@21 199 def GAP(y, M, MH, Omega, OmegaH, params, xinit):
nikcleju@21 200 #function [xhat, Lambdahat] = GAP(y, M, MH, Omega, OmegaH, params, xinit)
nikcleju@21 201
nikcleju@21 202 ##
nikcleju@21 203 # [xhat, Lambdahat] = GAP(y, M, MH, Omega, OmegaH, params, xinit)
nikcleju@21 204 #
nikcleju@21 205 # Greedy Analysis Pursuit Algorithm
nikcleju@21 206 # This aims to find an approximate (sometimes exact) solution of
nikcleju@21 207 # xhat = argmin || Omega * x ||_0 subject to || y - M * x ||_2 <= epsilon.
nikcleju@21 208 #
nikcleju@21 209 # Outputs:
nikcleju@21 210 # xhat : estimate of the target cosparse vector x0.
nikcleju@21 211 # Lambdahat : estimate of the cosupport of x0.
nikcleju@21 212 #
nikcleju@21 213 # Inputs:
nikcleju@21 214 # y : observation/measurement vector of a target cosparse solution x0,
nikcleju@21 215 # given by relation y = M * x0 + noise.
nikcleju@21 216 # M : measurement matrix. This should be given either as a matrix or as a function handle
nikcleju@21 217 # which implements linear transformation.
nikcleju@21 218 # MH : conjugate transpose of M.
nikcleju@21 219 # Omega : analysis operator. Like M, this should be given either as a matrix or as a function
nikcleju@21 220 # handle which implements linear transformation.
nikcleju@21 221 # OmegaH : conjugate transpose of OmegaH.
nikcleju@21 222 # params : parameters that govern the behavior of the algorithm (mostly).
nikcleju@21 223 # params.num_iteration : GAP performs this number of iterations.
nikcleju@21 224 # params.greedy_level : determines how many rows of Omega GAP eliminates at each iteration.
nikcleju@21 225 # if the value is < 1, then the rows to be eliminated are determined by
nikcleju@21 226 # j : |omega_j * xhat| > greedy_level * max_i |omega_i * xhat|.
nikcleju@21 227 # if the value is >= 1, then greedy_level is the number of rows to be
nikcleju@21 228 # eliminated at each iteration.
nikcleju@21 229 # params.stopping_coefficient_size : when the maximum analysis coefficient is smaller than
nikcleju@21 230 # this, GAP terminates.
nikcleju@21 231 # params.l2solver : legitimate values are 'pseudoinverse' or 'cg'. determines which method
nikcleju@21 232 # is used to compute
nikcleju@21 233 # argmin || Omega_Lambdahat * x ||_2 subject to || y - M * x ||_2 <= epsilon.
nikcleju@21 234 # params.l2_accuracy : when l2solver is 'cg', this determines how accurately the above
nikcleju@21 235 # problem is solved.
nikcleju@21 236 # params.noise_level : this corresponds to epsilon above.
nikcleju@21 237 # xinit : initial estimate of x0 that GAP will start with. can be zeros(d, 1).
nikcleju@21 238 #
nikcleju@21 239 # Examples:
nikcleju@21 240 #
nikcleju@21 241 # Not particularly interesting:
nikcleju@21 242 # >> d = 100; p = 110; m = 60;
nikcleju@21 243 # >> M = randn(m, d);
nikcleju@21 244 # >> Omega = randn(p, d);
nikcleju@21 245 # >> y = M * x0 + noise;
nikcleju@21 246 # >> params.num_iteration = 40;
nikcleju@21 247 # >> params.greedy_level = 0.9;
nikcleju@21 248 # >> params.stopping_coefficient_size = 1e-4;
nikcleju@21 249 # >> params.l2solver = 'pseudoinverse';
nikcleju@21 250 # >> [xhat, Lambdahat] = GAP(y, M, M', Omega, Omega', params, zeros(d, 1));
nikcleju@21 251 #
nikcleju@21 252 # Assuming that FourierSampling.m, FourierSamplingH.m, FDAnalysis.m, etc. exist:
nikcleju@21 253 # >> n = 128;
nikcleju@21 254 # >> M = @(t) FourierSampling(t, n);
nikcleju@21 255 # >> MH = @(u) FourierSamplingH(u, n);
nikcleju@21 256 # >> Omega = @(t) FDAnalysis(t, n);
nikcleju@21 257 # >> OmegaH = @(u) FDSynthesis(t, n);
nikcleju@21 258 # >> params.num_iteration = 1000;
nikcleju@21 259 # >> params.greedy_level = 50;
nikcleju@21 260 # >> params.stopping_coefficient_size = 1e-5;
nikcleju@21 261 # >> params.l2solver = 'cg'; # in fact, 'pseudoinverse' does not even make sense.
nikcleju@21 262 # >> [xhat, Lambdahat] = GAP(y, M, MH, Omega, OmegaH, params, zeros(d, 1));
nikcleju@21 263 #
nikcleju@21 264 # Above: FourierSampling and FourierSamplingH are conjugate transpose of each other.
nikcleju@21 265 # FDAnalysis and FDSynthesis are conjugate transpose of each other.
nikcleju@21 266 # These routines are problem specific and need to be implemented by the user.
nikcleju@21 267
nikcleju@21 268 #d = length(xinit(:));
nikcleju@21 269 d = xinit.size
nikcleju@21 270
nikcleju@21 271 #if strcmp(class(Omega), 'function_handle')
nikcleju@21 272 # p = length(Omega(zeros(d,1)));
nikcleju@21 273 #else ## Omega is a matrix
nikcleju@21 274 # p = size(Omega, 1);
nikcleju@21 275 #end
nikcleju@21 276 if hasattr(Omega, '__call__'):
nikcleju@27 277 p = Omega(numpy.zeros((d,1)))
nikcleju@21 278 else:
nikcleju@21 279 p = Omega.shape[0]
nikcleju@21 280
nikcleju@21 281
nikcleju@21 282 iter = 0
nikcleju@21 283 lagmult = 1e-4
nikcleju@21 284 #Lambdahat = 1:p;
nikcleju@27 285 Lambdahat = numpy.arange(p)
nikcleju@21 286 #while iter < params.num_iteration
nikcleju@21 287 while iter < params["num_iteration"]:
nikcleju@21 288 iter = iter + 1
nikcleju@21 289 #[xhat, analysis_repr, lagmult] = ArgminOperL2Constrained(y, M, MH, Omega, OmegaH, Lambdahat, xinit, lagmult, params);
nikcleju@21 290 xhat,analysis_repr,lagmult = ArgminOperL2Constrained(y, M, MH, Omega, OmegaH, Lambdahat, xinit, lagmult, params)
nikcleju@21 291 #[to_be_removed, maxcoef] = FindRowsToRemove(analysis_repr, params.greedy_level);
nikcleju@27 292 to_be_removed,maxcoef = FindRowsToRemove(analysis_repr, params["greedy_level"])
nikcleju@21 293 #disp(['** maxcoef=', num2str(maxcoef), ' target=', num2str(params.stopping_coefficient_size), ' rows_remaining=', num2str(length(Lambdahat)), ' lagmult=', num2str(lagmult)]);
nikcleju@21 294 #print '** maxcoef=',maxcoef,' target=',params['stopping_coefficient_size'],' rows_remaining=',Lambdahat.size,' lagmult=',lagmult
nikcleju@21 295 if check_stopping_criteria(xhat, xinit, maxcoef, lagmult, Lambdahat, params):
nikcleju@21 296 break
nikcleju@21 297
nikcleju@21 298 xinit = xhat.copy()
nikcleju@21 299 #Lambdahat[to_be_removed] = []
nikcleju@27 300 Lambdahat = numpy.delete(Lambdahat.squeeze(),to_be_removed)
nikcleju@21 301
nikcleju@21 302 #n = sqrt(d);
nikcleju@21 303 #figure(9);
nikcleju@21 304 #RR = zeros(2*n, n-1);
nikcleju@21 305 #RR(Lambdahat) = 1;
nikcleju@21 306 #XD = ones(n, n);
nikcleju@21 307 #XD(:, 2:end) = XD(:, 2:end) .* RR(1:n, :);
nikcleju@21 308 #XD(:, 1:(end-1)) = XD(:, 1:(end-1)) .* RR(1:n, :);
nikcleju@21 309 #XD(2:end, :) = XD(2:end, :) .* RR((n+1):(2*n), :)';
nikcleju@21 310 #XD(1:(end-1), :) = XD(1:(end-1), :) .* RR((n+1):(2*n), :)';
nikcleju@21 311 #XD = FD2DiagnosisPlot(n, Lambdahat);
nikcleju@21 312 #imshow(XD);
nikcleju@21 313 #figure(10);
nikcleju@21 314 #imshow(reshape(real(xhat), n, n));
nikcleju@21 315
nikcleju@21 316 #return;
nikcleju@27 317 return xhat,Lambdahat
nikcleju@21 318
nikcleju@21 319 def FindRowsToRemove(analysis_repr, greedy_level):
nikcleju@21 320 #function [to_be_removed, maxcoef] = FindRowsToRemove(analysis_repr, greedy_level)
nikcleju@21 321
nikcleju@21 322 #abscoef = abs(analysis_repr(:));
nikcleju@27 323 abscoef = numpy.abs(analysis_repr)
nikcleju@21 324 #n = length(abscoef);
nikcleju@21 325 n = abscoef.size
nikcleju@21 326 #maxcoef = max(abscoef);
nikcleju@21 327 maxcoef = abscoef.max()
nikcleju@21 328 if greedy_level >= 1:
nikcleju@21 329 #qq = quantile(abscoef, 1.0-greedy_level/n);
nikcleju@21 330 qq = sp.stats.mstats.mquantile(abscoef, 1.0-greedy_level/n, 0.5, 0.5)
nikcleju@21 331 else:
nikcleju@21 332 qq = maxcoef*greedy_level
nikcleju@21 333
nikcleju@21 334 #to_be_removed = find(abscoef >= qq);
nikcleju@27 335 # [0] needed because nonzero() returns a tuple of arrays!
nikcleju@27 336 to_be_removed = numpy.nonzero(abscoef >= qq)[0]
nikcleju@21 337 #return;
nikcleju@27 338 return to_be_removed,maxcoef
nikcleju@21 339
nikcleju@21 340 def check_stopping_criteria(xhat, xinit, maxcoef, lagmult, Lambdahat, params):
nikcleju@21 341 #function r = check_stopping_criteria(xhat, xinit, maxcoef, lagmult, Lambdahat, params)
nikcleju@21 342
nikcleju@21 343 #if isfield(params, 'stopping_coefficient_size') && maxcoef < params.stopping_coefficient_size
nikcleju@21 344 if ('stopping_coefficient_size' in params) and maxcoef < params['stopping_coefficient_size']:
nikcleju@21 345 return 1
nikcleju@21 346
nikcleju@21 347 #if isfield(params, 'stopping_lagrange_multiplier_size') && lagmult > params.stopping_lagrange_multiplier_size
nikcleju@21 348 if ('stopping_lagrange_multiplier_size' in params) and lagmult > params['stopping_lagrange_multiplier_size']:
nikcleju@21 349 return 1
nikcleju@21 350
nikcleju@21 351 #if isfield(params, 'stopping_relative_solution_change') && norm(xhat-xinit)/norm(xhat) < params.stopping_relative_solution_change
nikcleju@27 352 if ('stopping_relative_solution_change' in params) and numpy.linalg.norm(xhat-xinit)/numpy.linalg.norm(xhat) < params['stopping_relative_solution_change']:
nikcleju@21 353 return 1
nikcleju@21 354
nikcleju@21 355 #if isfield(params, 'stopping_cosparsity') && length(Lambdahat) < params.stopping_cosparsity
nikcleju@21 356 if ('stopping_cosparsity' in params) and Lambdahat.size() < params['stopping_cosparsity']:
nikcleju@21 357 return 1
nikcleju@21 358
nikcleju@21 359 return 0