annotate pyCSalgos/NESTA/NESTA.py @ 46:88f0ebe1667a

Finished NESTA implementation. Not tested yet
author nikcleju
date Tue, 29 Nov 2011 22:06:20 +0000
parents 7524d7749456
children e6a5f2173015
rev   line source
nikcleju@45 1 # -*- coding: utf-8 -*-
nikcleju@45 2 """
nikcleju@45 3 Created on Tue Nov 29 16:55:20 2011
nikcleju@45 4
nikcleju@45 5 @author: ncleju
nikcleju@45 6 """
nikcleju@45 7
nikcleju@45 8 import numpy
nikcleju@45 9 import math
nikcleju@45 10
nikcleju@45 11 #function [xk,niter,residuals,outputData,opts] =NESTA(A,At,b,muf,delta,opts)
nikcleju@45 12 def NESTA(A,At,b,muf,delta,opts=None):
nikcleju@45 13 # [xk,niter,residuals,outputData] =NESTA(A,At,b,muf,delta,opts)
nikcleju@45 14 #
nikcleju@45 15 # Solves a L1 minimization problem under a quadratic constraint using the
nikcleju@45 16 # Nesterov algorithm, with continuation:
nikcleju@45 17 #
nikcleju@45 18 # min_x || U x ||_1 s.t. ||y - Ax||_2 <= delta
nikcleju@45 19 #
nikcleju@45 20 # Continuation is performed by sequentially applying Nesterov's algorithm
nikcleju@45 21 # with a decreasing sequence of values of mu0 >= mu >= muf
nikcleju@45 22 #
nikcleju@45 23 # The primal prox-function is also adapted by accounting for a first guess
nikcleju@45 24 # xplug that also tends towards x_muf
nikcleju@45 25 #
nikcleju@45 26 # The observation matrix A is a projector
nikcleju@45 27 #
nikcleju@45 28 # Inputs: A and At - measurement matrix and adjoint (either a matrix, in which
nikcleju@45 29 # case At is unused, or function handles). m x n dimensions.
nikcleju@45 30 # b - Observed data, a m x 1 array
nikcleju@45 31 # muf - The desired value of mu at the last continuation step.
nikcleju@45 32 # A smaller mu leads to higher accuracy.
nikcleju@45 33 # delta - l2 error bound. This enforces how close the variable
nikcleju@45 34 # must fit the observations b, i.e. || y - Ax ||_2 <= delta
nikcleju@45 35 # If delta = 0, enforces y = Ax
nikcleju@45 36 # Common heuristic: delta = sqrt(m + 2*sqrt(2*m))*sigma;
nikcleju@45 37 # where sigma=std(noise).
nikcleju@45 38 # opts -
nikcleju@45 39 # This is a structure that contains additional options,
nikcleju@45 40 # some of which are optional.
nikcleju@45 41 # The fieldnames are case insensitive. Below
nikcleju@45 42 # are the possible fieldnames:
nikcleju@45 43 #
nikcleju@45 44 # opts.xplug - the first guess for the primal prox-function, and
nikcleju@45 45 # also the initial point for xk. By default, xplug = At(b)
nikcleju@45 46 # opts.U and opts.Ut - Analysis/Synthesis operators
nikcleju@45 47 # (either matrices of function handles).
nikcleju@45 48 # opts.normU - if opts.U is provided, this should be norm(U)
nikcleju@45 49 # otherwise it will have to be calculated (potentially
nikcleju@45 50 # expensive)
nikcleju@45 51 # opts.MaxIntIter - number of continuation steps.
nikcleju@45 52 # default is 5
nikcleju@45 53 # opts.maxiter - max number of iterations in an inner loop.
nikcleju@45 54 # default is 10,000
nikcleju@45 55 # opts.TolVar - tolerance for the stopping criteria
nikcleju@45 56 # opts.stopTest - which stopping criteria to apply
nikcleju@45 57 # opts.stopTest == 1 : stop when the relative
nikcleju@45 58 # change in the objective function is less than
nikcleju@45 59 # TolVar
nikcleju@45 60 # opts.stopTest == 2 : stop with the l_infinity norm
nikcleju@45 61 # of difference in the xk variable is less
nikcleju@45 62 # than TolVar
nikcleju@45 63 # opts.TypeMin - if this is 'L1' (default), then
nikcleju@45 64 # minimizes a smoothed version of the l_1 norm.
nikcleju@45 65 # If this is 'tv', then minimizes a smoothed
nikcleju@45 66 # version of the total-variation norm.
nikcleju@45 67 # The string is case insensitive.
nikcleju@45 68 # opts.Verbose - if this is 0 or false, then very
nikcleju@45 69 # little output is displayed. If this is 1 or true,
nikcleju@45 70 # then output every iteration is displayed.
nikcleju@45 71 # If this is a number p greater than 1, then
nikcleju@45 72 # output is displayed every pth iteration.
nikcleju@45 73 # opts.fid - if this is 1 (default), the display is
nikcleju@45 74 # the usual Matlab screen. If this is the file-id
nikcleju@45 75 # of a file opened with fopen, then the display
nikcleju@45 76 # will be redirected to this file.
nikcleju@45 77 # opts.errFcn - if this is a function handle,
nikcleju@45 78 # then the program will evaluate opts.errFcn(xk)
nikcleju@45 79 # at every iteration and display the result.
nikcleju@45 80 # ex. opts.errFcn = @(x) norm( x - x_true )
nikcleju@45 81 # opts.outFcn - if this is a function handle,
nikcleju@45 82 # then then program will evaluate opts.outFcn(xk)
nikcleju@45 83 # at every iteration and save the results in outputData.
nikcleju@45 84 # If the result is a vector (as opposed to a scalar),
nikcleju@45 85 # it should be a row vector and not a column vector.
nikcleju@45 86 # ex. opts.outFcn = @(x) [norm( x - xtrue, 'inf' ),...
nikcleju@45 87 # norm( x - xtrue) / norm(xtrue)]
nikcleju@45 88 # opts.AAtinv - this is an experimental new option. AAtinv
nikcleju@45 89 # is the inverse of AA^*. This allows the use of a
nikcleju@45 90 # matrix A which is not a projection, but only
nikcleju@45 91 # for the noiseless (i.e. delta = 0) case.
nikcleju@45 92 # opts.USV - another experimental option. This supercedes
nikcleju@45 93 # the AAtinv option, so it is recommended that you
nikcleju@45 94 # do not define AAtinv. This allows the use of a matrix
nikcleju@45 95 # A which is not a projection, and works for the
nikcleju@45 96 # noisy ( i.e. delta > 0 ) case.
nikcleju@45 97 # opts.USV should contain three fields:
nikcleju@45 98 # opts.USV.U is the U from [U,S,V] = svd(A)
nikcleju@45 99 # likewise, opts.USV.S and opts.USV.V are S and V
nikcleju@45 100 # from svd(A). S may be a matrix or a vector.
nikcleju@45 101 #
nikcleju@45 102 # Outputs:
nikcleju@45 103 # xk - estimate of the solution x
nikcleju@45 104 # niter - number of iterations
nikcleju@45 105 # residuals - first column is the residual at every step,
nikcleju@45 106 # second column is the value of f_mu at every step
nikcleju@45 107 # outputData - a matrix, where each row r is the output
nikcleju@45 108 # from opts.outFcn, if supplied.
nikcleju@45 109 # opts - the structure containing the options that were used
nikcleju@45 110 #
nikcleju@45 111 # Written by: Jerome Bobin, Caltech
nikcleju@45 112 # Email: bobin@acm.caltech.edu
nikcleju@45 113 # Created: February 2009
nikcleju@45 114 # Modified (version 1.0): May 2009, Jerome Bobin and Stephen Becker, Caltech
nikcleju@45 115 # Modified (version 1.1): Nov 2009, Stephen Becker, Caltech
nikcleju@45 116 #
nikcleju@45 117 # NESTA Version 1.1
nikcleju@45 118 # See also Core_Nesterov
nikcleju@45 119
nikcleju@45 120 #---------------------
nikcleju@45 121 # Original Matab code:
nikcleju@45 122 #
nikcleju@45 123 #if nargin < 6, opts = []; end
nikcleju@45 124 #if isempty(opts) && isnumeric(opts), opts = struct; end
nikcleju@45 125 #
nikcleju@45 126 ##---- Set defaults
nikcleju@45 127 #fid = setOpts('fid',1);
nikcleju@45 128 #Verbose = setOpts('Verbose',true);
nikcleju@45 129 #function printf(varargin), fprintf(fid,varargin{:}); end
nikcleju@45 130 #MaxIntIter = setOpts('MaxIntIter',5,1);
nikcleju@45 131 #TypeMin = setOpts('TypeMin','L1');
nikcleju@45 132 #TolVar = setOpts('tolvar',1e-5);
nikcleju@45 133 #[U,U_userSet] = setOpts('U', @(x) x );
nikcleju@45 134 #if ~isa(U,'function_handle')
nikcleju@45 135 # Ut = setOpts('Ut',[]);
nikcleju@45 136 #else
nikcleju@45 137 # Ut = setOpts('Ut', @(x) x );
nikcleju@45 138 #end
nikcleju@45 139 #xplug = setOpts('xplug',[]);
nikcleju@45 140 #normU = setOpts('normU',[]); # so we can tell if it's been set
nikcleju@45 141 #
nikcleju@45 142 #residuals = []; outputData = [];
nikcleju@45 143 #AAtinv = setOpts('AAtinv',[]);
nikcleju@45 144 #USV = setOpts('USV',[]);
nikcleju@45 145 #if ~isempty(USV)
nikcleju@45 146 # if isstruct(USV)
nikcleju@45 147 # Q = USV.U; # we can't use "U" as the variable name
nikcleju@45 148 # # since "U" already refers to the analysis operator
nikcleju@45 149 # S = USV.S;
nikcleju@45 150 # if isvector(S), s = S; #S = diag(s);
nikcleju@45 151 # else s = diag(S); end
nikcleju@45 152 # #V = USV.V;
nikcleju@45 153 # else
nikcleju@45 154 # error('opts.USV must be a structure');
nikcleju@45 155 # end
nikcleju@45 156 #end
nikcleju@45 157 #
nikcleju@45 158 ## -- We can handle non-projections IF a (fast) routine for computing
nikcleju@45 159 ## the psuedo-inverse is available.
nikcleju@45 160 ## We can handle a nonzero delta, but we need the full SVD
nikcleju@45 161 #if isempty(AAtinv) && isempty(USV)
nikcleju@45 162 # # Check if A is a partial isometry, i.e. if AA' = I
nikcleju@45 163 # z = randn(size(b));
nikcleju@45 164 # if isa(A,'function_handle'), AAtz = A(At(z));
nikcleju@45 165 # else AAtz = A*(A'*z); end
nikcleju@45 166 # if norm( AAtz - z )/norm(z) > 1e-8
nikcleju@45 167 # error('Measurement matrix A must be a partial isometry: AA''=I');
nikcleju@45 168 # end
nikcleju@45 169 #end
nikcleju@45 170 #
nikcleju@45 171 ## -- Find a initial guess if not already provided.
nikcleju@45 172 ## Use least-squares solution: x_ref = A'*inv(A*A')*b
nikcleju@45 173 ## If A is a projection, the least squares solution is trivial
nikcleju@45 174 #if isempty(xplug) || norm(xplug) < 1e-12
nikcleju@45 175 # if ~isempty(USV) && isempty(AAtinv)
nikcleju@45 176 # AAtinv = Q*diag( s.^(-2) )*Q';
nikcleju@45 177 # end
nikcleju@45 178 # if ~isempty(AAtinv)
nikcleju@45 179 # if delta > 0 && isempty(USV)
nikcleju@45 180 # error('delta must be zero for non-projections');
nikcleju@45 181 # end
nikcleju@45 182 # if isa(AAtinv,'function_handle')
nikcleju@45 183 # x_ref = AAtinv(b);
nikcleju@45 184 # else
nikcleju@45 185 # x_ref = AAtinv * b;
nikcleju@45 186 # end
nikcleju@45 187 # else
nikcleju@45 188 # x_ref = b;
nikcleju@45 189 # end
nikcleju@45 190 #
nikcleju@45 191 # if isa(A,'function_handle')
nikcleju@45 192 # x_ref=At(x_ref);
nikcleju@45 193 # else
nikcleju@45 194 # x_ref = A'*x_ref;
nikcleju@45 195 # end
nikcleju@45 196 #
nikcleju@45 197 # if isempty(xplug)
nikcleju@45 198 # xplug = x_ref;
nikcleju@45 199 # end
nikcleju@45 200 # # x_ref itself is used to calculate mu_0
nikcleju@45 201 # # in the case that xplug has very small norm
nikcleju@45 202 #else
nikcleju@45 203 # x_ref = xplug;
nikcleju@45 204 #end
nikcleju@45 205 #
nikcleju@45 206 ## use x_ref, not xplug, to find mu_0
nikcleju@45 207 #if isa(U,'function_handle')
nikcleju@45 208 # Ux_ref = U(x_ref);
nikcleju@45 209 #else
nikcleju@45 210 # Ux_ref = U*x_ref;
nikcleju@45 211 #end
nikcleju@45 212 #switch lower(TypeMin)
nikcleju@45 213 # case 'l1'
nikcleju@45 214 # mu0 = 0.9*max(abs(Ux_ref));
nikcleju@45 215 # case 'tv'
nikcleju@45 216 # mu0 = ValMUTv(Ux_ref);
nikcleju@45 217 #end
nikcleju@45 218 #
nikcleju@45 219 ## -- If U was set by the user and normU not supplied, then calcuate norm(U)
nikcleju@45 220 #if U_userSet && isempty(normU)
nikcleju@45 221 # # simple case: U*U' = I or U'*U = I, in which case norm(U) = 1
nikcleju@45 222 # z = randn(size(xplug));
nikcleju@45 223 # if isa(U,'function_handle'), UtUz = Ut(U(z)); else UtUz = U'*(U*z); end
nikcleju@45 224 # if norm( UtUz - z )/norm(z) < 1e-8
nikcleju@45 225 # normU = 1;
nikcleju@45 226 # else
nikcleju@45 227 # z = randn(size(Ux_ref));
nikcleju@45 228 # if isa(U,'function_handle')
nikcleju@45 229 # UUtz = U(Ut(z));
nikcleju@45 230 # else
nikcleju@45 231 # UUtz = U*(U'*z);
nikcleju@45 232 # end
nikcleju@45 233 # if norm( UUtz - z )/norm(z) < 1e-8
nikcleju@45 234 # normU = 1;
nikcleju@45 235 # end
nikcleju@45 236 # end
nikcleju@45 237 #
nikcleju@45 238 # if isempty(normU)
nikcleju@45 239 # # have to actually calculate the norm
nikcleju@45 240 # if isa(U,'function_handle')
nikcleju@45 241 # [normU,cnt] = my_normest(U,Ut,length(xplug),1e-3,30);
nikcleju@45 242 # if cnt == 30, printf('Warning: norm(U) may be inaccurate\n'); end
nikcleju@45 243 # else
nikcleju@45 244 # [mU,nU] = size(U);
nikcleju@45 245 # if mU < nU, UU = U*U'; else UU = U'*U; end
nikcleju@45 246 # # last resort is to call MATLAB's "norm", which is slow
nikcleju@45 247 # if norm( UU - diag(diag(UU)),'fro') < 100*eps
nikcleju@45 248 # # this means the matrix is diagonal, so norm is easy:
nikcleju@45 249 # normU = sqrt( max(abs(diag(UU))) );
nikcleju@45 250 # elseif issparse(UU)
nikcleju@45 251 # normU = sqrt( normest(UU) );
nikcleju@45 252 # else
nikcleju@45 253 # if min(size(U)) > 2000
nikcleju@45 254 # # norm(randn(2000)) takes about 5 seconds on my PC
nikcleju@45 255 # printf('Warning: calculation of norm(U) may be slow\n');
nikcleju@45 256 # end
nikcleju@45 257 # normU = sqrt( norm(UU) );
nikcleju@45 258 # end
nikcleju@45 259 # end
nikcleju@45 260 # end
nikcleju@45 261 # opts.normU = normU;
nikcleju@45 262 #end
nikcleju@45 263 #
nikcleju@45 264 #
nikcleju@45 265 #niter = 0;
nikcleju@45 266 #Gamma = (muf/mu0)^(1/MaxIntIter);
nikcleju@45 267 #mu = mu0;
nikcleju@45 268 #Gammat= (TolVar/0.1)^(1/MaxIntIter);
nikcleju@45 269 #TolVar = 0.1;
nikcleju@45 270 #
nikcleju@45 271 #for nl=1:MaxIntIter
nikcleju@45 272 #
nikcleju@45 273 # mu = mu*Gamma;
nikcleju@45 274 # TolVar=TolVar*Gammat; opts.TolVar = TolVar;
nikcleju@45 275 # opts.xplug = xplug;
nikcleju@45 276 # if Verbose, printf('\tBeginning #s Minimization; mu = #g\n',opts.TypeMin,mu); end
nikcleju@45 277 # [xk,niter_int,res,out,optsOut] = Core_Nesterov(...
nikcleju@45 278 # A,At,b,mu,delta,opts);
nikcleju@45 279 #
nikcleju@45 280 # xplug = xk;
nikcleju@45 281 # niter = niter_int + niter;
nikcleju@45 282 #
nikcleju@45 283 # residuals = [residuals; res];
nikcleju@45 284 # outputData = [outputData; out];
nikcleju@45 285 #
nikcleju@45 286 #end
nikcleju@45 287 #opts = optsOut;
nikcleju@45 288
nikcleju@45 289 # End of original Matab code:
nikcleju@45 290 #---------------------
nikcleju@45 291
nikcleju@45 292
nikcleju@45 293 #if isempty(opts) && isnumeric(opts), opts = struct; end
nikcleju@45 294
nikcleju@45 295 #---- Set defaults
nikcleju@45 296 #fid = setOpts('fid',1);
nikcleju@45 297 opts,Verbose,userSet = setOpts(opts,'Verbose',True);
nikcleju@45 298 #function printf(varargin), fprintf(fid,varargin{:}); end
nikcleju@45 299 opts,MaxIntIter,userSet = setOpts(opts,'MaxIntIter',5,1);
nikcleju@45 300 opts,TypeMin,userSet = setOpts(opts,'TypeMin','L1');
nikcleju@45 301 opts,TolVar,userSet = setOpts(opts,'tolvar',1e-5);
nikcleju@45 302 #[U,U_userSet] = setOpts('U', @(x) x );
nikcleju@45 303 opts,U,U_userSet = setOpts(opts,'U', lambda x: x );
nikcleju@45 304 #if ~isa(U,'function_handle')
nikcleju@45 305 if hasattr(U, '__call__'):
nikcleju@45 306 opts,Ut,userSet = setOpts(opts,'Ut',None)
nikcleju@45 307 else:
nikcleju@45 308 opts,Ut,userSet = setOpts(opts,'Ut', lambda x: x )
nikcleju@45 309 #end
nikcleju@45 310 opts,xplug,userSet = setOpts(opts,'xplug',None);
nikcleju@45 311 opts,normU,userSet = setOpts(opts,'normU',None); # so we can tell if it's been set
nikcleju@45 312
nikcleju@45 313 #residuals = []; outputData = [];
nikcleju@45 314 residuals = numpy.array([])
nikcleju@45 315 outputData = numpy.array([])
nikcleju@45 316 opts,AAtinv,userSet = setOpts(opts,'AAtinv',None);
nikcleju@45 317 opts,USV,userSet = setOpts(opts,'USV',None);
nikcleju@45 318 #if ~isempty(USV)
nikcleju@45 319 if len(USV.keys()):
nikcleju@45 320 #if isstruct(USV)
nikcleju@45 321
nikcleju@45 322 Q = USV['U'] # we can't use "U" as the variable name
nikcleju@45 323 # since "U" already refers to the analysis operator
nikcleju@45 324 S = USV['S']
nikcleju@45 325 if S.ndim is 1:
nikcleju@45 326 s = S
nikcleju@45 327 else:
nikcleju@45 328 s = numpy.diag(S)
nikcleju@45 329
nikcleju@46 330 V = USV['V'];
nikcleju@45 331 #else
nikcleju@45 332 # error('opts.USV must be a structure');
nikcleju@45 333 #end
nikcleju@45 334 #end
nikcleju@45 335
nikcleju@45 336 # -- We can handle non-projections IF a (fast) routine for computing
nikcleju@45 337 # the psuedo-inverse is available.
nikcleju@45 338 # We can handle a nonzero delta, but we need the full SVD
nikcleju@45 339 #if isempty(AAtinv) && isempty(USV)
nikcleju@45 340 if (AAtinv is None) and (USV is None):
nikcleju@45 341 # Check if A is a partial isometry, i.e. if AA' = I
nikcleju@45 342 #z = randn(size(b));
nikcleju@45 343 z = numpy.random.randn(b.shape)
nikcleju@45 344 #if isa(A,'function_handle'), AAtz = A(At(z));
nikcleju@45 345 #else AAtz = A*(A'*z); end
nikcleju@45 346 if hasattr(A, '__call__'):
nikcleju@45 347 AAtz = A(At(z))
nikcleju@45 348 else:
nikcleju@45 349 #AAtz = A*(A'*z)
nikcleju@45 350 AAtz = numpy.dot(A, numpy.dot(A.T,z))
nikcleju@45 351
nikcleju@45 352 #if norm( AAtz - z )/norm(z) > 1e-8
nikcleju@45 353 if numpy.linalg.norm(AAtz - z) / numpy.linalg.norm(z) > 1e-8:
nikcleju@45 354 #error('Measurement matrix A must be a partial isometry: AA''=I');
nikcleju@45 355 print 'Measurement matrix A must be a partial isometry: AA''=I'
nikcleju@45 356 raise
nikcleju@45 357 #end
nikcleju@45 358 #end
nikcleju@45 359
nikcleju@45 360 # -- Find a initial guess if not already provided.
nikcleju@45 361 # Use least-squares solution: x_ref = A'*inv(A*A')*b
nikcleju@45 362 # If A is a projection, the least squares solution is trivial
nikcleju@45 363 #if isempty(xplug) || norm(xplug) < 1e-12
nikcleju@45 364 if xplug is None or numpy.linalg.norm(xplug) < 1e-12:
nikcleju@45 365 #if ~isempty(USV) && isempty(AAtinv)
nikcleju@45 366 if USV is not None and AAtinv is None:
nikcleju@45 367 #AAtinv = Q*diag( s.^(-2) )*Q';
nikcleju@45 368 AAtinv = numpy.dot(Q, numpy.dot(numpy.diag(s ** -2), Q.T))
nikcleju@45 369 #end
nikcleju@45 370 #if ~isempty(AAtinv)
nikcleju@45 371 if AAtinv is not None:
nikcleju@45 372 #if delta > 0 && isempty(USV)
nikcleju@45 373 if delta > 0 and USV is None:
nikcleju@45 374 #error('delta must be zero for non-projections');
nikcleju@45 375 print 'delta must be zero for non-projections'
nikcleju@45 376 raise
nikcleju@45 377 #end
nikcleju@45 378 #if isa(AAtinv,'function_handle')
nikcleju@45 379 if hasattr(AAtinv,'__call__'):
nikcleju@45 380 x_ref = AAtinv(b)
nikcleju@45 381 else:
nikcleju@45 382 x_ref = numpy.dot(AAtinv , b)
nikcleju@45 383 #end
nikcleju@45 384 else:
nikcleju@45 385 x_ref = b
nikcleju@45 386 #end
nikcleju@45 387
nikcleju@45 388 #if isa(A,'function_handle')
nikcleju@45 389 if hasattr(A,'__call__'):
nikcleju@45 390 x_ref=At(x_ref);
nikcleju@45 391 else:
nikcleju@45 392 #x_ref = A'*x_ref;
nikcleju@45 393 x_ref = numpy.dot(A.T, x_ref)
nikcleju@45 394 #end
nikcleju@45 395
nikcleju@45 396 #if isempty(xplug)
nikcleju@45 397 if xplug is None:
nikcleju@45 398 xplug = x_ref;
nikcleju@45 399 #end
nikcleju@45 400 # x_ref itself is used to calculate mu_0
nikcleju@45 401 # in the case that xplug has very small norm
nikcleju@45 402 else:
nikcleju@45 403 x_ref = xplug;
nikcleju@45 404 #end
nikcleju@45 405
nikcleju@45 406 # use x_ref, not xplug, to find mu_0
nikcleju@45 407 #if isa(U,'function_handle')
nikcleju@45 408 if hasattr(U,'__call__'):
nikcleju@45 409 Ux_ref = U(x_ref);
nikcleju@45 410 else:
nikcleju@45 411 Ux_ref = numpy.dot(U,x_ref)
nikcleju@45 412 #end
nikcleju@45 413 #switch lower(TypeMin)
nikcleju@45 414 # case 'l1'
nikcleju@45 415 # mu0 = 0.9*max(abs(Ux_ref));
nikcleju@45 416 # case 'tv'
nikcleju@45 417 # mu0 = ValMUTv(Ux_ref);
nikcleju@45 418 #end
nikcleju@45 419 if TypeMin.lower() == 'l1':
nikcleju@45 420 mu0 = 0.9*max(abs(Ux_ref))
nikcleju@45 421 elif TypeMin.lower() == 'tv':
nikcleju@45 422 #mu0 = ValMUTv(Ux_ref);
nikcleju@45 423 print 'Nic: TODO: not implemented yet'
nikcleju@45 424 raise
nikcleju@45 425
nikcleju@45 426 # -- If U was set by the user and normU not supplied, then calcuate norm(U)
nikcleju@45 427 #if U_userSet && isempty(normU)
nikcleju@45 428 if U_userSet and normU is None:
nikcleju@45 429 # simple case: U*U' = I or U'*U = I, in which case norm(U) = 1
nikcleju@45 430 #z = randn(size(xplug));
nikcleju@45 431 z = numpy.random.randn(xplug.shape)
nikcleju@45 432 #if isa(U,'function_handle'), UtUz = Ut(U(z)); else UtUz = U'*(U*z); end
nikcleju@45 433 if hasattr(U,'__call__'):
nikcleju@45 434 UtUz = Ut(U(z))
nikcleju@45 435 else:
nikcleju@45 436 UtUz = numpy.dot(U.T, numpy.dot(U,z))
nikcleju@45 437
nikcleju@45 438 if numpy.linalg.norm( UtUz - z )/numpy.linalg.norm(z) < 1e-8:
nikcleju@45 439 normU = 1;
nikcleju@45 440 else:
nikcleju@45 441 z = numpy.random.randn(Ux_ref.shape)
nikcleju@45 442 #if isa(U,'function_handle'):
nikcleju@45 443 if hasattr(U,'__call__'):
nikcleju@45 444 UUtz = U(Ut(z));
nikcleju@45 445 else:
nikcleju@45 446 #UUtz = U*(U'*z);
nikcleju@45 447 UUtz = numpy.dot(U, numpy.dot(U.T,z))
nikcleju@45 448 #end
nikcleju@45 449 if numpy.linalg.norm( UUtz - z )/numpy.linalg.norm(z) < 1e-8:
nikcleju@45 450 normU = 1;
nikcleju@45 451 #end
nikcleju@45 452 #end
nikcleju@45 453
nikcleju@45 454 #if isempty(normU)
nikcleju@45 455 if normU is None:
nikcleju@45 456 # have to actually calculate the norm
nikcleju@45 457 #if isa(U,'function_handle')
nikcleju@45 458 if hasattr(U,'__call__'):
nikcleju@45 459 #[normU,cnt] = my_normest(U,Ut,length(xplug),1e-3,30);
nikcleju@45 460 normU,cnt = my_normest(U,Ut,xplug.size,1e-3,30)
nikcleju@45 461 #if cnt == 30, printf('Warning: norm(U) may be inaccurate\n'); end
nikcleju@45 462 if cnt == 30:
nikcleju@45 463 print 'Warning: norm(U) may be inaccurate'
nikcleju@45 464 else:
nikcleju@45 465 mU,nU = U.shape
nikcleju@45 466 if mU < nU:
nikcleju@45 467 UU = numpy.dot(U,U.T)
nikcleju@45 468 else:
nikcleju@45 469 UU = numpy.dot(U.T,U)
nikcleju@45 470 # last resort is to call MATLAB's "norm", which is slow
nikcleju@45 471 #if norm( UU - diag(diag(UU)),'fro') < 100*eps
nikcleju@45 472 if numpy.linalg.norm( UU - numpy.diag(numpy.diag(UU)),'fro') < 100*numpy.finfo(float).eps:
nikcleju@45 473 # this means the matrix is diagonal, so norm is easy:
nikcleju@45 474 #normU = sqrt( max(abs(diag(UU))) );
nikcleju@45 475 normU = math.sqrt( max(abs(numpy.diag(UU))) )
nikcleju@45 476
nikcleju@45 477 # Nic: TODO: sparse not implemented
nikcleju@45 478 #elif issparse(UU)
nikcleju@45 479 # normU = sqrt( normest(UU) );
nikcleju@45 480 else:
nikcleju@45 481 if min(U.shape) > 2000:
nikcleju@45 482 # norm(randn(2000)) takes about 5 seconds on my PC
nikcleju@45 483 #printf('Warning: calculation of norm(U) may be slow\n');
nikcleju@45 484 print 'Warning: calculation of norm(U) may be slow'
nikcleju@45 485 #end
nikcleju@45 486 normU = math.sqrt( numpy.linalg.norm(UU) );
nikcleju@45 487 #end
nikcleju@45 488 #end
nikcleju@45 489 #end
nikcleju@45 490 #opts.normU = normU;
nikcleju@45 491 opts['normU'] = normU
nikcleju@45 492 #end
nikcleju@45 493
nikcleju@45 494 niter = 0;
nikcleju@45 495 Gamma = (muf/mu0)**(1.0/MaxIntIter);
nikcleju@45 496 mu = mu0;
nikcleju@45 497 Gammat = (TolVar/0.1)**(1.0/MaxIntIter);
nikcleju@45 498 TolVar = 0.1;
nikcleju@45 499
nikcleju@45 500 #for nl=1:MaxIntIter
nikcleju@45 501 for n1 in numpy.arange(MaxIntIter):
nikcleju@45 502
nikcleju@45 503 mu = mu*Gamma;
nikcleju@45 504 TolVar=TolVar*Gammat;
nikcleju@45 505 opts['TolVar'] = TolVar;
nikcleju@45 506 opts['xplug'] = xplug;
nikcleju@45 507 #if Verbose, printf('\tBeginning #s Minimization; mu = #g\n',opts.TypeMin,mu); end
nikcleju@45 508 if Verbose:
nikcleju@45 509 #printf('\tBeginning #s Minimization; mu = #g\n',opts.TypeMin,mu)
nikcleju@45 510 print ' Beginning', opts['TypeMin'],'Minimization; mu =',mu
nikcleju@45 511
nikcleju@45 512 #[xk,niter_int,res,out,optsOut] = Core_Nesterov(A,At,b,mu,delta,opts);
nikcleju@45 513 xk,niter_int,res,out,optsOut = Core_Nesterov(A,At,b,mu,delta,opts)
nikcleju@45 514
nikcleju@45 515 xplug = xk.copy();
nikcleju@45 516 niter = niter_int + niter;
nikcleju@45 517
nikcleju@45 518 #residuals = [residuals; res];
nikcleju@45 519 residuals = numpy.hstack((residuals,res))
nikcleju@45 520 #outputData = [outputData; out];
nikcleju@45 521 outputData = numpy.hstack((outputData, out))
nikcleju@45 522
nikcleju@45 523 #end
nikcleju@45 524 opts = optsOut.copy()
nikcleju@45 525
nikcleju@45 526 return xk,niter,residuals,outputData,opts
nikcleju@45 527
nikcleju@45 528
nikcleju@45 529
nikcleju@45 530 #---- internal routine for setting defaults
nikcleju@45 531 #function [var,userSet] = setOpts(field,default,mn,mx)
nikcleju@45 532 def setOpts(opts,field,default,mn=None,mx=None):
nikcleju@45 533
nikcleju@45 534 var = default
nikcleju@45 535 # has the option already been set?
nikcleju@45 536 #if ~isfield(opts,field)
nikcleju@45 537 if field in opts.keys():
nikcleju@45 538 # see if there is a capitalization problem:
nikcleju@45 539 #names = fieldnames(opts);
nikcleju@45 540 #for i = 1:length(names)
nikcleju@45 541 for key in opts.keys():
nikcleju@45 542 #if strcmpi(names{i},field)
nikcleju@45 543 if key.lower() == field.lower():
nikcleju@45 544 #opts.(field) = opts.(names{i});
nikcleju@45 545 opts[field] = opts[key]
nikcleju@45 546 #opts = rmfield(opts,names{i});
nikcleju@45 547 del opts[key]
nikcleju@45 548 break
nikcleju@45 549 #end
nikcleju@45 550 #end
nikcleju@45 551 #end
nikcleju@45 552
nikcleju@45 553 #if isfield(opts,field) && ~isempty(opts.(field))
nikcleju@45 554 if field in opts.keys() and (opts[field] is not None):
nikcleju@45 555 #var = opts.(field); # override the default
nikcleju@45 556 var = opts[field]
nikcleju@45 557 userSet = True
nikcleju@45 558 else:
nikcleju@45 559 userSet = False
nikcleju@45 560 #end
nikcleju@45 561 # perform error checking, if desired
nikcleju@45 562 #if nargin >= 3 && ~isempty(mn)
nikcleju@45 563 if mn is not None:
nikcleju@45 564 if var < mn:
nikcleju@45 565 #printf('Variable #s is #f, should be at least #f\n',...
nikcleju@45 566 # field,var,mn); error('variable out-of-bounds');
nikcleju@45 567 print 'Variable',field,'is',var,', should be at least',mn
nikcleju@45 568 raise
nikcleju@45 569 #end
nikcleju@45 570 #end
nikcleju@45 571 #if nargin >= 4 && ~isempty(mx)
nikcleju@45 572 if mx is not None:
nikcleju@45 573 if var > mx:
nikcleju@45 574 #printf('Variable #s is #f, should be at least #f\n',...
nikcleju@45 575 # field,var,mn); error('variable out-of-bounds');
nikcleju@45 576 print 'Variable',field,'is',var,', should be at most',mx
nikcleju@45 577 raise
nikcleju@45 578 #end
nikcleju@45 579 #end
nikcleju@45 580 #opts.(field) = var;
nikcleju@45 581 opts[field] = var
nikcleju@45 582
nikcleju@45 583 return opts,var,userSet
nikcleju@45 584
nikcleju@45 585 # Nic: TODO: implement TV
nikcleju@45 586 #---- internal routine for setting mu0 in the tv minimization case
nikcleju@45 587 #function th=ValMUTv(x)
nikcleju@45 588 # #N = length(x);n = floor(sqrt(N));
nikcleju@45 589 # N = b.size
nikcleju@45 590 # n = floor(sqrt(N))
nikcleju@45 591 # Dv = spdiags([reshape([-ones(n-1,n); zeros(1,n)],N,1) ...
nikcleju@45 592 # reshape([zeros(1,n); ones(n-1,n)],N,1)], [0 1], N, N);
nikcleju@45 593 # Dh = spdiags([reshape([-ones(n,n-1) zeros(n,1)],N,1) ...
nikcleju@45 594 # reshape([zeros(n,1) ones(n,n-1)],N,1)], [0 n], N, N);
nikcleju@45 595 # D = sparse([Dh;Dv]);
nikcleju@45 596 #
nikcleju@45 597 #
nikcleju@45 598 # Dhx = Dh*x;
nikcleju@45 599 # Dvx = Dv*x;
nikcleju@45 600 #
nikcleju@45 601 # sk = sqrt(abs(Dhx).^2 + abs(Dvx).^2);
nikcleju@45 602 # th = max(sk);
nikcleju@45 603 #
nikcleju@45 604 #end
nikcleju@45 605
nikcleju@45 606 #end #-- end of NESTA function
nikcleju@45 607
nikcleju@45 608 ############ POWER METHOD TO ESTIMATE NORM ###############
nikcleju@45 609 # Copied from MATLAB's "normest" function, but allows function handles, not just sparse matrices
nikcleju@45 610 #function [e,cnt] = my_normest(S,St,n,tol, maxiter)
nikcleju@45 611 def my_normest(S,St,n,tol=1e-6, maxiter=20):
nikcleju@45 612 #MY_NORMEST Estimate the matrix 2-norm via power method.
nikcleju@45 613 #if nargin < 4, tol = 1.e-6; end
nikcleju@45 614 #if nargin < 5, maxiter = 20; end
nikcleju@45 615 #if isempty(St)
nikcleju@45 616 if S is None:
nikcleju@45 617 St = S # we assume the matrix is symmetric;
nikcleju@45 618 #end
nikcleju@45 619 x = numpy.ones(n);
nikcleju@45 620 cnt = 0;
nikcleju@45 621 e = numpy.linalg.norm(x);
nikcleju@45 622 #if e == 0, return, end
nikcleju@45 623 if e == 0:
nikcleju@45 624 return e,cnt
nikcleju@45 625 x = x/e;
nikcleju@45 626 e0 = 0;
nikcleju@45 627 while abs(e-e0) > tol*e and cnt < maxiter:
nikcleju@45 628 e0 = e;
nikcleju@45 629 Sx = S(x);
nikcleju@45 630 #if nnz(Sx) == 0
nikcleju@45 631 if (Sx!=0).sum() == 0:
nikcleju@45 632 Sx = numpy.random.rand(Sx.size);
nikcleju@45 633 #end
nikcleju@45 634 e = numpy.linalg.norm(Sx);
nikcleju@45 635 x = St(Sx);
nikcleju@45 636 x = x/numpy.linalg.norm(x);
nikcleju@45 637 cnt = cnt+1;
nikcleju@45 638 #end
nikcleju@45 639 #end
nikcleju@46 640
nikcleju@46 641
nikcleju@46 642
nikcleju@46 643 #function [xk,niter,residuals,outputData,opts] = Core_Nesterov(A,At,b,mu,delta,opts)
nikcleju@46 644 def Core_Nesterov(A,At,b,mu,delta,opts):
nikcleju@46 645 # [xk,niter,residuals,outputData,opts] =Core_Nesterov(A,At,b,mu,delta,opts)
nikcleju@46 646 #
nikcleju@46 647 # Solves a L1 minimization problem under a quadratic constraint using the
nikcleju@46 648 # Nesterov algorithm, without continuation:
nikcleju@46 649 #
nikcleju@46 650 # min_x || U x ||_1 s.t. ||y - Ax||_2 <= delta
nikcleju@46 651 #
nikcleju@46 652 # If continuation is desired, see the function NESTA.m
nikcleju@46 653 #
nikcleju@46 654 # The primal prox-function is also adapted by accounting for a first guess
nikcleju@46 655 # xplug that also tends towards x_muf
nikcleju@46 656 #
nikcleju@46 657 # The observation matrix A is a projector
nikcleju@46 658 #
nikcleju@46 659 # Inputs: A and At - measurement matrix and adjoint (either a matrix, in which
nikcleju@46 660 # case At is unused, or function handles). m x n dimensions.
nikcleju@46 661 # b - Observed data, a m x 1 array
nikcleju@46 662 # muf - The desired value of mu at the last continuation step.
nikcleju@46 663 # A smaller mu leads to higher accuracy.
nikcleju@46 664 # delta - l2 error bound. This enforces how close the variable
nikcleju@46 665 # must fit the observations b, i.e. || y - Ax ||_2 <= delta
nikcleju@46 666 # If delta = 0, enforces y = Ax
nikcleju@46 667 # Common heuristic: delta = sqrt(m + 2*sqrt(2*m))*sigma;
nikcleju@46 668 # where sigma=std(noise).
nikcleju@46 669 # opts -
nikcleju@46 670 # This is a structure that contains additional options,
nikcleju@46 671 # some of which are optional.
nikcleju@46 672 # The fieldnames are case insensitive. Below
nikcleju@46 673 # are the possible fieldnames:
nikcleju@46 674 #
nikcleju@46 675 # opts.xplug - the first guess for the primal prox-function, and
nikcleju@46 676 # also the initial point for xk. By default, xplug = At(b)
nikcleju@46 677 # opts.U and opts.Ut - Analysis/Synthesis operators
nikcleju@46 678 # (either matrices of function handles).
nikcleju@46 679 # opts.normU - if opts.U is provided, this should be norm(U)
nikcleju@46 680 # opts.maxiter - max number of iterations in an inner loop.
nikcleju@46 681 # default is 10,000
nikcleju@46 682 # opts.TolVar - tolerance for the stopping criteria
nikcleju@46 683 # opts.stopTest - which stopping criteria to apply
nikcleju@46 684 # opts.stopTest == 1 : stop when the relative
nikcleju@46 685 # change in the objective function is less than
nikcleju@46 686 # TolVar
nikcleju@46 687 # opts.stopTest == 2 : stop with the l_infinity norm
nikcleju@46 688 # of difference in the xk variable is less
nikcleju@46 689 # than TolVar
nikcleju@46 690 # opts.TypeMin - if this is 'L1' (default), then
nikcleju@46 691 # minimizes a smoothed version of the l_1 norm.
nikcleju@46 692 # If this is 'tv', then minimizes a smoothed
nikcleju@46 693 # version of the total-variation norm.
nikcleju@46 694 # The string is case insensitive.
nikcleju@46 695 # opts.Verbose - if this is 0 or false, then very
nikcleju@46 696 # little output is displayed. If this is 1 or true,
nikcleju@46 697 # then output every iteration is displayed.
nikcleju@46 698 # If this is a number p greater than 1, then
nikcleju@46 699 # output is displayed every pth iteration.
nikcleju@46 700 # opts.fid - if this is 1 (default), the display is
nikcleju@46 701 # the usual Matlab screen. If this is the file-id
nikcleju@46 702 # of a file opened with fopen, then the display
nikcleju@46 703 # will be redirected to this file.
nikcleju@46 704 # opts.errFcn - if this is a function handle,
nikcleju@46 705 # then the program will evaluate opts.errFcn(xk)
nikcleju@46 706 # at every iteration and display the result.
nikcleju@46 707 # ex. opts.errFcn = @(x) norm( x - x_true )
nikcleju@46 708 # opts.outFcn - if this is a function handle,
nikcleju@46 709 # then then program will evaluate opts.outFcn(xk)
nikcleju@46 710 # at every iteration and save the results in outputData.
nikcleju@46 711 # If the result is a vector (as opposed to a scalar),
nikcleju@46 712 # it should be a row vector and not a column vector.
nikcleju@46 713 # ex. opts.outFcn = @(x) [norm( x - xtrue, 'inf' ),...
nikcleju@46 714 # norm( x - xtrue) / norm(xtrue)]
nikcleju@46 715 # opts.AAtinv - this is an experimental new option. AAtinv
nikcleju@46 716 # is the inverse of AA^*. This allows the use of a
nikcleju@46 717 # matrix A which is not a projection, but only
nikcleju@46 718 # for the noiseless (i.e. delta = 0) case.
nikcleju@46 719 # If the SVD of A is U*S*V', then AAtinv = U*(S^{-2})*U'.
nikcleju@46 720 # opts.USV - another experimental option. This supercedes
nikcleju@46 721 # the AAtinv option, so it is recommended that you
nikcleju@46 722 # do not define AAtinv. This allows the use of a matrix
nikcleju@46 723 # A which is not a projection, and works for the
nikcleju@46 724 # noisy ( i.e. delta > 0 ) case.
nikcleju@46 725 # opts.USV should contain three fields:
nikcleju@46 726 # opts.USV.U is the U from [U,S,V] = svd(A)
nikcleju@46 727 # likewise, opts.USV.S and opts.USV.V are S and V
nikcleju@46 728 # from svd(A). S may be a matrix or a vector.
nikcleju@46 729 # Outputs:
nikcleju@46 730 # xk - estimate of the solution x
nikcleju@46 731 # niter - number of iterations
nikcleju@46 732 # residuals - first column is the residual at every step,
nikcleju@46 733 # second column is the value of f_mu at every step
nikcleju@46 734 # outputData - a matrix, where each row r is the output
nikcleju@46 735 # from opts.outFcn, if supplied.
nikcleju@46 736 # opts - the structure containing the options that were used
nikcleju@46 737 #
nikcleju@46 738 # Written by: Jerome Bobin, Caltech
nikcleju@46 739 # Email: bobin@acm.caltech.edu
nikcleju@46 740 # Created: February 2009
nikcleju@46 741 # Modified: May 2009, Jerome Bobin and Stephen Becker, Caltech
nikcleju@46 742 # Modified: Nov 2009, Stephen Becker
nikcleju@46 743 #
nikcleju@46 744 # NESTA Version 1.1
nikcleju@46 745 # See also NESTA
nikcleju@46 746
nikcleju@46 747 #---- Set defaults
nikcleju@46 748 # opts = [];
nikcleju@46 749
nikcleju@46 750 #---------------------
nikcleju@46 751 # Original Matab code:
nikcleju@46 752
nikcleju@46 753 #fid = setOpts('fid',1);
nikcleju@46 754 #function printf(varargin), fprintf(fid,varargin{:}); end
nikcleju@46 755 #maxiter = setOpts('maxiter',10000,0);
nikcleju@46 756 #TolVar = setOpts('TolVar',1e-5);
nikcleju@46 757 #TypeMin = setOpts('TypeMin','L1');
nikcleju@46 758 #Verbose = setOpts('Verbose',true);
nikcleju@46 759 #errFcn = setOpts('errFcn',[]);
nikcleju@46 760 #outFcn = setOpts('outFcn',[]);
nikcleju@46 761 #stopTest = setOpts('stopTest',1,1,2);
nikcleju@46 762 #U = setOpts('U', @(x) x );
nikcleju@46 763 #if ~isa(U,'function_handle')
nikcleju@46 764 # Ut = setOpts('Ut',[]);
nikcleju@46 765 #else
nikcleju@46 766 # Ut = setOpts('Ut', @(x) x );
nikcleju@46 767 #end
nikcleju@46 768 #xplug = setOpts('xplug',[]);
nikcleju@46 769 #normU = setOpts('normU',1);
nikcleju@46 770 #
nikcleju@46 771 #if delta < 0, error('delta must be greater or equal to zero'); end
nikcleju@46 772 #
nikcleju@46 773 #if isa(A,'function_handle')
nikcleju@46 774 # Atfun = At;
nikcleju@46 775 # Afun = A;
nikcleju@46 776 #else
nikcleju@46 777 # Atfun = @(x) A'*x;
nikcleju@46 778 # Afun = @(x) A*x;
nikcleju@46 779 #end
nikcleju@46 780 #Atb = Atfun(b);
nikcleju@46 781 #
nikcleju@46 782 #AAtinv = setOpts('AAtinv',[]);
nikcleju@46 783 #USV = setOpts('USV',[]);
nikcleju@46 784 #if ~isempty(USV)
nikcleju@46 785 # if isstruct(USV)
nikcleju@46 786 # Q = USV.U; # we can't use "U" as the variable name
nikcleju@46 787 # # since "U" already refers to the analysis operator
nikcleju@46 788 # S = USV.S;
nikcleju@46 789 # if isvector(S), s = S; S = diag(s);
nikcleju@46 790 # else s = diag(S); end
nikcleju@46 791 # V = USV.V;
nikcleju@46 792 # else
nikcleju@46 793 # error('opts.USV must be a structure');
nikcleju@46 794 # end
nikcleju@46 795 # if isempty(AAtinv)
nikcleju@46 796 # AAtinv = Q*diag( s.^(-2) )*Q';
nikcleju@46 797 # end
nikcleju@46 798 #end
nikcleju@46 799 ## --- for A not a projection (experimental)
nikcleju@46 800 #if ~isempty(AAtinv)
nikcleju@46 801 # if isa(AAtinv,'function_handle')
nikcleju@46 802 # AAtinv_fun = AAtinv;
nikcleju@46 803 # else
nikcleju@46 804 # AAtinv_fun = @(x) AAtinv * x;
nikcleju@46 805 # end
nikcleju@46 806 #
nikcleju@46 807 # AtAAtb = Atfun( AAtinv_fun(b) );
nikcleju@46 808 #
nikcleju@46 809 #else
nikcleju@46 810 # # We assume it's a projection
nikcleju@46 811 # AtAAtb = Atb;
nikcleju@46 812 # AAtinv_fun = @(x) x;
nikcleju@46 813 #end
nikcleju@46 814 #
nikcleju@46 815 #if isempty(xplug)
nikcleju@46 816 # xplug = AtAAtb;
nikcleju@46 817 #end
nikcleju@46 818 #
nikcleju@46 819 ##---- Initialization
nikcleju@46 820 #N = length(xplug);
nikcleju@46 821 #wk = zeros(N,1);
nikcleju@46 822 #xk = xplug;
nikcleju@46 823 #
nikcleju@46 824 #
nikcleju@46 825 ##---- Init Variables
nikcleju@46 826 #Ak= 0;
nikcleju@46 827 #Lmu = normU/mu;
nikcleju@46 828 #yk = xk;
nikcleju@46 829 #zk = xk;
nikcleju@46 830 #fmean = realmin/10;
nikcleju@46 831 #OK = 0;
nikcleju@46 832 #n = floor(sqrt(N));
nikcleju@46 833 #
nikcleju@46 834 ##---- Computing Atb
nikcleju@46 835 #Atb = Atfun(b);
nikcleju@46 836 #Axk = Afun(xk);# only needed if you want to see the residuals
nikcleju@46 837 ## Axplug = Axk;
nikcleju@46 838 #
nikcleju@46 839 #
nikcleju@46 840 ##---- TV Minimization
nikcleju@46 841 #if strcmpi(TypeMin,'TV')
nikcleju@46 842 # Lmu = 8*Lmu;
nikcleju@46 843 # Dv = spdiags([reshape([-ones(n-1,n); zeros(1,n)],N,1) ...
nikcleju@46 844 # reshape([zeros(1,n); ones(n-1,n)],N,1)], [0 1], N, N);
nikcleju@46 845 # Dh = spdiags([reshape([-ones(n,n-1) zeros(n,1)],N,1) ...
nikcleju@46 846 # reshape([zeros(n,1) ones(n,n-1)],N,1)], [0 n], N, N);
nikcleju@46 847 # D = sparse([Dh;Dv]);
nikcleju@46 848 #end
nikcleju@46 849 #
nikcleju@46 850 #
nikcleju@46 851 #Lmu1 = 1/Lmu;
nikcleju@46 852 ## SLmu = sqrt(Lmu);
nikcleju@46 853 ## SLmu1 = 1/sqrt(Lmu);
nikcleju@46 854 #lambdaY = 0;
nikcleju@46 855 #lambdaZ = 0;
nikcleju@46 856 #
nikcleju@46 857 ##---- setup data storage variables
nikcleju@46 858 #[DISPLAY_ERROR, RECORD_DATA] = deal(false);
nikcleju@46 859 #outputData = deal([]);
nikcleju@46 860 #residuals = zeros(maxiter,2);
nikcleju@46 861 #if ~isempty(errFcn), DISPLAY_ERROR = true; end
nikcleju@46 862 #if ~isempty(outFcn) && nargout >= 4
nikcleju@46 863 # RECORD_DATA = true;
nikcleju@46 864 # outputData = zeros(maxiter, size(outFcn(xplug),2) );
nikcleju@46 865 #end
nikcleju@46 866 #
nikcleju@46 867 #for k = 0:maxiter-1,
nikcleju@46 868 #
nikcleju@46 869 # #---- Dual problem
nikcleju@46 870 #
nikcleju@46 871 # if strcmpi(TypeMin,'L1') [df,fx,val,uk] = Perform_L1_Constraint(xk,mu,U,Ut);end
nikcleju@46 872 #
nikcleju@46 873 # if strcmpi(TypeMin,'TV') [df,fx] = Perform_TV_Constraint(xk,mu,Dv,Dh,D,U,Ut);end
nikcleju@46 874 #
nikcleju@46 875 # #---- Primal Problem
nikcleju@46 876 #
nikcleju@46 877 # #---- Updating yk
nikcleju@46 878 #
nikcleju@46 879 # #
nikcleju@46 880 # # yk = Argmin_x Lmu/2 ||x - xk||_l2^2 + <df,x-xk> s.t. ||b-Ax||_l2 < delta
nikcleju@46 881 # # Let xp be sqrt(Lmu) (x-xk), dfp be df/sqrt(Lmu), bp be sqrt(Lmu)(b- Axk) and deltap be sqrt(Lmu)delta
nikcleju@46 882 # # yk = xk + 1/sqrt(Lmu) Argmin_xp 1/2 || xp ||_2^2 + <dfp,xp> s.t. || bp - Axp ||_2 < deltap
nikcleju@46 883 # #
nikcleju@46 884 #
nikcleju@46 885 #
nikcleju@46 886 # cp = xk - 1/Lmu*df; # this is "q" in eq. (3.7) in the paper
nikcleju@46 887 #
nikcleju@46 888 # Acp = Afun( cp );
nikcleju@46 889 # if ~isempty(AAtinv) && isempty(USV)
nikcleju@46 890 # AtAcp = Atfun( AAtinv_fun( Acp ) );
nikcleju@46 891 # else
nikcleju@46 892 # AtAcp = Atfun( Acp );
nikcleju@46 893 # end
nikcleju@46 894 #
nikcleju@46 895 # residuals(k+1,1) = norm( b-Axk); # the residual
nikcleju@46 896 # residuals(k+1,2) = fx; # the value of the objective
nikcleju@46 897 # #--- if user has supplied a function, apply it to the iterate
nikcleju@46 898 # if RECORD_DATA
nikcleju@46 899 # outputData(k+1,:) = outFcn(xk);
nikcleju@46 900 # end
nikcleju@46 901 #
nikcleju@46 902 # if delta > 0
nikcleju@46 903 # if ~isempty(USV)
nikcleju@46 904 # # there are more efficient methods, but we're assuming
nikcleju@46 905 # # that A is negligible compared to U and Ut.
nikcleju@46 906 # # Here we make the change of variables x <-- x - xk
nikcleju@46 907 # # and df <-- df/L
nikcleju@46 908 # dfp = -Lmu1*df; Adfp = -(Axk - Acp);
nikcleju@46 909 # bp = b - Axk;
nikcleju@46 910 # deltap = delta;
nikcleju@46 911 # # Check if we even need to project:
nikcleju@46 912 # if norm( Adfp - bp ) < deltap
nikcleju@46 913 # lambdaY = 0; projIter = 0;
nikcleju@46 914 # # i.e. projection = dfp;
nikcleju@46 915 # yk = xk + dfp;
nikcleju@46 916 # Ayk = Axk + Adfp;
nikcleju@46 917 # else
nikcleju@46 918 # lambdaY_old = lambdaY;
nikcleju@46 919 # [projection,projIter,lambdaY] = fastProjection(Q,S,V,dfp,bp,...
nikcleju@46 920 # deltap, .999*lambdaY_old );
nikcleju@46 921 # if lambdaY > 0, disp('lambda is positive!'); keyboard; end
nikcleju@46 922 # yk = xk + projection;
nikcleju@46 923 # Ayk = Afun(yk);
nikcleju@46 924 # # DEBUGGING
nikcleju@46 925 ## if projIter == 50
nikcleju@46 926 ## fprintf('\n Maxed out iterations at y\n');
nikcleju@46 927 ## keyboard
nikcleju@46 928 ## end
nikcleju@46 929 # end
nikcleju@46 930 # else
nikcleju@46 931 # lambda = max(0,Lmu*(norm(b-Acp)/delta - 1));gamma = lambda/(lambda + Lmu);
nikcleju@46 932 # yk = lambda/Lmu*(1-gamma)*Atb + cp - gamma*AtAcp;
nikcleju@46 933 # # for calculating the residual, we'll avoid calling A()
nikcleju@46 934 # # by storing A(yk) here (using A'*A = I):
nikcleju@46 935 # Ayk = lambda/Lmu*(1-gamma)*b + Acp - gamma*Acp;
nikcleju@46 936 # end
nikcleju@46 937 # else
nikcleju@46 938 # # if delta is 0, the projection is simplified:
nikcleju@46 939 # yk = AtAAtb + cp - AtAcp;
nikcleju@46 940 # Ayk = b;
nikcleju@46 941 # end
nikcleju@46 942 #
nikcleju@46 943 # # DEBUGGING
nikcleju@46 944 ## if norm( Ayk - b ) > (1.05)*delta
nikcleju@46 945 ## fprintf('\nAyk failed projection test\n');
nikcleju@46 946 ## keyboard;
nikcleju@46 947 ## end
nikcleju@46 948 #
nikcleju@46 949 # #--- Stopping criterion
nikcleju@46 950 # qp = abs(fx - mean(fmean))/mean(fmean);
nikcleju@46 951 #
nikcleju@46 952 # switch stopTest
nikcleju@46 953 # case 1
nikcleju@46 954 # # look at the relative change in function value
nikcleju@46 955 # if qp <= TolVar && OK; break;end
nikcleju@46 956 # if qp <= TolVar && ~OK; OK=1; end
nikcleju@46 957 # case 2
nikcleju@46 958 # # look at the l_inf change from previous iterate
nikcleju@46 959 # if k >= 1 && norm( xk - xold, 'inf' ) <= TolVar
nikcleju@46 960 # break
nikcleju@46 961 # end
nikcleju@46 962 # end
nikcleju@46 963 # fmean = [fx,fmean];
nikcleju@46 964 # if (length(fmean) > 10) fmean = fmean(1:10);end
nikcleju@46 965 #
nikcleju@46 966 #
nikcleju@46 967 #
nikcleju@46 968 # #--- Updating zk
nikcleju@46 969 #
nikcleju@46 970 # apk =0.5*(k+1);
nikcleju@46 971 # Ak = Ak + apk;
nikcleju@46 972 # tauk = 2/(k+3);
nikcleju@46 973 #
nikcleju@46 974 # wk = apk*df + wk;
nikcleju@46 975 #
nikcleju@46 976 # #
nikcleju@46 977 # # zk = Argmin_x Lmu/2 ||b - Ax||_l2^2 + Lmu/2||x - xplug ||_2^2+ <wk,x-xk>
nikcleju@46 978 # # s.t. ||b-Ax||_l2 < delta
nikcleju@46 979 # #
nikcleju@46 980 #
nikcleju@46 981 # cp = xplug - 1/Lmu*wk;
nikcleju@46 982 #
nikcleju@46 983 # Acp = Afun( cp );
nikcleju@46 984 # if ~isempty( AAtinv ) && isempty(USV)
nikcleju@46 985 # AtAcp = Atfun( AAtinv_fun( Acp ) );
nikcleju@46 986 # else
nikcleju@46 987 # AtAcp = Atfun( Acp );
nikcleju@46 988 # end
nikcleju@46 989 #
nikcleju@46 990 # if delta > 0
nikcleju@46 991 # if ~isempty(USV)
nikcleju@46 992 # # Make the substitution wk <-- wk/K
nikcleju@46 993 #
nikcleju@46 994 ## dfp = (xplug - Lmu1*wk); # = cp
nikcleju@46 995 ## Adfp= (Axplug - Acp);
nikcleju@46 996 # dfp = cp; Adfp = Acp;
nikcleju@46 997 # bp = b;
nikcleju@46 998 # deltap = delta;
nikcleju@46 999 ## dfp = SLmu*xplug - SLmu1*wk;
nikcleju@46 1000 ## bp = SLmu*b;
nikcleju@46 1001 ## deltap = SLmu*delta;
nikcleju@46 1002 #
nikcleju@46 1003 # # See if we even need to project:
nikcleju@46 1004 # if norm( Adfp - bp ) < deltap
nikcleju@46 1005 # zk = dfp;
nikcleju@46 1006 # Azk = Adfp;
nikcleju@46 1007 # else
nikcleju@46 1008 # [projection,projIter,lambdaZ] = fastProjection(Q,S,V,dfp,bp,...
nikcleju@46 1009 # deltap, .999*lambdaZ );
nikcleju@46 1010 # if lambdaZ > 0, disp('lambda is positive!'); keyboard; end
nikcleju@46 1011 # zk = projection;
nikcleju@46 1012 # # zk = SLmu1*projection;
nikcleju@46 1013 # Azk = Afun(zk);
nikcleju@46 1014 #
nikcleju@46 1015 # # DEBUGGING:
nikcleju@46 1016 ## if projIter == 50
nikcleju@46 1017 ## fprintf('\n Maxed out iterations at z\n');
nikcleju@46 1018 ## keyboard
nikcleju@46 1019 ## end
nikcleju@46 1020 # end
nikcleju@46 1021 # else
nikcleju@46 1022 # lambda = max(0,Lmu*(norm(b-Acp)/delta - 1));gamma = lambda/(lambda + Lmu);
nikcleju@46 1023 # zk = lambda/Lmu*(1-gamma)*Atb + cp - gamma*AtAcp;
nikcleju@46 1024 # # for calculating the residual, we'll avoid calling A()
nikcleju@46 1025 # # by storing A(zk) here (using A'*A = I):
nikcleju@46 1026 # Azk = lambda/Lmu*(1-gamma)*b + Acp - gamma*Acp;
nikcleju@46 1027 # end
nikcleju@46 1028 # else
nikcleju@46 1029 # # if delta is 0, this is simplified:
nikcleju@46 1030 # zk = AtAAtb + cp - AtAcp;
nikcleju@46 1031 # Azk = b;
nikcleju@46 1032 # end
nikcleju@46 1033 #
nikcleju@46 1034 # # DEBUGGING
nikcleju@46 1035 ## if norm( Ayk - b ) > (1.05)*delta
nikcleju@46 1036 ## fprintf('\nAzk failed projection test\n');
nikcleju@46 1037 ## keyboard;
nikcleju@46 1038 ## end
nikcleju@46 1039 #
nikcleju@46 1040 # #--- Updating xk
nikcleju@46 1041 #
nikcleju@46 1042 # xkp = tauk*zk + (1-tauk)*yk;
nikcleju@46 1043 # xold = xk;
nikcleju@46 1044 # xk=xkp;
nikcleju@46 1045 # Axk = tauk*Azk + (1-tauk)*Ayk;
nikcleju@46 1046 #
nikcleju@46 1047 # if ~mod(k,10), Axk = Afun(xk); end # otherwise slowly lose precision
nikcleju@46 1048 # # DEBUG
nikcleju@46 1049 ## if norm(Axk - Afun(xk) ) > 1e-6, disp('error with Axk'); keyboard; end
nikcleju@46 1050 #
nikcleju@46 1051 # #--- display progress if desired
nikcleju@46 1052 # if ~mod(k+1,Verbose )
nikcleju@46 1053 # printf('Iter: #3d ~ fmu: #.3e ~ Rel. Variation of fmu: #.2e ~ Residual: #.2e',...
nikcleju@46 1054 # k+1,fx,qp,residuals(k+1,1) );
nikcleju@46 1055 # #--- if user has supplied a function to calculate the error,
nikcleju@46 1056 # # apply it to the current iterate and dislay the output:
nikcleju@46 1057 # if DISPLAY_ERROR, printf(' ~ Error: #.2e',errFcn(xk)); end
nikcleju@46 1058 # printf('\n');
nikcleju@46 1059 # end
nikcleju@46 1060 # if abs(fx)>1e20 || abs(residuals(k+1,1)) >1e20 || isnan(fx)
nikcleju@46 1061 # error('Nesta: possible divergence or NaN. Bad estimate of ||A''A||?');
nikcleju@46 1062 # end
nikcleju@46 1063 #
nikcleju@46 1064 #end
nikcleju@46 1065 #
nikcleju@46 1066 #niter = k+1;
nikcleju@46 1067 #
nikcleju@46 1068 ##-- truncate output vectors
nikcleju@46 1069 #residuals = residuals(1:niter,:);
nikcleju@46 1070 #if RECORD_DATA, outputData = outputData(1:niter,:); end
nikcleju@46 1071
nikcleju@46 1072 # End of original Matab code
nikcleju@46 1073 #---------------------
nikcleju@46 1074
nikcleju@46 1075 #fid = setOpts('fid',1);
nikcleju@46 1076 #function printf(varargin), fprintf(fid,varargin{:}); end
nikcleju@46 1077 opts,maxiter,userSet = setOpts(opts,'maxiter',10000,0);
nikcleju@46 1078 opts,TolVar,userSet = setOpts(opts,'TolVar',1e-5);
nikcleju@46 1079 opts,TypeMin,userSet = setOpts(opts,'TypeMin','L1');
nikcleju@46 1080 opts,Verbose,userSet = setOpts(opts,'Verbose',True);
nikcleju@46 1081 opts,errFcn,userSet = setOpts(opts,'errFcn',None);
nikcleju@46 1082 opts,outFcn,userSet = setOpts(opts,'outFcn',None);
nikcleju@46 1083 opts,stopTest,userSet = setOpts(opts,'stopTest',1,1,2);
nikcleju@46 1084 opts,U,userSet = setOpts(opts,'U',lambda x: x );
nikcleju@46 1085 #if ~isa(U,'function_handle')
nikcleju@46 1086 if hasattr(U,'__call__'):
nikcleju@46 1087 opts,Ut,userSet = setOpts(opts,'Ut',None);
nikcleju@46 1088 else:
nikcleju@46 1089 opts,Ut,userSet = setOpts(opts,'Ut', lambda x: x );
nikcleju@46 1090 #end
nikcleju@46 1091 opts,xplug,userSet = setOpts(opts,'xplug',None);
nikcleju@46 1092 opts,normU,userSet = setOpts(opts,'normU',1);
nikcleju@46 1093
nikcleju@46 1094 if delta < 0:
nikcleju@46 1095 print 'delta must be greater or equal to zero'
nikcleju@46 1096 raise
nikcleju@46 1097
nikcleju@46 1098 if hasattr(A,'__call__'):
nikcleju@46 1099 Atfun = At;
nikcleju@46 1100 Afun = A;
nikcleju@46 1101 else:
nikcleju@46 1102 Atfun = lambda x: numpy.dot(A.T,x)
nikcleju@46 1103 Afun = lambda x: numpy.dot(A,x)
nikcleju@46 1104 #end
nikcleju@46 1105 Atb = Atfun(b);
nikcleju@46 1106
nikcleju@46 1107 opts,AAtinv,userSet = setOpts(opts,'AAtinv',None);
nikcleju@46 1108 opts,USV,userSet = setOpts(opts,'USV',None);
nikcleju@46 1109 if USV is not None:
nikcleju@46 1110 #if isstruct(USV)
nikcleju@46 1111 Q = USV['U']; # we can't use "U" as the variable name
nikcleju@46 1112 # since "U" already refers to the analysis operator
nikcleju@46 1113 S = USV['S'];
nikcleju@46 1114 #if isvector(S), s = S; S = diag(s);
nikcleju@46 1115 #else s = diag(S); end
nikcleju@46 1116 if S.ndim is 1:
nikcleju@46 1117 s = S
nikcleju@46 1118 else:
nikcleju@46 1119 s = numpy.diag(S)
nikcleju@46 1120
nikcleju@46 1121 V = USV['V'];
nikcleju@46 1122 #else
nikcleju@46 1123 # error('opts.USV must be a structure');
nikcleju@46 1124 #end
nikcleju@46 1125 #if isempty(AAtinv)
nikcleju@46 1126 if AAtinv is None:
nikcleju@46 1127 #AAtinv = Q*diag( s.^(-2) )*Q';
nikcleju@46 1128 AAtinv = numpy.dot(Q, numpy.dot(numpy.diag(s ** -2), Q.T))
nikcleju@46 1129 #end
nikcleju@46 1130 #end
nikcleju@46 1131 # --- for A not a projection (experimental)
nikcleju@46 1132 #if ~isempty(AAtinv)
nikcleju@46 1133 if AAtinv is not None:
nikcleju@46 1134 #if isa(AAtinv,'function_handle')
nikcleju@46 1135 if hasattr(AAtinv, '__call__'):
nikcleju@46 1136 AAtinv_fun = AAtinv;
nikcleju@46 1137 else:
nikcleju@46 1138 AAtinv_fun = lambda x: numpy.dot(AAtinv,x)
nikcleju@46 1139 #end
nikcleju@46 1140
nikcleju@46 1141 AtAAtb = Atfun( AAtinv_fun(b) );
nikcleju@46 1142
nikcleju@46 1143 else:
nikcleju@46 1144 # We assume it's a projection
nikcleju@46 1145 AtAAtb = Atb;
nikcleju@46 1146 AAtinv_fun = lambda x: x;
nikcleju@46 1147 #end
nikcleju@46 1148
nikcleju@46 1149 if xplug == None:
nikcleju@46 1150 xplug = AtAAtb.copy();
nikcleju@46 1151 #end
nikcleju@46 1152
nikcleju@46 1153 #---- Initialization
nikcleju@46 1154 #N = length(xplug);
nikcleju@46 1155 N = len(xplug)
nikcleju@46 1156 #wk = zeros(N,1);
nikcleju@46 1157 wk = numpy.zeros(N)
nikcleju@46 1158 xk = xplug.copy()
nikcleju@46 1159
nikcleju@46 1160
nikcleju@46 1161 #---- Init Variables
nikcleju@46 1162 Ak = 0.0;
nikcleju@46 1163 Lmu = normU/mu;
nikcleju@46 1164 yk = xk.copy();
nikcleju@46 1165 zk = xk.copy();
nikcleju@46 1166 fmean = numpy.finfo(float).tiny/10.0;
nikcleju@46 1167 OK = 0;
nikcleju@46 1168 n = math.floor(math.sqrt(N));
nikcleju@46 1169
nikcleju@46 1170 #---- Computing Atb
nikcleju@46 1171 Atb = Atfun(b);
nikcleju@46 1172 Axk = Afun(xk);# only needed if you want to see the residuals
nikcleju@46 1173 # Axplug = Axk;
nikcleju@46 1174
nikcleju@46 1175
nikcleju@46 1176 #---- TV Minimization
nikcleju@46 1177 if TypeMin == 'TV':
nikcleju@46 1178 print 'Nic:TODO: TV minimization not yet implemented!'
nikcleju@46 1179 raise
nikcleju@46 1180 #if strcmpi(TypeMin,'TV')
nikcleju@46 1181 # Lmu = 8*Lmu;
nikcleju@46 1182 # Dv = spdiags([reshape([-ones(n-1,n); zeros(1,n)],N,1) ...
nikcleju@46 1183 # reshape([zeros(1,n); ones(n-1,n)],N,1)], [0 1], N, N);
nikcleju@46 1184 # Dh = spdiags([reshape([-ones(n,n-1) zeros(n,1)],N,1) ...
nikcleju@46 1185 # reshape([zeros(n,1) ones(n,n-1)],N,1)], [0 n], N, N);
nikcleju@46 1186 # D = sparse([Dh;Dv]);
nikcleju@46 1187 #end
nikcleju@46 1188
nikcleju@46 1189
nikcleju@46 1190 Lmu1 = 1.0/Lmu;
nikcleju@46 1191 # SLmu = sqrt(Lmu);
nikcleju@46 1192 # SLmu1 = 1/sqrt(Lmu);
nikcleju@46 1193 lambdaY = 0.;
nikcleju@46 1194 lambdaZ = 0.;
nikcleju@46 1195
nikcleju@46 1196 #---- setup data storage variables
nikcleju@46 1197 #[DISPLAY_ERROR, RECORD_DATA] = deal(false);
nikcleju@46 1198 DISPLAY_ERROR = False
nikcleju@46 1199 RECORD_DATA = False
nikcleju@46 1200 #outputData = deal([]);
nikcleju@46 1201 outputData = None
nikcleju@46 1202 residuals = numpy.zeros((maxiter,2))
nikcleju@46 1203 #if ~isempty(errFcn), DISPLAY_ERROR = true; end
nikcleju@46 1204 if errFcn is not None:
nikcleju@46 1205 DISPLAY_ERROR = True
nikcleju@46 1206 #if ~isempty(outFcn) && nargout >= 4
nikcleju@46 1207 if outFcn is not None: # Output max number of arguments
nikcleju@46 1208 RECORD_DATA = True
nikcleju@46 1209 outputData = numpy.zeros(maxiter, outFcn(xplug).shape[1]);
nikcleju@46 1210 #end
nikcleju@46 1211
nikcleju@46 1212 #for k = 0:maxiter-1,
nikcleju@46 1213 for k in numpy.arange(maxiter):
nikcleju@46 1214
nikcleju@46 1215 #---- Dual problem
nikcleju@46 1216
nikcleju@46 1217 #if strcmpi(TypeMin,'L1') [df,fx,val,uk] = Perform_L1_Constraint(xk,mu,U,Ut);end
nikcleju@46 1218 if TypeMin == 'L1':
nikcleju@46 1219 df,fx,val,uk = Perform_L1_Constraint(xk,mu,U,Ut)
nikcleju@46 1220
nikcleju@46 1221 # Nic: TODO: TV not implemented yet !
nikcleju@46 1222 #if strcmpi(TypeMin,'TV') [df,fx] = Perform_TV_Constraint(xk,mu,Dv,Dh,D,U,Ut);end
nikcleju@46 1223
nikcleju@46 1224 #---- Primal Problem
nikcleju@46 1225
nikcleju@46 1226 #---- Updating yk
nikcleju@46 1227
nikcleju@46 1228 #
nikcleju@46 1229 # yk = Argmin_x Lmu/2 ||x - xk||_l2^2 + <df,x-xk> s.t. ||b-Ax||_l2 < delta
nikcleju@46 1230 # Let xp be sqrt(Lmu) (x-xk), dfp be df/sqrt(Lmu), bp be sqrt(Lmu)(b- Axk) and deltap be sqrt(Lmu)delta
nikcleju@46 1231 # yk = xk + 1/sqrt(Lmu) Argmin_xp 1/2 || xp ||_2^2 + <dfp,xp> s.t. || bp - Axp ||_2 < deltap
nikcleju@46 1232 #
nikcleju@46 1233
nikcleju@46 1234
nikcleju@46 1235 cp = xk - 1./Lmu*df; # this is "q" in eq. (3.7) in the paper
nikcleju@46 1236
nikcleju@46 1237 Acp = Afun( cp );
nikcleju@46 1238 #if ~isempty(AAtinv) && isempty(USV)
nikcleju@46 1239 if AAtinv is not None and USV is None:
nikcleju@46 1240 AtAcp = Atfun( AAtinv_fun( Acp ) );
nikcleju@46 1241 else:
nikcleju@46 1242 AtAcp = Atfun( Acp );
nikcleju@46 1243 #end
nikcleju@46 1244
nikcleju@46 1245 #residuals(k+1,1) = norm( b-Axk); # the residual
nikcleju@46 1246 residuals[k,0] = numpy.linalg.norm(b-Axk)
nikcleju@46 1247 #residuals(k+1,2) = fx; # the value of the objective
nikcleju@46 1248 residuals[k,1] = fx
nikcleju@46 1249 #--- if user has supplied a function, apply it to the iterate
nikcleju@46 1250 if RECORD_DATA:
nikcleju@46 1251 outputData[k+1,:] = outFcn(xk);
nikcleju@46 1252 #end
nikcleju@46 1253
nikcleju@46 1254 if delta > 0:
nikcleju@46 1255 #if ~isempty(USV)
nikcleju@46 1256 if USV is not None:
nikcleju@46 1257 # there are more efficient methods, but we're assuming
nikcleju@46 1258 # that A is negligible compared to U and Ut.
nikcleju@46 1259 # Here we make the change of variables x <-- x - xk
nikcleju@46 1260 # and df <-- df/L
nikcleju@46 1261 dfp = -Lmu1*df;
nikcleju@46 1262 Adfp = -(Axk - Acp);
nikcleju@46 1263 bp = b - Axk;
nikcleju@46 1264 deltap = delta;
nikcleju@46 1265 # Check if we even need to project:
nikcleju@46 1266 if numpy.linalg.norm( Adfp - bp ) < deltap:
nikcleju@46 1267 lambdaY = 0.
nikcleju@46 1268 projIter = 0;
nikcleju@46 1269 # i.e. projection = dfp;
nikcleju@46 1270 yk = xk + dfp;
nikcleju@46 1271 Ayk = Axk + Adfp;
nikcleju@46 1272 else:
nikcleju@46 1273 lambdaY_old = lambdaY.copy();
nikcleju@46 1274 #[projection,projIter,lambdaY] = fastProjection(Q,S,V,dfp,bp,deltap, .999*lambdaY_old );
nikcleju@46 1275 projection,projIter,lambdaY = fastProjection(Q,S,V,dfp,bp,deltap, .999*lambdaY_old )
nikcleju@46 1276 #if lambdaY > 0, disp('lambda is positive!'); keyboard; end
nikcleju@46 1277 if lambdaY > 0:
nikcleju@46 1278 print 'lambda is positive!'
nikcleju@46 1279 raise
nikcleju@46 1280 yk = xk + projection;
nikcleju@46 1281 Ayk = Afun(yk);
nikcleju@46 1282 # DEBUGGING
nikcleju@46 1283 # if projIter == 50
nikcleju@46 1284 # fprintf('\n Maxed out iterations at y\n');
nikcleju@46 1285 # keyboard
nikcleju@46 1286 # end
nikcleju@46 1287 #end
nikcleju@46 1288 else:
nikcleju@46 1289 lambdaa = max(0,Lmu*(numpy.linalg.norm(b-Acp)/delta - 1))
nikcleju@46 1290 gamma = lambdaa/(lambdaa + Lmu);
nikcleju@46 1291 yk = lambdaa/Lmu*(1-gamma)*Atb + cp - gamma*AtAcp;
nikcleju@46 1292 # for calculating the residual, we'll avoid calling A()
nikcleju@46 1293 # by storing A(yk) here (using A'*A = I):
nikcleju@46 1294 Ayk = lambdaa/Lmu*(1-gamma)*b + Acp - gamma*Acp;
nikcleju@46 1295 #end
nikcleju@46 1296 else:
nikcleju@46 1297 # if delta is 0, the projection is simplified:
nikcleju@46 1298 yk = AtAAtb + cp - AtAcp;
nikcleju@46 1299 Ayk = b.copy();
nikcleju@46 1300 #end
nikcleju@46 1301
nikcleju@46 1302 # DEBUGGING
nikcleju@46 1303 # if norm( Ayk - b ) > (1.05)*delta
nikcleju@46 1304 # fprintf('\nAyk failed projection test\n');
nikcleju@46 1305 # keyboard;
nikcleju@46 1306 # end
nikcleju@46 1307
nikcleju@46 1308 #--- Stopping criterion
nikcleju@46 1309 qp = abs(fx - numpy.mean(fmean))/numpy.mean(fmean);
nikcleju@46 1310
nikcleju@46 1311 #switch stopTest
nikcleju@46 1312 # case 1
nikcleju@46 1313 if stopTest == 1:
nikcleju@46 1314 # look at the relative change in function value
nikcleju@46 1315 #if qp <= TolVar && OK; break;end
nikcleju@46 1316 if qp <= TolVar and OK:
nikcleju@46 1317 break
nikcleju@46 1318 #if qp <= TolVar && ~OK; OK=1; end
nikcleju@46 1319 if qp <= TolVar and not OK:
nikcleju@46 1320 OK = 1
nikcleju@46 1321 elif stopTest == 2:
nikcleju@46 1322 # look at the l_inf change from previous iterate
nikcleju@46 1323 if k >= 1 and numpy.linalg.norm( xk - xold, 'inf' ) <= TolVar:
nikcleju@46 1324 break
nikcleju@46 1325 #end
nikcleju@46 1326 #end
nikcleju@46 1327 #fmean = [fx,fmean];
nikcleju@46 1328 fmean = numpy.hstack((fx,fmean));
nikcleju@46 1329 if (len(fmean) > 10):
nikcleju@46 1330 fmean = fmean[:10]
nikcleju@46 1331
nikcleju@46 1332
nikcleju@46 1333
nikcleju@46 1334 #--- Updating zk
nikcleju@46 1335
nikcleju@46 1336 apk = 0.5*(k+1);
nikcleju@46 1337 Ak = Ak + apk;
nikcleju@46 1338 tauk = 2/(k+3);
nikcleju@46 1339
nikcleju@46 1340 wk = apk*df + wk;
nikcleju@46 1341
nikcleju@46 1342 #
nikcleju@46 1343 # zk = Argmin_x Lmu/2 ||b - Ax||_l2^2 + Lmu/2||x - xplug ||_2^2+ <wk,x-xk>
nikcleju@46 1344 # s.t. ||b-Ax||_l2 < delta
nikcleju@46 1345 #
nikcleju@46 1346
nikcleju@46 1347 cp = xplug - 1.0/Lmu*wk;
nikcleju@46 1348
nikcleju@46 1349 Acp = Afun( cp );
nikcleju@46 1350 #if ~isempty( AAtinv ) && isempty(USV)
nikcleju@46 1351 if AAtinv is not None and USV is None:
nikcleju@46 1352 AtAcp = Atfun( AAtinv_fun( Acp ) );
nikcleju@46 1353 else:
nikcleju@46 1354 AtAcp = Atfun( Acp );
nikcleju@46 1355 #end
nikcleju@46 1356
nikcleju@46 1357 if delta > 0:
nikcleju@46 1358 #if ~isempty(USV)
nikcleju@46 1359 if USV is not None:
nikcleju@46 1360 # Make the substitution wk <-- wk/K
nikcleju@46 1361
nikcleju@46 1362 # dfp = (xplug - Lmu1*wk); # = cp
nikcleju@46 1363 # Adfp= (Axplug - Acp);
nikcleju@46 1364 dfp = cp.copy()
nikcleju@46 1365 Adfp = Acp.copy()
nikcleju@46 1366 bp = b.copy();
nikcleju@46 1367 deltap = delta;
nikcleju@46 1368 # dfp = SLmu*xplug - SLmu1*wk;
nikcleju@46 1369 # bp = SLmu*b;
nikcleju@46 1370 # deltap = SLmu*delta;
nikcleju@46 1371
nikcleju@46 1372 # See if we even need to project:
nikcleju@46 1373 if numpy.linalg.norm( Adfp - bp ) < deltap:
nikcleju@46 1374 zk = dfp.copy();
nikcleju@46 1375 Azk = Adfp.copy();
nikcleju@46 1376 else:
nikcleju@46 1377 projection,projIter,lambdaZ = fastProjection(Q,S,V,dfp,bp,deltap, .999*lambdaZ )
nikcleju@46 1378 if lambdaZ > 0:
nikcleju@46 1379 print 'lambda is positive!'
nikcleju@46 1380 raise
nikcleju@46 1381 zk = projection.copy();
nikcleju@46 1382 # zk = SLmu1*projection;
nikcleju@46 1383 Azk = Afun(zk);
nikcleju@46 1384
nikcleju@46 1385 # DEBUGGING:
nikcleju@46 1386 # if projIter == 50
nikcleju@46 1387 # fprintf('\n Maxed out iterations at z\n');
nikcleju@46 1388 # keyboard
nikcleju@46 1389 # end
nikcleju@46 1390 #end
nikcleju@46 1391 else:
nikcleju@46 1392 lambdaa = max(0,Lmu*(numpy.linalg.norm(b-Acp)/delta - 1));
nikcleju@46 1393 gamma = lambdaa/(lambdaa + Lmu);
nikcleju@46 1394 zk = lambdaa/Lmu*(1-gamma)*Atb + cp - gamma*AtAcp;
nikcleju@46 1395 # for calculating the residual, we'll avoid calling A()
nikcleju@46 1396 # by storing A(zk) here (using A'*A = I):
nikcleju@46 1397 Azk = lambdaa/Lmu*(1-gamma)*b + Acp - gamma*Acp;
nikcleju@46 1398 #end
nikcleju@46 1399 else:
nikcleju@46 1400 # if delta is 0, this is simplified:
nikcleju@46 1401 zk = AtAAtb + cp - AtAcp;
nikcleju@46 1402 Azk = b;
nikcleju@46 1403 #end
nikcleju@46 1404
nikcleju@46 1405 # DEBUGGING
nikcleju@46 1406 # if norm( Ayk - b ) > (1.05)*delta
nikcleju@46 1407 # fprintf('\nAzk failed projection test\n');
nikcleju@46 1408 # keyboard;
nikcleju@46 1409 # end
nikcleju@46 1410
nikcleju@46 1411 #--- Updating xk
nikcleju@46 1412
nikcleju@46 1413 xkp = tauk*zk + (1-tauk)*yk;
nikcleju@46 1414 xold = xk.copy();
nikcleju@46 1415 xk = xkp.copy();
nikcleju@46 1416 Axk = tauk*Azk + (1-tauk)*Ayk;
nikcleju@46 1417
nikcleju@46 1418 #if ~mod(k,10), Axk = Afun(xk); end # otherwise slowly lose precision
nikcleju@46 1419 if not numpy.mod(k,10):
nikcleju@46 1420 Axk = Afun(xk)
nikcleju@46 1421 # DEBUG
nikcleju@46 1422 # if norm(Axk - Afun(xk) ) > 1e-6, disp('error with Axk'); keyboard; end
nikcleju@46 1423
nikcleju@46 1424 #--- display progress if desired
nikcleju@46 1425 #if ~mod(k+1,Verbose )
nikcleju@46 1426 if not numpy.mod(k+1,Verbose):
nikcleju@46 1427 #printf('Iter: #3d ~ fmu: #.3e ~ Rel. Variation of fmu: #.2e ~ Residual: #.2e',k+1,fx,qp,residuals(k+1,1) );
nikcleju@46 1428 print 'Iter: ',k+1,' ~ fmu: ',fx,' ~ Rel. Variation of fmu: ',qp,' ~ Residual:',residuals[k+1,0]
nikcleju@46 1429 #--- if user has supplied a function to calculate the error,
nikcleju@46 1430 # apply it to the current iterate and dislay the output:
nikcleju@46 1431 #if DISPLAY_ERROR, printf(' ~ Error: #.2e',errFcn(xk)); end
nikcleju@46 1432 if DISPLAY_ERROR:
nikcleju@46 1433 print ' ~ Error:',errFcn(xk)
nikcleju@46 1434 #end
nikcleju@46 1435 if abs(fx)>1e20 or abs(residuals[k,0]) >1e20 or numpy.isnan(fx):
nikcleju@46 1436 #error('Nesta: possible divergence or NaN. Bad estimate of ||A''A||?');
nikcleju@46 1437 print 'Nesta: possible divergence or NaN. Bad estimate of ||A''A||?'
nikcleju@46 1438 raise
nikcleju@46 1439 #end
nikcleju@46 1440
nikcleju@46 1441 #end
nikcleju@46 1442
nikcleju@46 1443 niter = k+1;
nikcleju@46 1444
nikcleju@46 1445 #-- truncate output vectors
nikcleju@46 1446 residuals = residuals[:niter,:]
nikcleju@46 1447 #if RECORD_DATA, outputData = outputData(1:niter,:); end
nikcleju@46 1448 if RECORD_DATA:
nikcleju@46 1449 outputData = outputData[:niter,:]
nikcleju@46 1450
nikcleju@46 1451 return xk,niter,residuals,outputData,opts
nikcleju@46 1452
nikcleju@46 1453
nikcleju@46 1454 ############ PERFORM THE L1 CONSTRAINT ##################
nikcleju@46 1455
nikcleju@46 1456 #function [df,fx,val,uk] = Perform_L1_Constraint(xk,mu,U,Ut)
nikcleju@46 1457 def Perform_L1_Constraint(xk,mu,U,Ut):
nikcleju@46 1458
nikcleju@46 1459 #if isa(U,'function_handle')
nikcleju@46 1460 if hasattr(U,'__call__'):
nikcleju@46 1461 uk = U(xk);
nikcleju@46 1462 else:
nikcleju@46 1463 uk = numpy.dot(U,xk)
nikcleju@46 1464 #end
nikcleju@46 1465 fx = uk.copy()
nikcleju@46 1466
nikcleju@46 1467 #uk = uk./max(mu,abs(uk));
nikcleju@46 1468 uk = uk / max(mu,abs(uk))
nikcleju@46 1469 #val = real(uk'*fx);
nikcleju@46 1470 val = numpy.real(numpy.vdot(uk,fx))
nikcleju@46 1471 #fx = real(uk'*fx - mu/2*norm(uk)^2);
nikcleju@46 1472 fx = numpy.real(numpy.vdot(uk,fx) - mu/2.*numpy.linalg.norm(uk)**2);
nikcleju@46 1473
nikcleju@46 1474 #if isa(Ut,'function_handle')
nikcleju@46 1475 if hasattr(U,'__call__'):
nikcleju@46 1476 df = Ut(uk);
nikcleju@46 1477 else:
nikcleju@46 1478 #df = U'*uk;
nikcleju@46 1479 df = numpy.dot(U.T,uk)
nikcleju@46 1480 #end
nikcleju@46 1481 return df,fx,val,uk
nikcleju@46 1482 #end
nikcleju@46 1483
nikcleju@46 1484 # Nic: TODO: TV not implemented yet!
nikcleju@46 1485 ############ PERFORM THE TV CONSTRAINT ##################
nikcleju@46 1486 #function [df,fx] = Perform_TV_Constraint(xk,mu,Dv,Dh,D,U,Ut)
nikcleju@46 1487 # if isa(U,'function_handle')
nikcleju@46 1488 # x = U(xk);
nikcleju@46 1489 # else
nikcleju@46 1490 # x = U*xk;
nikcleju@46 1491 # end
nikcleju@46 1492 # df = zeros(size(x));
nikcleju@46 1493 #
nikcleju@46 1494 # Dhx = Dh*x;
nikcleju@46 1495 # Dvx = Dv*x;
nikcleju@46 1496 #
nikcleju@46 1497 # tvx = sum(sqrt(abs(Dhx).^2+abs(Dvx).^2));
nikcleju@46 1498 # w = max(mu,sqrt(abs(Dhx).^2 + abs(Dvx).^2));
nikcleju@46 1499 # uh = Dhx ./ w;
nikcleju@46 1500 # uv = Dvx ./ w;
nikcleju@46 1501 # u = [uh;uv];
nikcleju@46 1502 # fx = real(u'*D*x - mu/2 * 1/numel(u)*sum(u'*u));
nikcleju@46 1503 # if isa(Ut,'function_handle')
nikcleju@46 1504 # df = Ut(D'*u);
nikcleju@46 1505 # else
nikcleju@46 1506 # df = U'*(D'*u);
nikcleju@46 1507 # end
nikcleju@46 1508 #end
nikcleju@46 1509
nikcleju@46 1510
nikcleju@46 1511 #function [x,k,l] = fastProjection( U, S, V, y, b, epsilon, lambda0, DISP )
nikcleju@46 1512 def fastProjection( U, S, V, y, b, epsilon, lambda0=0, DISP=False ):
nikcleju@46 1513 # [x,niter,lambda] = fastProjection(U, S, V, y, b, epsilon, [lambda0], [DISP] )
nikcleju@46 1514 #
nikcleju@46 1515 # minimizes || x - y ||
nikcleju@46 1516 # such that || Ax - b || <= epsilon
nikcleju@46 1517 #
nikcleju@46 1518 # where USV' = A (i.e the SVD of A)
nikcleju@46 1519 #
nikcleju@46 1520 # The optional input "lambda0" is a guess for the Lagrange parameter
nikcleju@46 1521 #
nikcleju@46 1522 # Warning: for speed, does not calculate A(y) to see if x = y is feasible
nikcleju@46 1523 #
nikcleju@46 1524 # NESTA Version 1.1
nikcleju@46 1525 # See also Core_Nesterov
nikcleju@46 1526
nikcleju@46 1527 # Written by Stephen Becker, September 2009, srbecker@caltech.edu
nikcleju@46 1528
nikcleju@46 1529 #---------------------
nikcleju@46 1530 # Original Matab code:
nikcleju@46 1531
nikcleju@46 1532 #DEBUG = true;
nikcleju@46 1533 #if nargin < 8
nikcleju@46 1534 # DISP = false;
nikcleju@46 1535 #end
nikcleju@46 1536 ## -- Parameters for Newton's method --
nikcleju@46 1537 #MAXIT = 70;
nikcleju@46 1538 #TOL = 1e-8 * epsilon;
nikcleju@46 1539 ## TOL = max( TOL, 10*eps ); # we can't do better than machine precision
nikcleju@46 1540 #
nikcleju@46 1541 #m = size(U,1);
nikcleju@46 1542 #n = size(V,1);
nikcleju@46 1543 #mn = min([m,n]);
nikcleju@46 1544 #if numel(S) > mn^2, S = diag(diag(S)); end # S should be a small square matrix
nikcleju@46 1545 #r = size(S);
nikcleju@46 1546 #if size(U,2) > r, U = U(:,1:r); end
nikcleju@46 1547 #if size(V,2) > r, V = V(:,1:r); end
nikcleju@46 1548 #
nikcleju@46 1549 #s = diag(S);
nikcleju@46 1550 #s2 = s.^2;
nikcleju@46 1551 #
nikcleju@46 1552 ## What we want to do:
nikcleju@46 1553 ## b = b - A*y;
nikcleju@46 1554 ## bb = U'*b;
nikcleju@46 1555 #
nikcleju@46 1556 ## if A doesn't have full row rank, then b may not be in the range
nikcleju@46 1557 #if size(U,1) > size(U,2)
nikcleju@46 1558 # bRange = U*(U'*b);
nikcleju@46 1559 # bNull = b - bRange;
nikcleju@46 1560 # epsilon = sqrt( epsilon^2 - norm(bNull)^2 );
nikcleju@46 1561 #end
nikcleju@46 1562 #b = U'*b - S*(V'*y); # parenthesis is very important! This is expensive.
nikcleju@46 1563 #
nikcleju@46 1564 ## b2 = b.^2;
nikcleju@46 1565 #b2 = abs(b).^2; # for complex data
nikcleju@46 1566 #bs2 = b2.*s2;
nikcleju@46 1567 #epsilon2 = epsilon^2;
nikcleju@46 1568 #
nikcleju@46 1569 ## The following routine need to be fast
nikcleju@46 1570 ## For efficiency (at cost of transparency), we are writing the calculations
nikcleju@46 1571 ## in a way that minimize number of operations. The functions "f"
nikcleju@46 1572 ## and "fp" represent f and its derivative.
nikcleju@46 1573 #
nikcleju@46 1574 ## f = @(lambda) sum( b2 .*(1-lambda*s2).^(-2) ) - epsilon^2;
nikcleju@46 1575 ## fp = @(lambda) 2*sum( bs2 .*(1-lambda*s2).^(-3) );
nikcleju@46 1576 #if nargin < 7, lambda0 = 0; end
nikcleju@46 1577 #l = lambda0; oldff = 0;
nikcleju@46 1578 #one = ones(m,1);
nikcleju@46 1579 #alpha = 1; # take full Newton steps
nikcleju@46 1580 #for k = 1:MAXIT
nikcleju@46 1581 # # make f(l) and fp(l) as efficient as possible:
nikcleju@46 1582 # ls = one./(one-l*s2);
nikcleju@46 1583 # ls2 = ls.^2;
nikcleju@46 1584 # ls3 = ls2.*ls;
nikcleju@46 1585 # ff = b2.'*ls2; # should be .', not ', even for complex data
nikcleju@46 1586 # ff = ff - epsilon2;
nikcleju@46 1587 # fpl = 2*( bs2.'*ls3 ); # should be .', not ', even for complex data
nikcleju@46 1588 ## ff = f(l); # this is a little slower
nikcleju@46 1589 ## fpl = fp(l); # this is a little slower
nikcleju@46 1590 # d = -ff/fpl;
nikcleju@46 1591 # if DISP, fprintf('#2d, lambda is #5.2f, f(lambda) is #.2e, f''(lambda) is #.2e\n',...
nikcleju@46 1592 # k,l,ff,fpl ); end
nikcleju@46 1593 # if abs(ff) < TOL, break; end # stopping criteria
nikcleju@46 1594 # l_old = l;
nikcleju@46 1595 # if k>2 && ( abs(ff) > 10*abs(oldff+100) ) #|| abs(d) > 1e13 )
nikcleju@46 1596 # l = 0; alpha = 1/2;
nikcleju@46 1597 ## oldff = f(0);
nikcleju@46 1598 # oldff = sum(b2); oldff = oldff - epsilon2;
nikcleju@46 1599 # if DISP, disp('restarting'); end
nikcleju@46 1600 # else
nikcleju@46 1601 # if alpha < 1, alpha = (alpha+1)/2; end
nikcleju@46 1602 # l = l + alpha*d;
nikcleju@46 1603 # oldff = ff;
nikcleju@46 1604 # if l > 0
nikcleju@46 1605 # l = 0; # shouldn't be positive
nikcleju@46 1606 # oldff = sum(b2); oldff = oldff - epsilon2;
nikcleju@46 1607 # end
nikcleju@46 1608 # end
nikcleju@46 1609 # if l_old == l && l == 0
nikcleju@46 1610 # if DISP, disp('Making no progress; x = y is probably feasible'); end
nikcleju@46 1611 # break;
nikcleju@46 1612 # end
nikcleju@46 1613 #end
nikcleju@46 1614 ## if k == MAXIT && DEBUG, disp('maxed out iterations'); end
nikcleju@46 1615 #if l < 0
nikcleju@46 1616 # xhat = -l*s.*b./( 1 - l*s2 );
nikcleju@46 1617 # x = V*xhat + y;
nikcleju@46 1618 #else
nikcleju@46 1619 # # y is already feasible, so no need to project
nikcleju@46 1620 # l = 0;
nikcleju@46 1621 # x = y;
nikcleju@46 1622 #end
nikcleju@46 1623
nikcleju@46 1624 # End of original Matab code
nikcleju@46 1625 #---------------------
nikcleju@46 1626
nikcleju@46 1627 DEBUG = True;
nikcleju@46 1628 #if nargin < 8
nikcleju@46 1629 # DISP = false;
nikcleju@46 1630 #end
nikcleju@46 1631 # -- Parameters for Newton's method --
nikcleju@46 1632 MAXIT = 70;
nikcleju@46 1633 TOL = 1e-8 * epsilon;
nikcleju@46 1634 # TOL = max( TOL, 10*eps ); # we can't do better than machine precision
nikcleju@46 1635
nikcleju@46 1636 #m = size(U,1);
nikcleju@46 1637 #n = size(V,1);
nikcleju@46 1638 m = U.shape[0]
nikcleju@46 1639 n = V.shape[0]
nikcleju@46 1640 mn = min(m,n);
nikcleju@46 1641 #if numel(S) > mn^2, S = diag(diag(S)); end # S should be a small square matrix
nikcleju@46 1642 if S.size > mn**2:
nikcleju@46 1643 S = numpy.diag(numpy.diag(S))
nikcleju@46 1644 #r = size(S);
nikcleju@46 1645 r = S.shape
nikcleju@46 1646 #if size(U,2) > r, U = U(:,1:r); end
nikcleju@46 1647 if U.shape[1] > r:
nikcleju@46 1648 U = U[:,r]
nikcleju@46 1649 #if size(V,2) > r, V = V(:,1:r); end
nikcleju@46 1650 if V.shape[1] > r:
nikcleju@46 1651 V = V[:,r]
nikcleju@46 1652
nikcleju@46 1653 s = numpy.diag(S);
nikcleju@46 1654 s2 = s**2;
nikcleju@46 1655
nikcleju@46 1656 # What we want to do:
nikcleju@46 1657 # b = b - A*y;
nikcleju@46 1658 # bb = U'*b;
nikcleju@46 1659
nikcleju@46 1660 # if A doesn't have full row rank, then b may not be in the range
nikcleju@46 1661 #if size(U,1) > size(U,2)
nikcleju@46 1662 if U.shape[0] > U.shape[1]:
nikcleju@46 1663 #bRange = U*(U'*b);
nikcleju@46 1664 bRange = numpy.dot(U, numpy.dot(U.T, b))
nikcleju@46 1665 bNull = b - bRange;
nikcleju@46 1666 epsilon = math.sqrt( epsilon**2 - numpy.linalg.norm(bNull)**2 );
nikcleju@46 1667 #end
nikcleju@46 1668 #b = U'*b - S*(V'*y); # parenthesis is very important! This is expensive.
nikcleju@46 1669 b = numpy.dot(U.T,b) - numpy.dot(S, numpy.dot(V.T,y))
nikcleju@46 1670
nikcleju@46 1671 # b2 = b.^2;
nikcleju@46 1672 b2 = abs(b)**2; # for complex data
nikcleju@46 1673 bs2 = b2**s2;
nikcleju@46 1674 epsilon2 = epsilon**2;
nikcleju@46 1675
nikcleju@46 1676 # The following routine need to be fast
nikcleju@46 1677 # For efficiency (at cost of transparency), we are writing the calculations
nikcleju@46 1678 # in a way that minimize number of operations. The functions "f"
nikcleju@46 1679 # and "fp" represent f and its derivative.
nikcleju@46 1680
nikcleju@46 1681 # f = @(lambda) sum( b2 .*(1-lambda*s2).^(-2) ) - epsilon^2;
nikcleju@46 1682 # fp = @(lambda) 2*sum( bs2 .*(1-lambda*s2).^(-3) );
nikcleju@46 1683
nikcleju@46 1684 #if nargin < 7, lambda0 = 0; end
nikcleju@46 1685 l = lambda0;
nikcleju@46 1686 oldff = 0;
nikcleju@46 1687 one = numpy.ones(m);
nikcleju@46 1688 alpha = 1; # take full Newton steps
nikcleju@46 1689 #for k = 1:MAXIT
nikcleju@46 1690 for k in numpy.arange(MAXIT):
nikcleju@46 1691 # make f(l) and fp(l) as efficient as possible:
nikcleju@46 1692 #ls = one./(one-l*s2);
nikcleju@46 1693 ls = one/(one-l*s2)
nikcleju@46 1694 ls2 = ls**2;
nikcleju@46 1695 ls3 = ls2**ls;
nikcleju@46 1696 #ff = b2.'*ls2; # should be .', not ', even for complex data
nikcleju@46 1697 ff = numpy.dot(b2.conj(), ls2)
nikcleju@46 1698 ff = ff - epsilon2;
nikcleju@46 1699 #fpl = 2*( bs2.'*ls3 ); # should be .', not ', even for complex data
nikcleju@46 1700 fpl = 2 * numpy.dot(bs2.conj(),ls3)
nikcleju@46 1701 # ff = f(l); # this is a little slower
nikcleju@46 1702 # fpl = fp(l); # this is a little slower
nikcleju@46 1703 d = -ff/fpl;
nikcleju@46 1704 # if DISP, fprintf('#2d, lambda is #5.2f, f(lambda) is #.2e, f''(lambda) is #.2e\n',k,l,ff,fpl ); end
nikcleju@46 1705 if DISP:
nikcleju@46 1706 print k,', lambda is ',l,', f(lambda) is ',ff,', f''(lambda) is',fpl
nikcleju@46 1707 #if abs(ff) < TOL, break; end # stopping criteria
nikcleju@46 1708 if abs(ff) < TOL:
nikcleju@46 1709 break
nikcleju@46 1710 l_old = l.copy();
nikcleju@46 1711 if k>2 and ( abs(ff) > 10*abs(oldff+100) ): #|| abs(d) > 1e13 )
nikcleju@46 1712 l = 0;
nikcleju@46 1713 alpha = 1.0/2.0;
nikcleju@46 1714 # oldff = f(0);
nikcleju@46 1715 oldff = b2.sum(); oldff = oldff - epsilon2;
nikcleju@46 1716 if DISP:
nikcleju@46 1717 print 'restarting'
nikcleju@46 1718 else:
nikcleju@46 1719 if alpha < 1:
nikcleju@46 1720 alpha = (alpha+1.0)/2.0
nikcleju@46 1721 l = l + alpha*d;
nikcleju@46 1722 oldff = ff
nikcleju@46 1723 if l > 0:
nikcleju@46 1724 l = 0; # shouldn't be positive
nikcleju@46 1725 oldff = b2.sum()
nikcleju@46 1726 oldff = oldff - epsilon2;
nikcleju@46 1727 #end
nikcleju@46 1728 #end
nikcleju@46 1729 if l_old == l and l == 0:
nikcleju@46 1730 #if DISP, disp('Making no progress; x = y is probably feasible'); end
nikcleju@46 1731 if DISP:
nikcleju@46 1732 print 'Making no progress; x = y is probably feasible'
nikcleju@46 1733 break;
nikcleju@46 1734 #end
nikcleju@46 1735 #end
nikcleju@46 1736 # if k == MAXIT && DEBUG, disp('maxed out iterations'); end
nikcleju@46 1737 if l < 0:
nikcleju@46 1738 #xhat = -l*s.*b./( 1 - l*s2 );
nikcleju@46 1739 xhat = numpy.dot(-l, s*b/( 1. - numpy.dot(l,s2) ) )
nikcleju@46 1740 #x = V*xhat + y;
nikcleju@46 1741 x = numpy.dot(V,xhat) + y
nikcleju@46 1742 else:
nikcleju@46 1743 # y is already feasible, so no need to project
nikcleju@46 1744 l = 0;
nikcleju@46 1745 x = y.copy();
nikcleju@46 1746 #end
nikcleju@46 1747 return x,k,l