annotate pyCSalgos/OMP/omp_QR.py @ 62:e684f76c1969

Added l1eq_pd(), l1 minimizer with equality constraints from l1-magic
author Nic Cleju <nikcleju@gmail.com>
date Mon, 12 Mar 2012 11:29:32 +0200
parents e1da5140c9a5
children
rev   line source
nikcleju@2 1 import numpy as np
nikcleju@2 2 import scipy.linalg
nikcleju@2 3 import time
nikcleju@2 4 import math
nikcleju@2 5
nikcleju@2 6
nikcleju@2 7 #function [s, err_mse, iter_time]=greed_omp_qr(x,A,m,varargin)
nikcleju@2 8 def greed_omp_qr(x,A,m,opts=[]):
nikcleju@2 9 # greed_omp_qr: Orthogonal Matching Pursuit algorithm based on QR
nikcleju@2 10 # factorisation
nikcleju@2 11 # Nic: translated to Python on 19.10.2011. Original Matlab Code by Thomas Blumensath
nikcleju@2 12 ###########################################################################
nikcleju@2 13 # Usage
nikcleju@2 14 # [s, err_mse, iter_time]=greed_omp_qr(x,P,m,'option_name','option_value')
nikcleju@2 15 ###########################################################################
nikcleju@2 16 ###########################################################################
nikcleju@2 17 # Input
nikcleju@2 18 # Mandatory:
nikcleju@2 19 # x Observation vector to be decomposed
nikcleju@2 20 # P Either:
nikcleju@2 21 # 1) An nxm matrix (n must be dimension of x)
nikcleju@2 22 # 2) A function handle (type "help function_format"
nikcleju@2 23 # for more information)
nikcleju@2 24 # Also requires specification of P_trans option.
nikcleju@2 25 # 3) An object handle (type "help object_format" for
nikcleju@2 26 # more information)
nikcleju@2 27 # m length of s
nikcleju@2 28 #
nikcleju@2 29 # Possible additional options:
nikcleju@2 30 # (specify as many as you want using 'option_name','option_value' pairs)
nikcleju@2 31 # See below for explanation of options:
nikcleju@2 32 #__________________________________________________________________________
nikcleju@2 33 # option_name | available option_values | default
nikcleju@2 34 #--------------------------------------------------------------------------
nikcleju@2 35 # stopCrit | M, corr, mse, mse_change | M
nikcleju@2 36 # stopTol | number (see below) | n/4
nikcleju@2 37 # P_trans | function_handle (see below) |
nikcleju@2 38 # maxIter | positive integer (see below) | n
nikcleju@2 39 # verbose | true, false | false
nikcleju@2 40 # start_val | vector of length m | zeros
nikcleju@2 41 #
nikcleju@2 42 # Available stopping criteria :
nikcleju@2 43 # M - Extracts exactly M = stopTol elements.
nikcleju@2 44 # corr - Stops when maximum correlation between
nikcleju@2 45 # residual and atoms is below stopTol value.
nikcleju@2 46 # mse - Stops when mean squared error of residual
nikcleju@2 47 # is below stopTol value.
nikcleju@2 48 # mse_change - Stops when the change in the mean squared
nikcleju@2 49 # error falls below stopTol value.
nikcleju@2 50 #
nikcleju@2 51 # stopTol: Value for stopping criterion.
nikcleju@2 52 #
nikcleju@2 53 # P_trans: If P is a function handle, then P_trans has to be specified and
nikcleju@2 54 # must be a function handle.
nikcleju@2 55 #
nikcleju@2 56 # maxIter: Maximum number of allowed iterations.
nikcleju@2 57 #
nikcleju@2 58 # verbose: Logical value to allow algorithm progress to be displayed.
nikcleju@2 59 #
nikcleju@2 60 # start_val: Allows algorithms to start from partial solution.
nikcleju@2 61 #
nikcleju@2 62 ###########################################################################
nikcleju@2 63 # Outputs
nikcleju@2 64 # s Solution vector
nikcleju@2 65 # err_mse Vector containing mse of approximation error for each
nikcleju@2 66 # iteration
nikcleju@2 67 # iter_time Vector containing computation times for each iteration
nikcleju@2 68 #
nikcleju@2 69 ###########################################################################
nikcleju@2 70 # Description
nikcleju@2 71 # greed_omp_qr performs a greedy signal decomposition.
nikcleju@2 72 # In each iteration a new element is selected depending on the inner
nikcleju@2 73 # product between the current residual and columns in P.
nikcleju@2 74 # The non-zero elements of s are approximated by orthogonally projecting
nikcleju@2 75 # x onto the selected elements in each iteration.
nikcleju@2 76 # The algorithm uses QR decomposition.
nikcleju@2 77 #
nikcleju@2 78 # See Also
nikcleju@2 79 # greed_omp_chol, greed_omp_cg, greed_omp_cgp, greed_omp_pinv,
nikcleju@2 80 # greed_omp_linsolve, greed_gp, greed_nomp
nikcleju@2 81 #
nikcleju@2 82 # Copyright (c) 2007 Thomas Blumensath
nikcleju@2 83 #
nikcleju@2 84 # The University of Edinburgh
nikcleju@2 85 # Email: thomas.blumensath@ed.ac.uk
nikcleju@2 86 # Comments and bug reports welcome
nikcleju@2 87 #
nikcleju@2 88 # This file is part of sparsity Version 0.1
nikcleju@2 89 # Created: April 2007
nikcleju@2 90 #
nikcleju@2 91 # Part of this toolbox was developed with the support of EPSRC Grant
nikcleju@2 92 # D000246/1
nikcleju@2 93 #
nikcleju@2 94 # Please read COPYRIGHT.m for terms and conditions.
nikcleju@2 95
nikcleju@2 96 ###########################################################################
nikcleju@2 97 # Default values and initialisation
nikcleju@2 98 ###########################################################################
nikcleju@2 99 #[n1 n2]=size(x);
nikcleju@2 100 #n1,n2 = x.shape
nikcleju@2 101 #if n2 == 1
nikcleju@2 102 # n=n1;
nikcleju@2 103 #elseif n1 == 1
nikcleju@2 104 # x=x';
nikcleju@2 105 # n=n2;
nikcleju@2 106 #else
nikcleju@2 107 # display('x must be a vector.');
nikcleju@2 108 # return
nikcleju@2 109 #end
nikcleju@2 110 if x.ndim != 1:
nikcleju@2 111 print 'x must be a vector.'
nikcleju@2 112 return
nikcleju@2 113 n = x.size
nikcleju@2 114
nikcleju@2 115 #sigsize = x'*x/n;
nikcleju@2 116 sigsize = np.vdot(x,x)/n;
nikcleju@2 117 initial_given = 0;
nikcleju@2 118 err_mse = np.array([]);
nikcleju@2 119 iter_time = np.array([]);
nikcleju@2 120 STOPCRIT = 'M';
nikcleju@2 121 STOPTOL = math.ceil(n/4.0);
nikcleju@2 122 MAXITER = n;
nikcleju@2 123 verbose = False;
nikcleju@2 124 s_initial = np.zeros(m);
nikcleju@2 125
nikcleju@2 126 if verbose:
nikcleju@2 127 print 'Initialising...'
nikcleju@2 128 #end
nikcleju@2 129
nikcleju@2 130 ###########################################################################
nikcleju@2 131 # Output variables
nikcleju@2 132 ###########################################################################
nikcleju@2 133 #switch nargout
nikcleju@2 134 # case 3
nikcleju@2 135 # comp_err=true;
nikcleju@2 136 # comp_time=true;
nikcleju@2 137 # case 2
nikcleju@2 138 # comp_err=true;
nikcleju@2 139 # comp_time=false;
nikcleju@2 140 # case 1
nikcleju@2 141 # comp_err=false;
nikcleju@2 142 # comp_time=false;
nikcleju@2 143 # case 0
nikcleju@2 144 # error('Please assign output variable.')
nikcleju@2 145 # otherwise
nikcleju@2 146 # error('Too many output arguments specified')
nikcleju@2 147 #end
nikcleju@2 148 if 'nargout' in opts:
nikcleju@2 149 if opts['nargout'] == 3:
nikcleju@2 150 comp_err = True
nikcleju@2 151 comp_time = True
nikcleju@2 152 elif opts['nargout'] == 2:
nikcleju@2 153 comp_err = True
nikcleju@2 154 comp_time = False
nikcleju@2 155 elif opts['nargout'] == 1:
nikcleju@2 156 comp_err = False
nikcleju@2 157 comp_time = False
nikcleju@2 158 elif opts['nargout'] == 0:
nikcleju@2 159 print 'Please assign output variable.'
nikcleju@2 160 return
nikcleju@2 161 else:
nikcleju@2 162 print 'Too many output arguments specified'
nikcleju@2 163 return
nikcleju@2 164 else:
nikcleju@2 165 # If not given, make default nargout = 3
nikcleju@2 166 # and add nargout to options
nikcleju@2 167 opts['nargout'] = 3
nikcleju@2 168 comp_err = True
nikcleju@2 169 comp_time = True
nikcleju@2 170
nikcleju@2 171 ###########################################################################
nikcleju@2 172 # Look through options
nikcleju@2 173 ###########################################################################
nikcleju@2 174 # Put option into nice format
nikcleju@2 175 #Options={};
nikcleju@2 176 #OS=nargin-3;
nikcleju@2 177 #c=1;
nikcleju@2 178 #for i=1:OS
nikcleju@2 179 # if isa(varargin{i},'cell')
nikcleju@2 180 # CellSize=length(varargin{i});
nikcleju@2 181 # ThisCell=varargin{i};
nikcleju@2 182 # for j=1:CellSize
nikcleju@2 183 # Options{c}=ThisCell{j};
nikcleju@2 184 # c=c+1;
nikcleju@2 185 # end
nikcleju@2 186 # else
nikcleju@2 187 # Options{c}=varargin{i};
nikcleju@2 188 # c=c+1;
nikcleju@2 189 # end
nikcleju@2 190 #end
nikcleju@2 191 #OS=length(Options);
nikcleju@2 192 #if rem(OS,2)
nikcleju@2 193 # error('Something is wrong with argument name and argument value pairs.')
nikcleju@2 194 #end
nikcleju@2 195 #
nikcleju@2 196 #for i=1:2:OS
nikcleju@2 197 # switch Options{i}
nikcleju@2 198 # case {'stopCrit'}
nikcleju@2 199 # if (strmatch(Options{i+1},{'M'; 'corr'; 'mse'; 'mse_change'},'exact'));
nikcleju@2 200 # STOPCRIT = Options{i+1};
nikcleju@2 201 # else error('stopCrit must be char string [M, corr, mse, mse_change]. Exiting.'); end
nikcleju@2 202 # case {'stopTol'}
nikcleju@2 203 # if isa(Options{i+1},'numeric') ; STOPTOL = Options{i+1};
nikcleju@2 204 # else error('stopTol must be number. Exiting.'); end
nikcleju@2 205 # case {'P_trans'}
nikcleju@2 206 # if isa(Options{i+1},'function_handle'); Pt = Options{i+1};
nikcleju@2 207 # else error('P_trans must be function _handle. Exiting.'); end
nikcleju@2 208 # case {'maxIter'}
nikcleju@2 209 # if isa(Options{i+1},'numeric'); MAXITER = Options{i+1};
nikcleju@2 210 # else error('maxIter must be a number. Exiting.'); end
nikcleju@2 211 # case {'verbose'}
nikcleju@2 212 # if isa(Options{i+1},'logical'); verbose = Options{i+1};
nikcleju@2 213 # else error('verbose must be a logical. Exiting.'); end
nikcleju@2 214 # case {'start_val'}
nikcleju@2 215 # if isa(Options{i+1},'numeric') & length(Options{i+1}) == m ;
nikcleju@2 216 # s_initial = Options{i+1};
nikcleju@2 217 # initial_given=1;
nikcleju@2 218 # else error('start_val must be a vector of length m. Exiting.'); end
nikcleju@2 219 # otherwise
nikcleju@2 220 # error('Unrecognised option. Exiting.')
nikcleju@2 221 # end
nikcleju@2 222 #end
nikcleju@2 223 if 'stopCrit' in opts:
nikcleju@2 224 STOPCRIT = opts['stopCrit']
nikcleju@2 225 if 'stopTol' in opts:
nikcleju@2 226 if hasattr(opts['stopTol'], '__int__'): # check if numeric
nikcleju@2 227 STOPTOL = opts['stopTol']
nikcleju@2 228 else:
nikcleju@2 229 raise TypeError('stopTol must be number. Exiting.')
nikcleju@2 230 if 'P_trans' in opts:
nikcleju@2 231 if hasattr(opts['P_trans'], '__call__'): # check if function handle
nikcleju@2 232 Pt = opts['P_trans']
nikcleju@2 233 else:
nikcleju@2 234 raise TypeError('P_trans must be function _handle. Exiting.')
nikcleju@2 235 if 'maxIter' in opts:
nikcleju@2 236 if hasattr(opts['maxIter'], '__int__'): # check if numeric
nikcleju@2 237 MAXITER = opts['maxIter']
nikcleju@2 238 else:
nikcleju@2 239 raise TypeError('maxIter must be a number. Exiting.')
nikcleju@2 240 if 'verbose' in opts:
nikcleju@2 241 # TODO: Should check here if is logical
nikcleju@2 242 verbose = opts['verbose']
nikcleju@2 243 if 'start_val' in opts:
nikcleju@2 244 # TODO: Should check here if is numeric
nikcleju@2 245 if opts['start_val'].size == m:
nikcleju@2 246 s_initial = opts['start_val']
nikcleju@2 247 initial_given = 1
nikcleju@2 248 else:
nikcleju@2 249 raise ValueError('start_val must be a vector of length m. Exiting.')
nikcleju@2 250 # Don't exit if unknown option is given, simply ignore it
nikcleju@2 251
nikcleju@2 252 #if strcmp(STOPCRIT,'M')
nikcleju@2 253 # maxM=STOPTOL;
nikcleju@2 254 #else
nikcleju@2 255 # maxM=MAXITER;
nikcleju@2 256 #end
nikcleju@2 257 if STOPCRIT == 'M':
nikcleju@2 258 maxM = STOPTOL
nikcleju@2 259 else:
nikcleju@2 260 maxM = MAXITER
nikcleju@2 261
nikcleju@2 262 # if nargout >=2
nikcleju@2 263 # err_mse = zeros(maxM,1);
nikcleju@2 264 # end
nikcleju@2 265 # if nargout ==3
nikcleju@2 266 # iter_time = zeros(maxM,1);
nikcleju@2 267 # end
nikcleju@2 268 if opts['nargout'] >= 2:
nikcleju@2 269 err_mse = np.zeros(maxM)
nikcleju@2 270 if opts['nargout'] == 3:
nikcleju@2 271 iter_time = np.zeros(maxM)
nikcleju@2 272
nikcleju@2 273 ###########################################################################
nikcleju@2 274 # Make P and Pt functions
nikcleju@2 275 ###########################################################################
nikcleju@2 276 #if isa(A,'float') P =@(z) A*z; Pt =@(z) A'*z;
nikcleju@2 277 #elseif isobject(A) P =@(z) A*z; Pt =@(z) A'*z;
nikcleju@2 278 #elseif isa(A,'function_handle')
nikcleju@2 279 # try
nikcleju@2 280 # if isa(Pt,'function_handle'); P=A;
nikcleju@2 281 # else error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end
nikcleju@2 282 # catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end
nikcleju@2 283 #else error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end
nikcleju@2 284 if hasattr(A, '__call__'):
nikcleju@2 285 if hasattr(Pt, '__call__'):
nikcleju@2 286 P = A
nikcleju@2 287 else:
nikcleju@2 288 raise TypeError('If P is a function handle, Pt also needs to be a function handle.')
nikcleju@2 289 else:
nikcleju@2 290 # TODO: should check here if A is matrix
nikcleju@2 291 P = lambda z: np.dot(A,z)
nikcleju@2 292 Pt = lambda z: np.dot(A.T,z)
nikcleju@2 293
nikcleju@2 294 ###########################################################################
nikcleju@2 295 # Random Check to see if dictionary is normalised
nikcleju@2 296 ###########################################################################
nikcleju@2 297 # mask=zeros(m,1);
nikcleju@2 298 # mask(ceil(rand*m))=1;
nikcleju@2 299 # nP=norm(P(mask));
nikcleju@2 300 # if abs(1-nP)>1e-3;
nikcleju@2 301 # display('Dictionary appears not to have unit norm columns.')
nikcleju@2 302 # end
nikcleju@2 303 mask = np.zeros(m)
nikcleju@2 304 mask[math.floor(np.random.rand() * m)] = 1
nikcleju@32 305 #nP = np.linalg.norm(P(mask))
nikcleju@32 306 #if abs(1-nP) > 1e-3:
nikcleju@32 307 # print 'Dictionary appears not to have unit norm columns.'
nikcleju@2 308
nikcleju@2 309 ###########################################################################
nikcleju@2 310 # Check if we have enough memory and initialise
nikcleju@2 311 ###########################################################################
nikcleju@2 312 # try Q=zeros(n,maxM);
nikcleju@2 313 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
nikcleju@2 314 # end
nikcleju@2 315 # try R=zeros(maxM);
nikcleju@2 316 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
nikcleju@2 317 # end
nikcleju@2 318 try:
nikcleju@2 319 Q = np.zeros((n,maxM))
nikcleju@2 320 except:
nikcleju@2 321 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
nikcleju@2 322 raise
nikcleju@2 323 try:
nikcleju@2 324 R = np.zeros((maxM, maxM))
nikcleju@2 325 except:
nikcleju@2 326 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
nikcleju@2 327 raise
nikcleju@2 328
nikcleju@2 329 ###########################################################################
nikcleju@2 330 # Do we start from zero or not?
nikcleju@2 331 ###########################################################################
nikcleju@2 332 #if initial_given ==1;
nikcleju@2 333 # IN = find(s_initial);
nikcleju@2 334 # if ~isempty(IN)
nikcleju@2 335 # Residual = x-P(s_initial);
nikcleju@2 336 # lengthIN=length(IN);
nikcleju@2 337 # z=[];
nikcleju@2 338 # for k=1:length(IN)
nikcleju@2 339 # # Extract new element
nikcleju@2 340 # mask=zeros(m,1);
nikcleju@2 341 # mask(IN(k))=1;
nikcleju@2 342 # new_element=P(mask);
nikcleju@2 343 #
nikcleju@2 344 # # Orthogonalise new element
nikcleju@2 345 # qP=Q(:,1:k-1)'*new_element;
nikcleju@2 346 # q=new_element-Q(:,1:k-1)*(qP);
nikcleju@2 347 #
nikcleju@2 348 # nq=norm(q);
nikcleju@2 349 # q=q/nq;
nikcleju@2 350 # # Update QR factorisation
nikcleju@2 351 # R(1:k-1,k)=qP;
nikcleju@2 352 # R(k,k)=nq;
nikcleju@2 353 # Q(:,k)=q;
nikcleju@2 354 #
nikcleju@2 355 # z(k)=q'*x;
nikcleju@2 356 # end
nikcleju@2 357 # s = s_initial;
nikcleju@2 358 # Residual=x-Q(:,k)*z;
nikcleju@2 359 # oldERR = Residual'*Residual/n;
nikcleju@2 360 # else
nikcleju@2 361 # IN = [];
nikcleju@2 362 # Residual = x;
nikcleju@2 363 # s = s_initial;
nikcleju@2 364 # sigsize = x'*x/n;
nikcleju@2 365 # oldERR = sigsize;
nikcleju@2 366 # k=0;
nikcleju@2 367 # z=[];
nikcleju@2 368 # end
nikcleju@2 369 #
nikcleju@2 370 #else
nikcleju@2 371 # IN = [];
nikcleju@2 372 # Residual = x;
nikcleju@2 373 # s = s_initial;
nikcleju@2 374 # sigsize = x'*x/n;
nikcleju@2 375 # oldERR = sigsize;
nikcleju@2 376 # k=0;
nikcleju@2 377 # z=[];
nikcleju@2 378 #end
nikcleju@2 379 if initial_given == 1:
nikcleju@2 380 #IN = find(s_initial);
nikcleju@2 381 IN = np.nonzero(s_initial)[0].tolist()
nikcleju@2 382 #if ~isempty(IN)
nikcleju@2 383 if IN.size > 0:
nikcleju@2 384 Residual = x - P(s_initial)
nikcleju@2 385 lengthIN = IN.size
nikcleju@2 386 z = np.array([])
nikcleju@2 387 #for k=1:length(IN)
nikcleju@2 388 for k in np.arange(IN.size):
nikcleju@2 389 # Extract new element
nikcleju@2 390 mask = np.zeros(m)
nikcleju@2 391 mask[IN[k]] = 1
nikcleju@2 392 new_element = P(mask)
nikcleju@2 393
nikcleju@2 394 # Orthogonalise new element
nikcleju@2 395 #qP=Q(:,1:k-1)'*new_element;
nikcleju@2 396 if k-1 >= 0:
nikcleju@2 397 qP = np.dot(Q[:,0:k].T , new_element)
nikcleju@2 398 #q=new_element-Q(:,1:k-1)*(qP);
nikcleju@2 399 q = new_element - np.dot(Q[:,0:k] , qP)
nikcleju@2 400
nikcleju@2 401 nq = np.linalg.norm(q)
nikcleju@2 402 q = q / nq
nikcleju@2 403 # Update QR factorisation
nikcleju@2 404 R[0:k,k] = qP
nikcleju@2 405 R[k,k] = nq
nikcleju@2 406 Q[:,k] = q
nikcleju@2 407 else:
nikcleju@2 408 q = new_element
nikcleju@2 409
nikcleju@2 410 nq = np.linalg.norm(q)
nikcleju@2 411 q = q / nq
nikcleju@2 412 # Update QR factorisation
nikcleju@2 413 R[k,k] = nq
nikcleju@2 414 Q[:,k] = q
nikcleju@2 415
nikcleju@2 416 z[k] = np.dot(q.T , x)
nikcleju@2 417 #end
nikcleju@2 418 s = s_initial.copy()
nikcleju@2 419 Residual = x - np.dot(Q[:,k] , z)
nikcleju@2 420 oldERR = np.vdot(Residual , Residual) / n;
nikcleju@2 421 else:
nikcleju@2 422 #IN = np.array([], dtype = int)
nikcleju@2 423 IN = np.array([], dtype = int).tolist()
nikcleju@2 424 Residual = x.copy()
nikcleju@2 425 s = s_initial.copy()
nikcleju@2 426 sigsize = np.vdot(x , x) / n
nikcleju@2 427 oldERR = sigsize
nikcleju@2 428 k = 0
nikcleju@2 429 #z = np.array([])
nikcleju@2 430 z = []
nikcleju@2 431 #end
nikcleju@2 432
nikcleju@2 433 else:
nikcleju@2 434 #IN = np.array([], dtype = int)
nikcleju@2 435 IN = np.array([], dtype = int).tolist()
nikcleju@2 436 Residual = x.copy()
nikcleju@2 437 s = s_initial.copy()
nikcleju@2 438 sigsize = np.vdot(x , x) / n
nikcleju@2 439 oldERR = sigsize
nikcleju@2 440 k = 0
nikcleju@2 441 #z = np.array([])
nikcleju@2 442 z = []
nikcleju@2 443 #end
nikcleju@2 444
nikcleju@2 445 ###########################################################################
nikcleju@2 446 # Main algorithm
nikcleju@2 447 ###########################################################################
nikcleju@2 448 # if verbose
nikcleju@2 449 # display('Main iterations...')
nikcleju@2 450 # end
nikcleju@2 451 # tic
nikcleju@2 452 # t=0;
nikcleju@2 453 # DR=Pt(Residual);
nikcleju@2 454 # done = 0;
nikcleju@2 455 # iter=1;
nikcleju@2 456 if verbose:
nikcleju@2 457 print 'Main iterations...'
nikcleju@2 458 tic = time.time()
nikcleju@2 459 t = 0
nikcleju@2 460 DR = Pt(Residual)
nikcleju@2 461 done = 0
nikcleju@2 462 iter = 1
nikcleju@2 463
nikcleju@2 464 #while ~done
nikcleju@2 465 #
nikcleju@2 466 # # Select new element
nikcleju@2 467 # DR(IN)=0;
nikcleju@2 468 # # Nic: replace selection with random variable
nikcleju@2 469 # # i.e. Randomized OMP!!
nikcleju@2 470 # # DON'T FORGET ABOUT THIS!!
nikcleju@2 471 # [v I]=max(abs(DR));
nikcleju@2 472 # #I = randp(exp(abs(DR).^2 ./ (norms.^2)'), [1 1]);
nikcleju@2 473 # IN=[IN I];
nikcleju@2 474 #
nikcleju@2 475 #
nikcleju@2 476 # k=k+1;
nikcleju@2 477 # # Extract new element
nikcleju@2 478 # mask=zeros(m,1);
nikcleju@2 479 # mask(IN(k))=1;
nikcleju@2 480 # new_element=P(mask);
nikcleju@2 481 #
nikcleju@2 482 # # Orthogonalise new element
nikcleju@2 483 # qP=Q(:,1:k-1)'*new_element;
nikcleju@2 484 # q=new_element-Q(:,1:k-1)*(qP);
nikcleju@2 485 #
nikcleju@2 486 # nq=norm(q);
nikcleju@2 487 # q=q/nq;
nikcleju@2 488 # # Update QR factorisation
nikcleju@2 489 # R(1:k-1,k)=qP;
nikcleju@2 490 # R(k,k)=nq;
nikcleju@2 491 # Q(:,k)=q;
nikcleju@2 492 #
nikcleju@2 493 # z(k)=q'*x;
nikcleju@2 494 #
nikcleju@2 495 # # New residual
nikcleju@2 496 # Residual=Residual-q*(z(k));
nikcleju@2 497 # DR=Pt(Residual);
nikcleju@2 498 #
nikcleju@2 499 # ERR=Residual'*Residual/n;
nikcleju@2 500 # if comp_err
nikcleju@2 501 # err_mse(iter)=ERR;
nikcleju@2 502 # end
nikcleju@2 503 #
nikcleju@2 504 # if comp_time
nikcleju@2 505 # iter_time(iter)=toc;
nikcleju@2 506 # end
nikcleju@2 507 #
nikcleju@2 508 ############################################################################
nikcleju@2 509 ## Are we done yet?
nikcleju@2 510 ############################################################################
nikcleju@2 511 #
nikcleju@2 512 # if strcmp(STOPCRIT,'M')
nikcleju@2 513 # if iter >= STOPTOL
nikcleju@2 514 # done =1;
nikcleju@2 515 # elseif verbose && toc-t>10
nikcleju@2 516 # display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
nikcleju@2 517 # t=toc;
nikcleju@2 518 # end
nikcleju@2 519 # elseif strcmp(STOPCRIT,'mse')
nikcleju@2 520 # if comp_err
nikcleju@2 521 # if err_mse(iter)<STOPTOL;
nikcleju@2 522 # done = 1;
nikcleju@2 523 # elseif verbose && toc-t>10
nikcleju@2 524 # display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
nikcleju@2 525 # t=toc;
nikcleju@2 526 # end
nikcleju@2 527 # else
nikcleju@2 528 # if ERR<STOPTOL;
nikcleju@2 529 # done = 1;
nikcleju@2 530 # elseif verbose && toc-t>10
nikcleju@2 531 # display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
nikcleju@2 532 # t=toc;
nikcleju@2 533 # end
nikcleju@2 534 # end
nikcleju@2 535 # elseif strcmp(STOPCRIT,'mse_change') && iter >=2
nikcleju@2 536 # if comp_err && iter >=2
nikcleju@2 537 # if ((err_mse(iter-1)-err_mse(iter))/sigsize <STOPTOL);
nikcleju@2 538 # done = 1;
nikcleju@2 539 # elseif verbose && toc-t>10
nikcleju@2 540 # display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
nikcleju@2 541 # t=toc;
nikcleju@2 542 # end
nikcleju@2 543 # else
nikcleju@2 544 # if ((oldERR - ERR)/sigsize < STOPTOL);
nikcleju@2 545 # done = 1;
nikcleju@2 546 # elseif verbose && toc-t>10
nikcleju@2 547 # display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
nikcleju@2 548 # t=toc;
nikcleju@2 549 # end
nikcleju@2 550 # end
nikcleju@2 551 # elseif strcmp(STOPCRIT,'corr')
nikcleju@2 552 # if max(abs(DR)) < STOPTOL;
nikcleju@2 553 # done = 1;
nikcleju@2 554 # elseif verbose && toc-t>10
nikcleju@2 555 # display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
nikcleju@2 556 # t=toc;
nikcleju@2 557 # end
nikcleju@2 558 # end
nikcleju@2 559 #
nikcleju@2 560 # # Also stop if residual gets too small or maxIter reached
nikcleju@2 561 # if comp_err
nikcleju@2 562 # if err_mse(iter)<1e-16
nikcleju@2 563 # display('Stopping. Exact signal representation found!')
nikcleju@2 564 # done=1;
nikcleju@2 565 # end
nikcleju@2 566 # else
nikcleju@2 567 #
nikcleju@2 568 #
nikcleju@2 569 # if iter>1
nikcleju@2 570 # if ERR<1e-16
nikcleju@2 571 # display('Stopping. Exact signal representation found!')
nikcleju@2 572 # done=1;
nikcleju@2 573 # end
nikcleju@2 574 # end
nikcleju@2 575 # end
nikcleju@2 576 #
nikcleju@2 577 # if iter >= MAXITER
nikcleju@2 578 # display('Stopping. Maximum number of iterations reached!')
nikcleju@2 579 # done = 1;
nikcleju@2 580 # end
nikcleju@2 581 #
nikcleju@2 582 ############################################################################
nikcleju@2 583 ## If not done, take another round
nikcleju@2 584 ############################################################################
nikcleju@2 585 #
nikcleju@2 586 # if ~done
nikcleju@2 587 # iter=iter+1;
nikcleju@2 588 # oldERR=ERR;
nikcleju@2 589 # end
nikcleju@2 590 #end
nikcleju@2 591 while not done:
nikcleju@2 592
nikcleju@2 593 # Select new element
nikcleju@2 594 DR[IN]=0
nikcleju@2 595 #[v I]=max(abs(DR));
nikcleju@2 596 #v = np.abs(DR).max()
nikcleju@2 597 I = np.abs(DR).argmax()
nikcleju@2 598 #IN = np.concatenate((IN,I))
nikcleju@2 599 IN.append(I)
nikcleju@2 600
nikcleju@2 601
nikcleju@2 602 #k = k + 1 Move to end, since is zero based
nikcleju@2 603
nikcleju@2 604 # Extract new element
nikcleju@2 605 mask = np.zeros(m)
nikcleju@2 606 mask[IN[k]] = 1
nikcleju@2 607 new_element = P(mask)
nikcleju@2 608
nikcleju@2 609 # Orthogonalise new element
nikcleju@2 610 if k-1 >= 0:
nikcleju@2 611 qP = np.dot(Q[:,0:k].T , new_element)
nikcleju@2 612 q = new_element - np.dot(Q[:,0:k] , qP)
nikcleju@2 613
nikcleju@2 614 nq = np.linalg.norm(q)
nikcleju@2 615 q = q/nq
nikcleju@2 616 # Update QR factorisation
nikcleju@2 617 R[0:k,k] = qP
nikcleju@2 618 R[k,k] = nq
nikcleju@2 619 Q[:,k] = q
nikcleju@2 620 else:
nikcleju@2 621 q = new_element
nikcleju@2 622
nikcleju@2 623 nq = np.linalg.norm(q)
nikcleju@2 624 q = q/nq
nikcleju@2 625 # Update QR factorisation
nikcleju@2 626 R[k,k] = nq
nikcleju@2 627 Q[:,k] = q
nikcleju@2 628
nikcleju@2 629 #z[k]=np.vdot(q , x)
nikcleju@2 630 z.append(np.vdot(q , x))
nikcleju@2 631
nikcleju@2 632 # New residual
nikcleju@2 633 Residual = Residual - q * (z[k])
nikcleju@2 634 DR = Pt(Residual)
nikcleju@2 635
nikcleju@2 636 ERR = np.vdot(Residual , Residual) / n
nikcleju@2 637 if comp_err:
nikcleju@2 638 err_mse[iter-1] = ERR
nikcleju@2 639 #end
nikcleju@2 640
nikcleju@2 641 if comp_time:
nikcleju@2 642 iter_time[iter-1] = time.time() - tic
nikcleju@2 643 #end
nikcleju@2 644
nikcleju@2 645 ###########################################################################
nikcleju@2 646 # Are we done yet?
nikcleju@2 647 ###########################################################################
nikcleju@2 648 if STOPCRIT == 'M':
nikcleju@2 649 if iter >= STOPTOL:
nikcleju@2 650 done = 1
nikcleju@2 651 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@2 652 #display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
nikcleju@2 653 print 'Iteration '+iter+'. --- '+(STOPTOL-iter)+' iterations to go'
nikcleju@2 654 t = time.time()
nikcleju@2 655 #end
nikcleju@2 656 elif STOPCRIT =='mse':
nikcleju@2 657 if comp_err:
nikcleju@2 658 if err_mse[iter-1] < STOPTOL:
nikcleju@2 659 done = 1
nikcleju@2 660 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@2 661 #display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
nikcleju@2 662 print 'Iteration '+iter+'. --- '+err_mse[iter-1]+' mse'
nikcleju@2 663 t = time.time()
nikcleju@2 664 #end
nikcleju@2 665 else:
nikcleju@2 666 if ERR < STOPTOL:
nikcleju@2 667 done = 1
nikcleju@2 668 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@2 669 #display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
nikcleju@2 670 print 'Iteration '+iter+'. --- '+ERR+' mse'
nikcleju@2 671 t = time.time()
nikcleju@2 672 #end
nikcleju@2 673 #end
nikcleju@2 674 elif STOPCRIT == 'mse_change' and iter >=2:
nikcleju@2 675 if comp_err and iter >=2:
nikcleju@2 676 if ((err_mse[iter-2] - err_mse[iter-1])/sigsize < STOPTOL):
nikcleju@2 677 done = 1
nikcleju@2 678 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@2 679 #display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
nikcleju@2 680 print 'Iteration '+iter+'. --- '+((err_mse[iter-2]-err_mse[iter-1])/sigsize)+' mse change'
nikcleju@2 681 t = time.time()
nikcleju@2 682 #end
nikcleju@2 683 else:
nikcleju@2 684 if ((oldERR - ERR)/sigsize < STOPTOL):
nikcleju@2 685 done = 1
nikcleju@2 686 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@2 687 #display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
nikcleju@2 688 print 'Iteration '+iter+'. --- '+((oldERR - ERR)/sigsize)+' mse change'
nikcleju@2 689 t = time.time()
nikcleju@2 690 #end
nikcleju@2 691 #end
nikcleju@2 692 elif STOPCRIT == 'corr':
nikcleju@2 693 if np.abs(DR).max() < STOPTOL:
nikcleju@2 694 done = 1
nikcleju@2 695 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@2 696 #display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
nikcleju@2 697 print 'Iteration '+iter+'. --- '+(np.abs(DR).max())+' corr'
nikcleju@2 698 t = time.time()
nikcleju@2 699 #end
nikcleju@2 700 #end
nikcleju@2 701
nikcleju@2 702 # Also stop if residual gets too small or maxIter reached
nikcleju@2 703 if comp_err:
nikcleju@2 704 if err_mse[iter-1] < 1e-14:
nikcleju@2 705 done = 1
nikcleju@2 706 # Nic: added verbose check
nikcleju@2 707 if verbose:
nikcleju@2 708 print 'Stopping. Exact signal representation found!'
nikcleju@2 709 #end
nikcleju@2 710 else:
nikcleju@2 711 if iter > 1:
nikcleju@2 712 if ERR < 1e-14:
nikcleju@2 713 done = 1
nikcleju@2 714 # Nic: added verbose check
nikcleju@2 715 if verbose:
nikcleju@2 716 print 'Stopping. Exact signal representation found!'
nikcleju@2 717 #end
nikcleju@2 718 #end
nikcleju@2 719 #end
nikcleju@2 720
nikcleju@2 721
nikcleju@2 722 if iter >= MAXITER:
nikcleju@2 723 done = 1
nikcleju@2 724 # Nic: added verbose check
nikcleju@2 725 if verbose:
nikcleju@2 726 print 'Stopping. Maximum number of iterations reached!'
nikcleju@2 727 #end
nikcleju@2 728
nikcleju@2 729 ###########################################################################
nikcleju@2 730 # If not done, take another round
nikcleju@2 731 ###########################################################################
nikcleju@2 732 if not done:
nikcleju@2 733 iter = iter + 1
nikcleju@2 734 oldERR = ERR
nikcleju@2 735 #end
nikcleju@2 736
nikcleju@2 737 # Moved here from front, since we are 0-based
nikcleju@2 738 k = k + 1
nikcleju@2 739 #end
nikcleju@2 740
nikcleju@2 741 ###########################################################################
nikcleju@2 742 # Now we can solve for s by back-substitution
nikcleju@2 743 ###########################################################################
nikcleju@2 744 #s(IN)=R(1:k,1:k)\z(1:k)';
nikcleju@2 745 s[IN] = scipy.linalg.solve(R[0:k,0:k] , np.array(z[0:k]))
nikcleju@2 746
nikcleju@2 747 ###########################################################################
nikcleju@2 748 # Only return as many elements as iterations
nikcleju@2 749 ###########################################################################
nikcleju@2 750 if opts['nargout'] >= 2:
nikcleju@2 751 err_mse = err_mse[0:iter-1]
nikcleju@2 752 #end
nikcleju@2 753 if opts['nargout'] == 3:
nikcleju@2 754 iter_time = iter_time[0:iter-1]
nikcleju@2 755 #end
nikcleju@2 756 if verbose:
nikcleju@2 757 print 'Done'
nikcleju@2 758 #end
nikcleju@2 759
nikcleju@2 760 # Return
nikcleju@2 761 if opts['nargout'] == 1:
nikcleju@2 762 return s
nikcleju@2 763 elif opts['nargout'] == 2:
nikcleju@2 764 return s, err_mse
nikcleju@2 765 elif opts['nargout'] == 3:
nikcleju@2 766 return s, err_mse, iter_time
nikcleju@2 767
nikcleju@2 768 # Change history
nikcleju@2 769 #
nikcleju@2 770 # 8 of Februray: Algo does no longer stop if dictionary is not normaliesd.
nikcleju@2 771
nikcleju@2 772 # End of greed_omp_qr() function
nikcleju@2 773 #--------------------------------
nikcleju@2 774
nikcleju@2 775
nikcleju@2 776 def omp_qr(x, dict, D, natom, tolerance):
nikcleju@2 777 """ Recover x using QR implementation of OMP
nikcleju@2 778
nikcleju@2 779 Parameter
nikcleju@2 780 ---------
nikcleju@2 781 x: measurements
nikcleju@2 782 dict: dictionary
nikcleju@2 783 D: Gramian of dictionary
nikcleju@2 784 natom: iterations
nikcleju@2 785 tolerance: error tolerance
nikcleju@2 786
nikcleju@2 787 Return
nikcleju@2 788 ------
nikcleju@2 789 x_hat : estimate of x
nikcleju@2 790 gamma : indices where non-zero
nikcleju@2 791
nikcleju@2 792 For more information, see http://media.aau.dk/null_space_pursuits/2011/10/efficient-omp.html
nikcleju@2 793 """
nikcleju@2 794 msize, dictsize = dict.shape
nikcleju@2 795 normr2 = np.vdot(x,x)
nikcleju@2 796 normtol2 = tolerance*normr2
nikcleju@2 797 R = np.zeros((natom,natom))
nikcleju@2 798 Q = np.zeros((msize,natom))
nikcleju@2 799 gamma = []
nikcleju@2 800
nikcleju@2 801 # find initial projections
nikcleju@2 802 origprojections = np.dot(x.T,dict)
nikcleju@2 803 origprojectionsT = origprojections.T
nikcleju@2 804 projections = origprojections.copy();
nikcleju@2 805
nikcleju@2 806 k = 0
nikcleju@2 807 while (normr2 > normtol2) and (k < natom):
nikcleju@2 808 # find index of maximum magnitude projection
nikcleju@2 809 newgam = np.argmax(np.abs(projections ** 2))
nikcleju@2 810 gamma.append(newgam)
nikcleju@2 811 # update QR factorization, projections, and residual energy
nikcleju@2 812 if k == 0:
nikcleju@2 813 R[0,0] = 1
nikcleju@2 814 Q[:,0] = dict[:,newgam].copy()
nikcleju@2 815 # update projections
nikcleju@2 816 QtempQtempT = np.outer(Q[:,0],Q[:,0])
nikcleju@2 817 projections -= np.dot(x.T, np.dot(QtempQtempT,dict))
nikcleju@2 818 # update residual energy
nikcleju@2 819 normr2 -= np.vdot(x, np.dot(QtempQtempT,x))
nikcleju@2 820 else:
nikcleju@2 821 w = scipy.linalg.solve_triangular(R[0:k,0:k],D[gamma[0:k],newgam],trans=1)
nikcleju@2 822 R[k,k] = np.sqrt(1-np.vdot(w,w))
nikcleju@2 823 R[0:k,k] = w.copy()
nikcleju@2 824 Q[:,k] = (dict[:,newgam] - np.dot(QtempQtempT,dict[:,newgam]))/R[k,k]
nikcleju@2 825 QkQkT = np.outer(Q[:,k],Q[:,k])
nikcleju@2 826 xTQkQkT = np.dot(x.T,QkQkT)
nikcleju@2 827 QtempQtempT += QkQkT
nikcleju@2 828 # update projections
nikcleju@2 829 projections -= np.dot(xTQkQkT,dict)
nikcleju@2 830 # update residual energy
nikcleju@2 831 normr2 -= np.dot(xTQkQkT,x)
nikcleju@2 832
nikcleju@2 833 k += 1
nikcleju@2 834
nikcleju@2 835 # build solution
nikcleju@2 836 tempR = R[0:k,0:k]
nikcleju@2 837 w = scipy.linalg.solve_triangular(tempR,origprojectionsT[gamma[0:k]],trans=1)
nikcleju@2 838 x_hat = np.zeros((dictsize,1))
nikcleju@2 839 x_hat[gamma[0:k]] = scipy.linalg.solve_triangular(tempR,w)
nikcleju@2 840
nikcleju@2 841 return x_hat, gamma