annotate OMP/omp_QR.py @ 1:2a2abf5092f8

Organized into test file and lib files Changed sklearn cholesky function to behave as the others: tol does not override number of atoms, but the two conditions work together
author nikcleju
date Thu, 20 Oct 2011 21:06:06 +0000
parents
children
rev   line source
nikcleju@1 1 import numpy as np
nikcleju@1 2 import scipy.linalg
nikcleju@1 3 import time
nikcleju@1 4 import math
nikcleju@1 5
nikcleju@1 6
nikcleju@1 7 #function [s, err_mse, iter_time]=greed_omp_qr(x,A,m,varargin)
nikcleju@1 8 def greed_omp_qr(x,A,m,opts=[]):
nikcleju@1 9 # greed_omp_qr: Orthogonal Matching Pursuit algorithm based on QR
nikcleju@1 10 # factorisation
nikcleju@1 11 # Nic: translated to Python on 19.10.2011. Original Matlab Code by Thomas Blumensath
nikcleju@1 12 ###########################################################################
nikcleju@1 13 # Usage
nikcleju@1 14 # [s, err_mse, iter_time]=greed_omp_qr(x,P,m,'option_name','option_value')
nikcleju@1 15 ###########################################################################
nikcleju@1 16 ###########################################################################
nikcleju@1 17 # Input
nikcleju@1 18 # Mandatory:
nikcleju@1 19 # x Observation vector to be decomposed
nikcleju@1 20 # P Either:
nikcleju@1 21 # 1) An nxm matrix (n must be dimension of x)
nikcleju@1 22 # 2) A function handle (type "help function_format"
nikcleju@1 23 # for more information)
nikcleju@1 24 # Also requires specification of P_trans option.
nikcleju@1 25 # 3) An object handle (type "help object_format" for
nikcleju@1 26 # more information)
nikcleju@1 27 # m length of s
nikcleju@1 28 #
nikcleju@1 29 # Possible additional options:
nikcleju@1 30 # (specify as many as you want using 'option_name','option_value' pairs)
nikcleju@1 31 # See below for explanation of options:
nikcleju@1 32 #__________________________________________________________________________
nikcleju@1 33 # option_name | available option_values | default
nikcleju@1 34 #--------------------------------------------------------------------------
nikcleju@1 35 # stopCrit | M, corr, mse, mse_change | M
nikcleju@1 36 # stopTol | number (see below) | n/4
nikcleju@1 37 # P_trans | function_handle (see below) |
nikcleju@1 38 # maxIter | positive integer (see below) | n
nikcleju@1 39 # verbose | true, false | false
nikcleju@1 40 # start_val | vector of length m | zeros
nikcleju@1 41 #
nikcleju@1 42 # Available stopping criteria :
nikcleju@1 43 # M - Extracts exactly M = stopTol elements.
nikcleju@1 44 # corr - Stops when maximum correlation between
nikcleju@1 45 # residual and atoms is below stopTol value.
nikcleju@1 46 # mse - Stops when mean squared error of residual
nikcleju@1 47 # is below stopTol value.
nikcleju@1 48 # mse_change - Stops when the change in the mean squared
nikcleju@1 49 # error falls below stopTol value.
nikcleju@1 50 #
nikcleju@1 51 # stopTol: Value for stopping criterion.
nikcleju@1 52 #
nikcleju@1 53 # P_trans: If P is a function handle, then P_trans has to be specified and
nikcleju@1 54 # must be a function handle.
nikcleju@1 55 #
nikcleju@1 56 # maxIter: Maximum number of allowed iterations.
nikcleju@1 57 #
nikcleju@1 58 # verbose: Logical value to allow algorithm progress to be displayed.
nikcleju@1 59 #
nikcleju@1 60 # start_val: Allows algorithms to start from partial solution.
nikcleju@1 61 #
nikcleju@1 62 ###########################################################################
nikcleju@1 63 # Outputs
nikcleju@1 64 # s Solution vector
nikcleju@1 65 # err_mse Vector containing mse of approximation error for each
nikcleju@1 66 # iteration
nikcleju@1 67 # iter_time Vector containing computation times for each iteration
nikcleju@1 68 #
nikcleju@1 69 ###########################################################################
nikcleju@1 70 # Description
nikcleju@1 71 # greed_omp_qr performs a greedy signal decomposition.
nikcleju@1 72 # In each iteration a new element is selected depending on the inner
nikcleju@1 73 # product between the current residual and columns in P.
nikcleju@1 74 # The non-zero elements of s are approximated by orthogonally projecting
nikcleju@1 75 # x onto the selected elements in each iteration.
nikcleju@1 76 # The algorithm uses QR decomposition.
nikcleju@1 77 #
nikcleju@1 78 # See Also
nikcleju@1 79 # greed_omp_chol, greed_omp_cg, greed_omp_cgp, greed_omp_pinv,
nikcleju@1 80 # greed_omp_linsolve, greed_gp, greed_nomp
nikcleju@1 81 #
nikcleju@1 82 # Copyright (c) 2007 Thomas Blumensath
nikcleju@1 83 #
nikcleju@1 84 # The University of Edinburgh
nikcleju@1 85 # Email: thomas.blumensath@ed.ac.uk
nikcleju@1 86 # Comments and bug reports welcome
nikcleju@1 87 #
nikcleju@1 88 # This file is part of sparsity Version 0.1
nikcleju@1 89 # Created: April 2007
nikcleju@1 90 #
nikcleju@1 91 # Part of this toolbox was developed with the support of EPSRC Grant
nikcleju@1 92 # D000246/1
nikcleju@1 93 #
nikcleju@1 94 # Please read COPYRIGHT.m for terms and conditions.
nikcleju@1 95
nikcleju@1 96 ###########################################################################
nikcleju@1 97 # Default values and initialisation
nikcleju@1 98 ###########################################################################
nikcleju@1 99 #[n1 n2]=size(x);
nikcleju@1 100 #n1,n2 = x.shape
nikcleju@1 101 #if n2 == 1
nikcleju@1 102 # n=n1;
nikcleju@1 103 #elseif n1 == 1
nikcleju@1 104 # x=x';
nikcleju@1 105 # n=n2;
nikcleju@1 106 #else
nikcleju@1 107 # display('x must be a vector.');
nikcleju@1 108 # return
nikcleju@1 109 #end
nikcleju@1 110 if x.ndim != 1:
nikcleju@1 111 print 'x must be a vector.'
nikcleju@1 112 return
nikcleju@1 113 n = x.size
nikcleju@1 114
nikcleju@1 115 #sigsize = x'*x/n;
nikcleju@1 116 sigsize = np.vdot(x,x)/n;
nikcleju@1 117 initial_given = 0;
nikcleju@1 118 err_mse = np.array([]);
nikcleju@1 119 iter_time = np.array([]);
nikcleju@1 120 STOPCRIT = 'M';
nikcleju@1 121 STOPTOL = math.ceil(n/4.0);
nikcleju@1 122 MAXITER = n;
nikcleju@1 123 verbose = False;
nikcleju@1 124 s_initial = np.zeros(m);
nikcleju@1 125
nikcleju@1 126 if verbose:
nikcleju@1 127 print 'Initialising...'
nikcleju@1 128 #end
nikcleju@1 129
nikcleju@1 130 ###########################################################################
nikcleju@1 131 # Output variables
nikcleju@1 132 ###########################################################################
nikcleju@1 133 #switch nargout
nikcleju@1 134 # case 3
nikcleju@1 135 # comp_err=true;
nikcleju@1 136 # comp_time=true;
nikcleju@1 137 # case 2
nikcleju@1 138 # comp_err=true;
nikcleju@1 139 # comp_time=false;
nikcleju@1 140 # case 1
nikcleju@1 141 # comp_err=false;
nikcleju@1 142 # comp_time=false;
nikcleju@1 143 # case 0
nikcleju@1 144 # error('Please assign output variable.')
nikcleju@1 145 # otherwise
nikcleju@1 146 # error('Too many output arguments specified')
nikcleju@1 147 #end
nikcleju@1 148 if 'nargout' in opts:
nikcleju@1 149 if opts['nargout'] == 3:
nikcleju@1 150 comp_err = True
nikcleju@1 151 comp_time = True
nikcleju@1 152 elif opts['nargout'] == 2:
nikcleju@1 153 comp_err = True
nikcleju@1 154 comp_time = False
nikcleju@1 155 elif opts['nargout'] == 1:
nikcleju@1 156 comp_err = False
nikcleju@1 157 comp_time = False
nikcleju@1 158 elif opts['nargout'] == 0:
nikcleju@1 159 print 'Please assign output variable.'
nikcleju@1 160 return
nikcleju@1 161 else:
nikcleju@1 162 print 'Too many output arguments specified'
nikcleju@1 163 return
nikcleju@1 164 else:
nikcleju@1 165 # If not given, make default nargout = 3
nikcleju@1 166 # and add nargout to options
nikcleju@1 167 opts['nargout'] = 3
nikcleju@1 168 comp_err = True
nikcleju@1 169 comp_time = True
nikcleju@1 170
nikcleju@1 171 ###########################################################################
nikcleju@1 172 # Look through options
nikcleju@1 173 ###########################################################################
nikcleju@1 174 # Put option into nice format
nikcleju@1 175 #Options={};
nikcleju@1 176 #OS=nargin-3;
nikcleju@1 177 #c=1;
nikcleju@1 178 #for i=1:OS
nikcleju@1 179 # if isa(varargin{i},'cell')
nikcleju@1 180 # CellSize=length(varargin{i});
nikcleju@1 181 # ThisCell=varargin{i};
nikcleju@1 182 # for j=1:CellSize
nikcleju@1 183 # Options{c}=ThisCell{j};
nikcleju@1 184 # c=c+1;
nikcleju@1 185 # end
nikcleju@1 186 # else
nikcleju@1 187 # Options{c}=varargin{i};
nikcleju@1 188 # c=c+1;
nikcleju@1 189 # end
nikcleju@1 190 #end
nikcleju@1 191 #OS=length(Options);
nikcleju@1 192 #if rem(OS,2)
nikcleju@1 193 # error('Something is wrong with argument name and argument value pairs.')
nikcleju@1 194 #end
nikcleju@1 195 #
nikcleju@1 196 #for i=1:2:OS
nikcleju@1 197 # switch Options{i}
nikcleju@1 198 # case {'stopCrit'}
nikcleju@1 199 # if (strmatch(Options{i+1},{'M'; 'corr'; 'mse'; 'mse_change'},'exact'));
nikcleju@1 200 # STOPCRIT = Options{i+1};
nikcleju@1 201 # else error('stopCrit must be char string [M, corr, mse, mse_change]. Exiting.'); end
nikcleju@1 202 # case {'stopTol'}
nikcleju@1 203 # if isa(Options{i+1},'numeric') ; STOPTOL = Options{i+1};
nikcleju@1 204 # else error('stopTol must be number. Exiting.'); end
nikcleju@1 205 # case {'P_trans'}
nikcleju@1 206 # if isa(Options{i+1},'function_handle'); Pt = Options{i+1};
nikcleju@1 207 # else error('P_trans must be function _handle. Exiting.'); end
nikcleju@1 208 # case {'maxIter'}
nikcleju@1 209 # if isa(Options{i+1},'numeric'); MAXITER = Options{i+1};
nikcleju@1 210 # else error('maxIter must be a number. Exiting.'); end
nikcleju@1 211 # case {'verbose'}
nikcleju@1 212 # if isa(Options{i+1},'logical'); verbose = Options{i+1};
nikcleju@1 213 # else error('verbose must be a logical. Exiting.'); end
nikcleju@1 214 # case {'start_val'}
nikcleju@1 215 # if isa(Options{i+1},'numeric') & length(Options{i+1}) == m ;
nikcleju@1 216 # s_initial = Options{i+1};
nikcleju@1 217 # initial_given=1;
nikcleju@1 218 # else error('start_val must be a vector of length m. Exiting.'); end
nikcleju@1 219 # otherwise
nikcleju@1 220 # error('Unrecognised option. Exiting.')
nikcleju@1 221 # end
nikcleju@1 222 #end
nikcleju@1 223 if 'stopCrit' in opts:
nikcleju@1 224 STOPCRIT = opts['stopCrit']
nikcleju@1 225 if 'stopTol' in opts:
nikcleju@1 226 if hasattr(opts['stopTol'], '__int__'): # check if numeric
nikcleju@1 227 STOPTOL = opts['stopTol']
nikcleju@1 228 else:
nikcleju@1 229 raise TypeError('stopTol must be number. Exiting.')
nikcleju@1 230 if 'P_trans' in opts:
nikcleju@1 231 if hasattr(opts['P_trans'], '__call__'): # check if function handle
nikcleju@1 232 Pt = opts['P_trans']
nikcleju@1 233 else:
nikcleju@1 234 raise TypeError('P_trans must be function _handle. Exiting.')
nikcleju@1 235 if 'maxIter' in opts:
nikcleju@1 236 if hasattr(opts['maxIter'], '__int__'): # check if numeric
nikcleju@1 237 MAXITER = opts['maxIter']
nikcleju@1 238 else:
nikcleju@1 239 raise TypeError('maxIter must be a number. Exiting.')
nikcleju@1 240 if 'verbose' in opts:
nikcleju@1 241 # TODO: Should check here if is logical
nikcleju@1 242 verbose = opts['verbose']
nikcleju@1 243 if 'start_val' in opts:
nikcleju@1 244 # TODO: Should check here if is numeric
nikcleju@1 245 if opts['start_val'].size == m:
nikcleju@1 246 s_initial = opts['start_val']
nikcleju@1 247 initial_given = 1
nikcleju@1 248 else:
nikcleju@1 249 raise ValueError('start_val must be a vector of length m. Exiting.')
nikcleju@1 250 # Don't exit if unknown option is given, simply ignore it
nikcleju@1 251
nikcleju@1 252 #if strcmp(STOPCRIT,'M')
nikcleju@1 253 # maxM=STOPTOL;
nikcleju@1 254 #else
nikcleju@1 255 # maxM=MAXITER;
nikcleju@1 256 #end
nikcleju@1 257 if STOPCRIT == 'M':
nikcleju@1 258 maxM = STOPTOL
nikcleju@1 259 else:
nikcleju@1 260 maxM = MAXITER
nikcleju@1 261
nikcleju@1 262 # if nargout >=2
nikcleju@1 263 # err_mse = zeros(maxM,1);
nikcleju@1 264 # end
nikcleju@1 265 # if nargout ==3
nikcleju@1 266 # iter_time = zeros(maxM,1);
nikcleju@1 267 # end
nikcleju@1 268 if opts['nargout'] >= 2:
nikcleju@1 269 err_mse = np.zeros(maxM)
nikcleju@1 270 if opts['nargout'] == 3:
nikcleju@1 271 iter_time = np.zeros(maxM)
nikcleju@1 272
nikcleju@1 273 ###########################################################################
nikcleju@1 274 # Make P and Pt functions
nikcleju@1 275 ###########################################################################
nikcleju@1 276 #if isa(A,'float') P =@(z) A*z; Pt =@(z) A'*z;
nikcleju@1 277 #elseif isobject(A) P =@(z) A*z; Pt =@(z) A'*z;
nikcleju@1 278 #elseif isa(A,'function_handle')
nikcleju@1 279 # try
nikcleju@1 280 # if isa(Pt,'function_handle'); P=A;
nikcleju@1 281 # else error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end
nikcleju@1 282 # catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end
nikcleju@1 283 #else error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end
nikcleju@1 284 if hasattr(A, '__call__'):
nikcleju@1 285 if hasattr(Pt, '__call__'):
nikcleju@1 286 P = A
nikcleju@1 287 else:
nikcleju@1 288 raise TypeError('If P is a function handle, Pt also needs to be a function handle.')
nikcleju@1 289 else:
nikcleju@1 290 # TODO: should check here if A is matrix
nikcleju@1 291 P = lambda z: np.dot(A,z)
nikcleju@1 292 Pt = lambda z: np.dot(A.T,z)
nikcleju@1 293
nikcleju@1 294 ###########################################################################
nikcleju@1 295 # Random Check to see if dictionary is normalised
nikcleju@1 296 ###########################################################################
nikcleju@1 297 # mask=zeros(m,1);
nikcleju@1 298 # mask(ceil(rand*m))=1;
nikcleju@1 299 # nP=norm(P(mask));
nikcleju@1 300 # if abs(1-nP)>1e-3;
nikcleju@1 301 # display('Dictionary appears not to have unit norm columns.')
nikcleju@1 302 # end
nikcleju@1 303 mask = np.zeros(m)
nikcleju@1 304 mask[math.floor(np.random.rand() * m)] = 1
nikcleju@1 305 nP = np.linalg.norm(P(mask))
nikcleju@1 306 if abs(1-nP) > 1e-3:
nikcleju@1 307 print 'Dictionary appears not to have unit norm columns.'
nikcleju@1 308 #end
nikcleju@1 309
nikcleju@1 310 ###########################################################################
nikcleju@1 311 # Check if we have enough memory and initialise
nikcleju@1 312 ###########################################################################
nikcleju@1 313 # try Q=zeros(n,maxM);
nikcleju@1 314 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
nikcleju@1 315 # end
nikcleju@1 316 # try R=zeros(maxM);
nikcleju@1 317 # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.')
nikcleju@1 318 # end
nikcleju@1 319 try:
nikcleju@1 320 Q = np.zeros((n,maxM))
nikcleju@1 321 except:
nikcleju@1 322 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
nikcleju@1 323 raise
nikcleju@1 324 try:
nikcleju@1 325 R = np.zeros((maxM, maxM))
nikcleju@1 326 except:
nikcleju@1 327 print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.'
nikcleju@1 328 raise
nikcleju@1 329
nikcleju@1 330 ###########################################################################
nikcleju@1 331 # Do we start from zero or not?
nikcleju@1 332 ###########################################################################
nikcleju@1 333 #if initial_given ==1;
nikcleju@1 334 # IN = find(s_initial);
nikcleju@1 335 # if ~isempty(IN)
nikcleju@1 336 # Residual = x-P(s_initial);
nikcleju@1 337 # lengthIN=length(IN);
nikcleju@1 338 # z=[];
nikcleju@1 339 # for k=1:length(IN)
nikcleju@1 340 # # Extract new element
nikcleju@1 341 # mask=zeros(m,1);
nikcleju@1 342 # mask(IN(k))=1;
nikcleju@1 343 # new_element=P(mask);
nikcleju@1 344 #
nikcleju@1 345 # # Orthogonalise new element
nikcleju@1 346 # qP=Q(:,1:k-1)'*new_element;
nikcleju@1 347 # q=new_element-Q(:,1:k-1)*(qP);
nikcleju@1 348 #
nikcleju@1 349 # nq=norm(q);
nikcleju@1 350 # q=q/nq;
nikcleju@1 351 # # Update QR factorisation
nikcleju@1 352 # R(1:k-1,k)=qP;
nikcleju@1 353 # R(k,k)=nq;
nikcleju@1 354 # Q(:,k)=q;
nikcleju@1 355 #
nikcleju@1 356 # z(k)=q'*x;
nikcleju@1 357 # end
nikcleju@1 358 # s = s_initial;
nikcleju@1 359 # Residual=x-Q(:,k)*z;
nikcleju@1 360 # oldERR = Residual'*Residual/n;
nikcleju@1 361 # else
nikcleju@1 362 # IN = [];
nikcleju@1 363 # Residual = x;
nikcleju@1 364 # s = s_initial;
nikcleju@1 365 # sigsize = x'*x/n;
nikcleju@1 366 # oldERR = sigsize;
nikcleju@1 367 # k=0;
nikcleju@1 368 # z=[];
nikcleju@1 369 # end
nikcleju@1 370 #
nikcleju@1 371 #else
nikcleju@1 372 # IN = [];
nikcleju@1 373 # Residual = x;
nikcleju@1 374 # s = s_initial;
nikcleju@1 375 # sigsize = x'*x/n;
nikcleju@1 376 # oldERR = sigsize;
nikcleju@1 377 # k=0;
nikcleju@1 378 # z=[];
nikcleju@1 379 #end
nikcleju@1 380 if initial_given == 1:
nikcleju@1 381 #IN = find(s_initial);
nikcleju@1 382 IN = np.nonzero(s_initial)[0].tolist()
nikcleju@1 383 #if ~isempty(IN)
nikcleju@1 384 if IN.size > 0:
nikcleju@1 385 Residual = x - P(s_initial)
nikcleju@1 386 lengthIN = IN.size
nikcleju@1 387 z = np.array([])
nikcleju@1 388 #for k=1:length(IN)
nikcleju@1 389 for k in np.arange(IN.size):
nikcleju@1 390 # Extract new element
nikcleju@1 391 mask = np.zeros(m)
nikcleju@1 392 mask[IN[k]] = 1
nikcleju@1 393 new_element = P(mask)
nikcleju@1 394
nikcleju@1 395 # Orthogonalise new element
nikcleju@1 396 #qP=Q(:,1:k-1)'*new_element;
nikcleju@1 397 if k-1 >= 0:
nikcleju@1 398 qP = np.dot(Q[:,0:k].T , new_element)
nikcleju@1 399 #q=new_element-Q(:,1:k-1)*(qP);
nikcleju@1 400 q = new_element - np.dot(Q[:,0:k] , qP)
nikcleju@1 401
nikcleju@1 402 nq = np.linalg.norm(q)
nikcleju@1 403 q = q / nq
nikcleju@1 404 # Update QR factorisation
nikcleju@1 405 R[0:k,k] = qP
nikcleju@1 406 R[k,k] = nq
nikcleju@1 407 Q[:,k] = q
nikcleju@1 408 else:
nikcleju@1 409 q = new_element
nikcleju@1 410
nikcleju@1 411 nq = np.linalg.norm(q)
nikcleju@1 412 q = q / nq
nikcleju@1 413 # Update QR factorisation
nikcleju@1 414 R[k,k] = nq
nikcleju@1 415 Q[:,k] = q
nikcleju@1 416
nikcleju@1 417 z[k] = np.dot(q.T , x)
nikcleju@1 418 #end
nikcleju@1 419 s = s_initial.copy()
nikcleju@1 420 Residual = x - np.dot(Q[:,k] , z)
nikcleju@1 421 oldERR = np.vdot(Residual , Residual) / n;
nikcleju@1 422 else:
nikcleju@1 423 #IN = np.array([], dtype = int)
nikcleju@1 424 IN = np.array([], dtype = int).tolist()
nikcleju@1 425 Residual = x.copy()
nikcleju@1 426 s = s_initial.copy()
nikcleju@1 427 sigsize = np.vdot(x , x) / n
nikcleju@1 428 oldERR = sigsize
nikcleju@1 429 k = 0
nikcleju@1 430 #z = np.array([])
nikcleju@1 431 z = []
nikcleju@1 432 #end
nikcleju@1 433
nikcleju@1 434 else:
nikcleju@1 435 #IN = np.array([], dtype = int)
nikcleju@1 436 IN = np.array([], dtype = int).tolist()
nikcleju@1 437 Residual = x.copy()
nikcleju@1 438 s = s_initial.copy()
nikcleju@1 439 sigsize = np.vdot(x , x) / n
nikcleju@1 440 oldERR = sigsize
nikcleju@1 441 k = 0
nikcleju@1 442 #z = np.array([])
nikcleju@1 443 z = []
nikcleju@1 444 #end
nikcleju@1 445
nikcleju@1 446 ###########################################################################
nikcleju@1 447 # Main algorithm
nikcleju@1 448 ###########################################################################
nikcleju@1 449 # if verbose
nikcleju@1 450 # display('Main iterations...')
nikcleju@1 451 # end
nikcleju@1 452 # tic
nikcleju@1 453 # t=0;
nikcleju@1 454 # DR=Pt(Residual);
nikcleju@1 455 # done = 0;
nikcleju@1 456 # iter=1;
nikcleju@1 457 if verbose:
nikcleju@1 458 print 'Main iterations...'
nikcleju@1 459 tic = time.time()
nikcleju@1 460 t = 0
nikcleju@1 461 DR = Pt(Residual)
nikcleju@1 462 done = 0
nikcleju@1 463 iter = 1
nikcleju@1 464
nikcleju@1 465 #while ~done
nikcleju@1 466 #
nikcleju@1 467 # # Select new element
nikcleju@1 468 # DR(IN)=0;
nikcleju@1 469 # # Nic: replace selection with random variable
nikcleju@1 470 # # i.e. Randomized OMP!!
nikcleju@1 471 # # DON'T FORGET ABOUT THIS!!
nikcleju@1 472 # [v I]=max(abs(DR));
nikcleju@1 473 # #I = randp(exp(abs(DR).^2 ./ (norms.^2)'), [1 1]);
nikcleju@1 474 # IN=[IN I];
nikcleju@1 475 #
nikcleju@1 476 #
nikcleju@1 477 # k=k+1;
nikcleju@1 478 # # Extract new element
nikcleju@1 479 # mask=zeros(m,1);
nikcleju@1 480 # mask(IN(k))=1;
nikcleju@1 481 # new_element=P(mask);
nikcleju@1 482 #
nikcleju@1 483 # # Orthogonalise new element
nikcleju@1 484 # qP=Q(:,1:k-1)'*new_element;
nikcleju@1 485 # q=new_element-Q(:,1:k-1)*(qP);
nikcleju@1 486 #
nikcleju@1 487 # nq=norm(q);
nikcleju@1 488 # q=q/nq;
nikcleju@1 489 # # Update QR factorisation
nikcleju@1 490 # R(1:k-1,k)=qP;
nikcleju@1 491 # R(k,k)=nq;
nikcleju@1 492 # Q(:,k)=q;
nikcleju@1 493 #
nikcleju@1 494 # z(k)=q'*x;
nikcleju@1 495 #
nikcleju@1 496 # # New residual
nikcleju@1 497 # Residual=Residual-q*(z(k));
nikcleju@1 498 # DR=Pt(Residual);
nikcleju@1 499 #
nikcleju@1 500 # ERR=Residual'*Residual/n;
nikcleju@1 501 # if comp_err
nikcleju@1 502 # err_mse(iter)=ERR;
nikcleju@1 503 # end
nikcleju@1 504 #
nikcleju@1 505 # if comp_time
nikcleju@1 506 # iter_time(iter)=toc;
nikcleju@1 507 # end
nikcleju@1 508 #
nikcleju@1 509 ############################################################################
nikcleju@1 510 ## Are we done yet?
nikcleju@1 511 ############################################################################
nikcleju@1 512 #
nikcleju@1 513 # if strcmp(STOPCRIT,'M')
nikcleju@1 514 # if iter >= STOPTOL
nikcleju@1 515 # done =1;
nikcleju@1 516 # elseif verbose && toc-t>10
nikcleju@1 517 # display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
nikcleju@1 518 # t=toc;
nikcleju@1 519 # end
nikcleju@1 520 # elseif strcmp(STOPCRIT,'mse')
nikcleju@1 521 # if comp_err
nikcleju@1 522 # if err_mse(iter)<STOPTOL;
nikcleju@1 523 # done = 1;
nikcleju@1 524 # elseif verbose && toc-t>10
nikcleju@1 525 # display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
nikcleju@1 526 # t=toc;
nikcleju@1 527 # end
nikcleju@1 528 # else
nikcleju@1 529 # if ERR<STOPTOL;
nikcleju@1 530 # done = 1;
nikcleju@1 531 # elseif verbose && toc-t>10
nikcleju@1 532 # display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
nikcleju@1 533 # t=toc;
nikcleju@1 534 # end
nikcleju@1 535 # end
nikcleju@1 536 # elseif strcmp(STOPCRIT,'mse_change') && iter >=2
nikcleju@1 537 # if comp_err && iter >=2
nikcleju@1 538 # if ((err_mse(iter-1)-err_mse(iter))/sigsize <STOPTOL);
nikcleju@1 539 # done = 1;
nikcleju@1 540 # elseif verbose && toc-t>10
nikcleju@1 541 # display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
nikcleju@1 542 # t=toc;
nikcleju@1 543 # end
nikcleju@1 544 # else
nikcleju@1 545 # if ((oldERR - ERR)/sigsize < STOPTOL);
nikcleju@1 546 # done = 1;
nikcleju@1 547 # elseif verbose && toc-t>10
nikcleju@1 548 # display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
nikcleju@1 549 # t=toc;
nikcleju@1 550 # end
nikcleju@1 551 # end
nikcleju@1 552 # elseif strcmp(STOPCRIT,'corr')
nikcleju@1 553 # if max(abs(DR)) < STOPTOL;
nikcleju@1 554 # done = 1;
nikcleju@1 555 # elseif verbose && toc-t>10
nikcleju@1 556 # display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
nikcleju@1 557 # t=toc;
nikcleju@1 558 # end
nikcleju@1 559 # end
nikcleju@1 560 #
nikcleju@1 561 # # Also stop if residual gets too small or maxIter reached
nikcleju@1 562 # if comp_err
nikcleju@1 563 # if err_mse(iter)<1e-16
nikcleju@1 564 # display('Stopping. Exact signal representation found!')
nikcleju@1 565 # done=1;
nikcleju@1 566 # end
nikcleju@1 567 # else
nikcleju@1 568 #
nikcleju@1 569 #
nikcleju@1 570 # if iter>1
nikcleju@1 571 # if ERR<1e-16
nikcleju@1 572 # display('Stopping. Exact signal representation found!')
nikcleju@1 573 # done=1;
nikcleju@1 574 # end
nikcleju@1 575 # end
nikcleju@1 576 # end
nikcleju@1 577 #
nikcleju@1 578 # if iter >= MAXITER
nikcleju@1 579 # display('Stopping. Maximum number of iterations reached!')
nikcleju@1 580 # done = 1;
nikcleju@1 581 # end
nikcleju@1 582 #
nikcleju@1 583 ############################################################################
nikcleju@1 584 ## If not done, take another round
nikcleju@1 585 ############################################################################
nikcleju@1 586 #
nikcleju@1 587 # if ~done
nikcleju@1 588 # iter=iter+1;
nikcleju@1 589 # oldERR=ERR;
nikcleju@1 590 # end
nikcleju@1 591 #end
nikcleju@1 592 while not done:
nikcleju@1 593
nikcleju@1 594 # Select new element
nikcleju@1 595 DR[IN]=0
nikcleju@1 596 #[v I]=max(abs(DR));
nikcleju@1 597 #v = np.abs(DR).max()
nikcleju@1 598 I = np.abs(DR).argmax()
nikcleju@1 599 #IN = np.concatenate((IN,I))
nikcleju@1 600 IN.append(I)
nikcleju@1 601
nikcleju@1 602
nikcleju@1 603 #k = k + 1 Move to end, since is zero based
nikcleju@1 604
nikcleju@1 605 # Extract new element
nikcleju@1 606 mask = np.zeros(m)
nikcleju@1 607 mask[IN[k]] = 1
nikcleju@1 608 new_element = P(mask)
nikcleju@1 609
nikcleju@1 610 # Orthogonalise new element
nikcleju@1 611 if k-1 >= 0:
nikcleju@1 612 qP = np.dot(Q[:,0:k].T , new_element)
nikcleju@1 613 q = new_element - np.dot(Q[:,0:k] , qP)
nikcleju@1 614
nikcleju@1 615 nq = np.linalg.norm(q)
nikcleju@1 616 q = q/nq
nikcleju@1 617 # Update QR factorisation
nikcleju@1 618 R[0:k,k] = qP
nikcleju@1 619 R[k,k] = nq
nikcleju@1 620 Q[:,k] = q
nikcleju@1 621 else:
nikcleju@1 622 q = new_element
nikcleju@1 623
nikcleju@1 624 nq = np.linalg.norm(q)
nikcleju@1 625 q = q/nq
nikcleju@1 626 # Update QR factorisation
nikcleju@1 627 R[k,k] = nq
nikcleju@1 628 Q[:,k] = q
nikcleju@1 629
nikcleju@1 630 #z[k]=np.vdot(q , x)
nikcleju@1 631 z.append(np.vdot(q , x))
nikcleju@1 632
nikcleju@1 633 # New residual
nikcleju@1 634 Residual = Residual - q * (z[k])
nikcleju@1 635 DR = Pt(Residual)
nikcleju@1 636
nikcleju@1 637 ERR = np.vdot(Residual , Residual) / n
nikcleju@1 638 if comp_err:
nikcleju@1 639 err_mse[iter-1] = ERR
nikcleju@1 640 #end
nikcleju@1 641
nikcleju@1 642 if comp_time:
nikcleju@1 643 iter_time[iter-1] = time.time() - tic
nikcleju@1 644 #end
nikcleju@1 645
nikcleju@1 646 ###########################################################################
nikcleju@1 647 # Are we done yet?
nikcleju@1 648 ###########################################################################
nikcleju@1 649 if STOPCRIT == 'M':
nikcleju@1 650 if iter >= STOPTOL:
nikcleju@1 651 done = 1
nikcleju@1 652 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@1 653 #display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter))
nikcleju@1 654 print 'Iteration '+iter+'. --- '+(STOPTOL-iter)+' iterations to go'
nikcleju@1 655 t = time.time()
nikcleju@1 656 #end
nikcleju@1 657 elif STOPCRIT =='mse':
nikcleju@1 658 if comp_err:
nikcleju@1 659 if err_mse[iter-1] < STOPTOL:
nikcleju@1 660 done = 1
nikcleju@1 661 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@1 662 #display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter)))
nikcleju@1 663 print 'Iteration '+iter+'. --- '+err_mse[iter-1]+' mse'
nikcleju@1 664 t = time.time()
nikcleju@1 665 #end
nikcleju@1 666 else:
nikcleju@1 667 if ERR < STOPTOL:
nikcleju@1 668 done = 1
nikcleju@1 669 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@1 670 #display(sprintf('Iteration #i. --- #i mse',iter ,ERR))
nikcleju@1 671 print 'Iteration '+iter+'. --- '+ERR+' mse'
nikcleju@1 672 t = time.time()
nikcleju@1 673 #end
nikcleju@1 674 #end
nikcleju@1 675 elif STOPCRIT == 'mse_change' and iter >=2:
nikcleju@1 676 if comp_err and iter >=2:
nikcleju@1 677 if ((err_mse[iter-2] - err_mse[iter-1])/sigsize < STOPTOL):
nikcleju@1 678 done = 1
nikcleju@1 679 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@1 680 #display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize ))
nikcleju@1 681 print 'Iteration '+iter+'. --- '+((err_mse[iter-2]-err_mse[iter-1])/sigsize)+' mse change'
nikcleju@1 682 t = time.time()
nikcleju@1 683 #end
nikcleju@1 684 else:
nikcleju@1 685 if ((oldERR - ERR)/sigsize < STOPTOL):
nikcleju@1 686 done = 1
nikcleju@1 687 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@1 688 #display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize))
nikcleju@1 689 print 'Iteration '+iter+'. --- '+((oldERR - ERR)/sigsize)+' mse change'
nikcleju@1 690 t = time.time()
nikcleju@1 691 #end
nikcleju@1 692 #end
nikcleju@1 693 elif STOPCRIT == 'corr':
nikcleju@1 694 if np.abs(DR).max() < STOPTOL:
nikcleju@1 695 done = 1
nikcleju@1 696 elif verbose and time.time() - t > 10.0/1000: # time() returns sec
nikcleju@1 697 #display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR))))
nikcleju@1 698 print 'Iteration '+iter+'. --- '+(np.abs(DR).max())+' corr'
nikcleju@1 699 t = time.time()
nikcleju@1 700 #end
nikcleju@1 701 #end
nikcleju@1 702
nikcleju@1 703 # Also stop if residual gets too small or maxIter reached
nikcleju@1 704 if comp_err:
nikcleju@1 705 if err_mse[iter-1] < 1e-14:
nikcleju@1 706 done = 1
nikcleju@1 707 # Nic: added verbose check
nikcleju@1 708 if verbose:
nikcleju@1 709 print 'Stopping. Exact signal representation found!'
nikcleju@1 710 #end
nikcleju@1 711 else:
nikcleju@1 712 if iter > 1:
nikcleju@1 713 if ERR < 1e-14:
nikcleju@1 714 done = 1
nikcleju@1 715 # Nic: added verbose check
nikcleju@1 716 if verbose:
nikcleju@1 717 print 'Stopping. Exact signal representation found!'
nikcleju@1 718 #end
nikcleju@1 719 #end
nikcleju@1 720 #end
nikcleju@1 721
nikcleju@1 722
nikcleju@1 723 if iter >= MAXITER:
nikcleju@1 724 done = 1
nikcleju@1 725 # Nic: added verbose check
nikcleju@1 726 if verbose:
nikcleju@1 727 print 'Stopping. Maximum number of iterations reached!'
nikcleju@1 728 #end
nikcleju@1 729
nikcleju@1 730 ###########################################################################
nikcleju@1 731 # If not done, take another round
nikcleju@1 732 ###########################################################################
nikcleju@1 733 if not done:
nikcleju@1 734 iter = iter + 1
nikcleju@1 735 oldERR = ERR
nikcleju@1 736 #end
nikcleju@1 737
nikcleju@1 738 # Moved here from front, since we are 0-based
nikcleju@1 739 k = k + 1
nikcleju@1 740 #end
nikcleju@1 741
nikcleju@1 742 ###########################################################################
nikcleju@1 743 # Now we can solve for s by back-substitution
nikcleju@1 744 ###########################################################################
nikcleju@1 745 #s(IN)=R(1:k,1:k)\z(1:k)';
nikcleju@1 746 s[IN] = scipy.linalg.solve(R[0:k,0:k] , np.array(z[0:k]))
nikcleju@1 747
nikcleju@1 748 ###########################################################################
nikcleju@1 749 # Only return as many elements as iterations
nikcleju@1 750 ###########################################################################
nikcleju@1 751 if opts['nargout'] >= 2:
nikcleju@1 752 err_mse = err_mse[0:iter-1]
nikcleju@1 753 #end
nikcleju@1 754 if opts['nargout'] == 3:
nikcleju@1 755 iter_time = iter_time[0:iter-1]
nikcleju@1 756 #end
nikcleju@1 757 if verbose:
nikcleju@1 758 print 'Done'
nikcleju@1 759 #end
nikcleju@1 760
nikcleju@1 761 # Return
nikcleju@1 762 if opts['nargout'] == 1:
nikcleju@1 763 return s
nikcleju@1 764 elif opts['nargout'] == 2:
nikcleju@1 765 return s, err_mse
nikcleju@1 766 elif opts['nargout'] == 3:
nikcleju@1 767 return s, err_mse, iter_time
nikcleju@1 768
nikcleju@1 769 # Change history
nikcleju@1 770 #
nikcleju@1 771 # 8 of Februray: Algo does no longer stop if dictionary is not normaliesd.
nikcleju@1 772
nikcleju@1 773 # End of greed_omp_qr() function
nikcleju@1 774 #--------------------------------
nikcleju@1 775
nikcleju@1 776
nikcleju@1 777 def omp_qr(x, dict, D, natom, tolerance):
nikcleju@1 778 """ Recover x using QR implementation of OMP
nikcleju@1 779
nikcleju@1 780 Parameter
nikcleju@1 781 ---------
nikcleju@1 782 x: measurements
nikcleju@1 783 dict: dictionary
nikcleju@1 784 D: Gramian of dictionary
nikcleju@1 785 natom: iterations
nikcleju@1 786 tolerance: error tolerance
nikcleju@1 787
nikcleju@1 788 Return
nikcleju@1 789 ------
nikcleju@1 790 x_hat : estimate of x
nikcleju@1 791 gamma : indices where non-zero
nikcleju@1 792
nikcleju@1 793 For more information, see http://media.aau.dk/null_space_pursuits/2011/10/efficient-omp.html
nikcleju@1 794 """
nikcleju@1 795 msize, dictsize = dict.shape
nikcleju@1 796 normr2 = np.vdot(x,x)
nikcleju@1 797 normtol2 = tolerance*normr2
nikcleju@1 798 R = np.zeros((natom,natom))
nikcleju@1 799 Q = np.zeros((msize,natom))
nikcleju@1 800 gamma = []
nikcleju@1 801
nikcleju@1 802 # find initial projections
nikcleju@1 803 origprojections = np.dot(x.T,dict)
nikcleju@1 804 origprojectionsT = origprojections.T
nikcleju@1 805 projections = origprojections.copy();
nikcleju@1 806
nikcleju@1 807 k = 0
nikcleju@1 808 while (normr2 > normtol2) and (k < natom):
nikcleju@1 809 # find index of maximum magnitude projection
nikcleju@1 810 newgam = np.argmax(np.abs(projections ** 2))
nikcleju@1 811 gamma.append(newgam)
nikcleju@1 812 # update QR factorization, projections, and residual energy
nikcleju@1 813 if k == 0:
nikcleju@1 814 R[0,0] = 1
nikcleju@1 815 Q[:,0] = dict[:,newgam].copy()
nikcleju@1 816 # update projections
nikcleju@1 817 QtempQtempT = np.outer(Q[:,0],Q[:,0])
nikcleju@1 818 projections -= np.dot(x.T, np.dot(QtempQtempT,dict))
nikcleju@1 819 # update residual energy
nikcleju@1 820 normr2 -= np.vdot(x, np.dot(QtempQtempT,x))
nikcleju@1 821 else:
nikcleju@1 822 w = scipy.linalg.solve_triangular(R[0:k,0:k],D[gamma[0:k],newgam],trans=1)
nikcleju@1 823 R[k,k] = np.sqrt(1-np.vdot(w,w))
nikcleju@1 824 R[0:k,k] = w.copy()
nikcleju@1 825 Q[:,k] = (dict[:,newgam] - np.dot(QtempQtempT,dict[:,newgam]))/R[k,k]
nikcleju@1 826 QkQkT = np.outer(Q[:,k],Q[:,k])
nikcleju@1 827 xTQkQkT = np.dot(x.T,QkQkT)
nikcleju@1 828 QtempQtempT += QkQkT
nikcleju@1 829 # update projections
nikcleju@1 830 projections -= np.dot(xTQkQkT,dict)
nikcleju@1 831 # update residual energy
nikcleju@1 832 normr2 -= np.dot(xTQkQkT,x)
nikcleju@1 833
nikcleju@1 834 k += 1
nikcleju@1 835
nikcleju@1 836 # build solution
nikcleju@1 837 tempR = R[0:k,0:k]
nikcleju@1 838 w = scipy.linalg.solve_triangular(tempR,origprojectionsT[gamma[0:k]],trans=1)
nikcleju@1 839 x_hat = np.zeros((dictsize,1))
nikcleju@1 840 x_hat[gamma[0:k]] = scipy.linalg.solve_triangular(tempR,w)
nikcleju@1 841
nikcleju@1 842 return x_hat, gamma