nikcleju@2: import numpy as np nikcleju@2: import scipy.linalg nikcleju@2: import time nikcleju@2: import math nikcleju@2: nikcleju@2: nikcleju@2: #function [s, err_mse, iter_time]=greed_omp_qr(x,A,m,varargin) nikcleju@2: def greed_omp_qr(x,A,m,opts=[]): nikcleju@2: # greed_omp_qr: Orthogonal Matching Pursuit algorithm based on QR nikcleju@2: # factorisation nikcleju@2: # Nic: translated to Python on 19.10.2011. Original Matlab Code by Thomas Blumensath nikcleju@2: ########################################################################### nikcleju@2: # Usage nikcleju@2: # [s, err_mse, iter_time]=greed_omp_qr(x,P,m,'option_name','option_value') nikcleju@2: ########################################################################### nikcleju@2: ########################################################################### nikcleju@2: # Input nikcleju@2: # Mandatory: nikcleju@2: # x Observation vector to be decomposed nikcleju@2: # P Either: nikcleju@2: # 1) An nxm matrix (n must be dimension of x) nikcleju@2: # 2) A function handle (type "help function_format" nikcleju@2: # for more information) nikcleju@2: # Also requires specification of P_trans option. nikcleju@2: # 3) An object handle (type "help object_format" for nikcleju@2: # more information) nikcleju@2: # m length of s nikcleju@2: # nikcleju@2: # Possible additional options: nikcleju@2: # (specify as many as you want using 'option_name','option_value' pairs) nikcleju@2: # See below for explanation of options: nikcleju@2: #__________________________________________________________________________ nikcleju@2: # option_name | available option_values | default nikcleju@2: #-------------------------------------------------------------------------- nikcleju@2: # stopCrit | M, corr, mse, mse_change | M nikcleju@2: # stopTol | number (see below) | n/4 nikcleju@2: # P_trans | function_handle (see below) | nikcleju@2: # maxIter | positive integer (see below) | n nikcleju@2: # verbose | true, false | false nikcleju@2: # start_val | vector of length m | zeros nikcleju@2: # nikcleju@2: # Available stopping criteria : nikcleju@2: # M - Extracts exactly M = stopTol elements. nikcleju@2: # corr - Stops when maximum correlation between nikcleju@2: # residual and atoms is below stopTol value. nikcleju@2: # mse - Stops when mean squared error of residual nikcleju@2: # is below stopTol value. nikcleju@2: # mse_change - Stops when the change in the mean squared nikcleju@2: # error falls below stopTol value. nikcleju@2: # nikcleju@2: # stopTol: Value for stopping criterion. nikcleju@2: # nikcleju@2: # P_trans: If P is a function handle, then P_trans has to be specified and nikcleju@2: # must be a function handle. nikcleju@2: # nikcleju@2: # maxIter: Maximum number of allowed iterations. nikcleju@2: # nikcleju@2: # verbose: Logical value to allow algorithm progress to be displayed. nikcleju@2: # nikcleju@2: # start_val: Allows algorithms to start from partial solution. nikcleju@2: # nikcleju@2: ########################################################################### nikcleju@2: # Outputs nikcleju@2: # s Solution vector nikcleju@2: # err_mse Vector containing mse of approximation error for each nikcleju@2: # iteration nikcleju@2: # iter_time Vector containing computation times for each iteration nikcleju@2: # nikcleju@2: ########################################################################### nikcleju@2: # Description nikcleju@2: # greed_omp_qr performs a greedy signal decomposition. nikcleju@2: # In each iteration a new element is selected depending on the inner nikcleju@2: # product between the current residual and columns in P. nikcleju@2: # The non-zero elements of s are approximated by orthogonally projecting nikcleju@2: # x onto the selected elements in each iteration. nikcleju@2: # The algorithm uses QR decomposition. nikcleju@2: # nikcleju@2: # See Also nikcleju@2: # greed_omp_chol, greed_omp_cg, greed_omp_cgp, greed_omp_pinv, nikcleju@2: # greed_omp_linsolve, greed_gp, greed_nomp nikcleju@2: # nikcleju@2: # Copyright (c) 2007 Thomas Blumensath nikcleju@2: # nikcleju@2: # The University of Edinburgh nikcleju@2: # Email: thomas.blumensath@ed.ac.uk nikcleju@2: # Comments and bug reports welcome nikcleju@2: # nikcleju@2: # This file is part of sparsity Version 0.1 nikcleju@2: # Created: April 2007 nikcleju@2: # nikcleju@2: # Part of this toolbox was developed with the support of EPSRC Grant nikcleju@2: # D000246/1 nikcleju@2: # nikcleju@2: # Please read COPYRIGHT.m for terms and conditions. nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Default values and initialisation nikcleju@2: ########################################################################### nikcleju@2: #[n1 n2]=size(x); nikcleju@2: #n1,n2 = x.shape nikcleju@2: #if n2 == 1 nikcleju@2: # n=n1; nikcleju@2: #elseif n1 == 1 nikcleju@2: # x=x'; nikcleju@2: # n=n2; nikcleju@2: #else nikcleju@2: # display('x must be a vector.'); nikcleju@2: # return nikcleju@2: #end nikcleju@2: if x.ndim != 1: nikcleju@2: print 'x must be a vector.' nikcleju@2: return nikcleju@2: n = x.size nikcleju@2: nikcleju@2: #sigsize = x'*x/n; nikcleju@2: sigsize = np.vdot(x,x)/n; nikcleju@2: initial_given = 0; nikcleju@2: err_mse = np.array([]); nikcleju@2: iter_time = np.array([]); nikcleju@2: STOPCRIT = 'M'; nikcleju@2: STOPTOL = math.ceil(n/4.0); nikcleju@2: MAXITER = n; nikcleju@2: verbose = False; nikcleju@2: s_initial = np.zeros(m); nikcleju@2: nikcleju@2: if verbose: nikcleju@2: print 'Initialising...' nikcleju@2: #end nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Output variables nikcleju@2: ########################################################################### nikcleju@2: #switch nargout nikcleju@2: # case 3 nikcleju@2: # comp_err=true; nikcleju@2: # comp_time=true; nikcleju@2: # case 2 nikcleju@2: # comp_err=true; nikcleju@2: # comp_time=false; nikcleju@2: # case 1 nikcleju@2: # comp_err=false; nikcleju@2: # comp_time=false; nikcleju@2: # case 0 nikcleju@2: # error('Please assign output variable.') nikcleju@2: # otherwise nikcleju@2: # error('Too many output arguments specified') nikcleju@2: #end nikcleju@2: if 'nargout' in opts: nikcleju@2: if opts['nargout'] == 3: nikcleju@2: comp_err = True nikcleju@2: comp_time = True nikcleju@2: elif opts['nargout'] == 2: nikcleju@2: comp_err = True nikcleju@2: comp_time = False nikcleju@2: elif opts['nargout'] == 1: nikcleju@2: comp_err = False nikcleju@2: comp_time = False nikcleju@2: elif opts['nargout'] == 0: nikcleju@2: print 'Please assign output variable.' nikcleju@2: return nikcleju@2: else: nikcleju@2: print 'Too many output arguments specified' nikcleju@2: return nikcleju@2: else: nikcleju@2: # If not given, make default nargout = 3 nikcleju@2: # and add nargout to options nikcleju@2: opts['nargout'] = 3 nikcleju@2: comp_err = True nikcleju@2: comp_time = True nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Look through options nikcleju@2: ########################################################################### nikcleju@2: # Put option into nice format nikcleju@2: #Options={}; nikcleju@2: #OS=nargin-3; nikcleju@2: #c=1; nikcleju@2: #for i=1:OS nikcleju@2: # if isa(varargin{i},'cell') nikcleju@2: # CellSize=length(varargin{i}); nikcleju@2: # ThisCell=varargin{i}; nikcleju@2: # for j=1:CellSize nikcleju@2: # Options{c}=ThisCell{j}; nikcleju@2: # c=c+1; nikcleju@2: # end nikcleju@2: # else nikcleju@2: # Options{c}=varargin{i}; nikcleju@2: # c=c+1; nikcleju@2: # end nikcleju@2: #end nikcleju@2: #OS=length(Options); nikcleju@2: #if rem(OS,2) nikcleju@2: # error('Something is wrong with argument name and argument value pairs.') nikcleju@2: #end nikcleju@2: # nikcleju@2: #for i=1:2:OS nikcleju@2: # switch Options{i} nikcleju@2: # case {'stopCrit'} nikcleju@2: # if (strmatch(Options{i+1},{'M'; 'corr'; 'mse'; 'mse_change'},'exact')); nikcleju@2: # STOPCRIT = Options{i+1}; nikcleju@2: # else error('stopCrit must be char string [M, corr, mse, mse_change]. Exiting.'); end nikcleju@2: # case {'stopTol'} nikcleju@2: # if isa(Options{i+1},'numeric') ; STOPTOL = Options{i+1}; nikcleju@2: # else error('stopTol must be number. Exiting.'); end nikcleju@2: # case {'P_trans'} nikcleju@2: # if isa(Options{i+1},'function_handle'); Pt = Options{i+1}; nikcleju@2: # else error('P_trans must be function _handle. Exiting.'); end nikcleju@2: # case {'maxIter'} nikcleju@2: # if isa(Options{i+1},'numeric'); MAXITER = Options{i+1}; nikcleju@2: # else error('maxIter must be a number. Exiting.'); end nikcleju@2: # case {'verbose'} nikcleju@2: # if isa(Options{i+1},'logical'); verbose = Options{i+1}; nikcleju@2: # else error('verbose must be a logical. Exiting.'); end nikcleju@2: # case {'start_val'} nikcleju@2: # if isa(Options{i+1},'numeric') & length(Options{i+1}) == m ; nikcleju@2: # s_initial = Options{i+1}; nikcleju@2: # initial_given=1; nikcleju@2: # else error('start_val must be a vector of length m. Exiting.'); end nikcleju@2: # otherwise nikcleju@2: # error('Unrecognised option. Exiting.') nikcleju@2: # end nikcleju@2: #end nikcleju@2: if 'stopCrit' in opts: nikcleju@2: STOPCRIT = opts['stopCrit'] nikcleju@2: if 'stopTol' in opts: nikcleju@2: if hasattr(opts['stopTol'], '__int__'): # check if numeric nikcleju@2: STOPTOL = opts['stopTol'] nikcleju@2: else: nikcleju@2: raise TypeError('stopTol must be number. Exiting.') nikcleju@2: if 'P_trans' in opts: nikcleju@2: if hasattr(opts['P_trans'], '__call__'): # check if function handle nikcleju@2: Pt = opts['P_trans'] nikcleju@2: else: nikcleju@2: raise TypeError('P_trans must be function _handle. Exiting.') nikcleju@2: if 'maxIter' in opts: nikcleju@2: if hasattr(opts['maxIter'], '__int__'): # check if numeric nikcleju@2: MAXITER = opts['maxIter'] nikcleju@2: else: nikcleju@2: raise TypeError('maxIter must be a number. Exiting.') nikcleju@2: if 'verbose' in opts: nikcleju@2: # TODO: Should check here if is logical nikcleju@2: verbose = opts['verbose'] nikcleju@2: if 'start_val' in opts: nikcleju@2: # TODO: Should check here if is numeric nikcleju@2: if opts['start_val'].size == m: nikcleju@2: s_initial = opts['start_val'] nikcleju@2: initial_given = 1 nikcleju@2: else: nikcleju@2: raise ValueError('start_val must be a vector of length m. Exiting.') nikcleju@2: # Don't exit if unknown option is given, simply ignore it nikcleju@2: nikcleju@2: #if strcmp(STOPCRIT,'M') nikcleju@2: # maxM=STOPTOL; nikcleju@2: #else nikcleju@2: # maxM=MAXITER; nikcleju@2: #end nikcleju@2: if STOPCRIT == 'M': nikcleju@2: maxM = STOPTOL nikcleju@2: else: nikcleju@2: maxM = MAXITER nikcleju@2: nikcleju@2: # if nargout >=2 nikcleju@2: # err_mse = zeros(maxM,1); nikcleju@2: # end nikcleju@2: # if nargout ==3 nikcleju@2: # iter_time = zeros(maxM,1); nikcleju@2: # end nikcleju@2: if opts['nargout'] >= 2: nikcleju@2: err_mse = np.zeros(maxM) nikcleju@2: if opts['nargout'] == 3: nikcleju@2: iter_time = np.zeros(maxM) nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Make P and Pt functions nikcleju@2: ########################################################################### nikcleju@2: #if isa(A,'float') P =@(z) A*z; Pt =@(z) A'*z; nikcleju@2: #elseif isobject(A) P =@(z) A*z; Pt =@(z) A'*z; nikcleju@2: #elseif isa(A,'function_handle') nikcleju@2: # try nikcleju@2: # if isa(Pt,'function_handle'); P=A; nikcleju@2: # else error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end nikcleju@2: # catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end nikcleju@2: #else error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end nikcleju@2: if hasattr(A, '__call__'): nikcleju@2: if hasattr(Pt, '__call__'): nikcleju@2: P = A nikcleju@2: else: nikcleju@2: raise TypeError('If P is a function handle, Pt also needs to be a function handle.') nikcleju@2: else: nikcleju@2: # TODO: should check here if A is matrix nikcleju@2: P = lambda z: np.dot(A,z) nikcleju@2: Pt = lambda z: np.dot(A.T,z) nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Random Check to see if dictionary is normalised nikcleju@2: ########################################################################### nikcleju@2: # mask=zeros(m,1); nikcleju@2: # mask(ceil(rand*m))=1; nikcleju@2: # nP=norm(P(mask)); nikcleju@2: # if abs(1-nP)>1e-3; nikcleju@2: # display('Dictionary appears not to have unit norm columns.') nikcleju@2: # end nikcleju@2: mask = np.zeros(m) nikcleju@2: mask[math.floor(np.random.rand() * m)] = 1 nikcleju@2: nP = np.linalg.norm(P(mask)) nikcleju@2: if abs(1-nP) > 1e-3: nikcleju@2: print 'Dictionary appears not to have unit norm columns.' nikcleju@2: #end nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Check if we have enough memory and initialise nikcleju@2: ########################################################################### nikcleju@2: # try Q=zeros(n,maxM); nikcleju@2: # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.') nikcleju@2: # end nikcleju@2: # try R=zeros(maxM); nikcleju@2: # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.') nikcleju@2: # end nikcleju@2: try: nikcleju@2: Q = np.zeros((n,maxM)) nikcleju@2: except: nikcleju@2: print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.' nikcleju@2: raise nikcleju@2: try: nikcleju@2: R = np.zeros((maxM, maxM)) nikcleju@2: except: nikcleju@2: print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.' nikcleju@2: raise nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Do we start from zero or not? nikcleju@2: ########################################################################### nikcleju@2: #if initial_given ==1; nikcleju@2: # IN = find(s_initial); nikcleju@2: # if ~isempty(IN) nikcleju@2: # Residual = x-P(s_initial); nikcleju@2: # lengthIN=length(IN); nikcleju@2: # z=[]; nikcleju@2: # for k=1:length(IN) nikcleju@2: # # Extract new element nikcleju@2: # mask=zeros(m,1); nikcleju@2: # mask(IN(k))=1; nikcleju@2: # new_element=P(mask); nikcleju@2: # nikcleju@2: # # Orthogonalise new element nikcleju@2: # qP=Q(:,1:k-1)'*new_element; nikcleju@2: # q=new_element-Q(:,1:k-1)*(qP); nikcleju@2: # nikcleju@2: # nq=norm(q); nikcleju@2: # q=q/nq; nikcleju@2: # # Update QR factorisation nikcleju@2: # R(1:k-1,k)=qP; nikcleju@2: # R(k,k)=nq; nikcleju@2: # Q(:,k)=q; nikcleju@2: # nikcleju@2: # z(k)=q'*x; nikcleju@2: # end nikcleju@2: # s = s_initial; nikcleju@2: # Residual=x-Q(:,k)*z; nikcleju@2: # oldERR = Residual'*Residual/n; nikcleju@2: # else nikcleju@2: # IN = []; nikcleju@2: # Residual = x; nikcleju@2: # s = s_initial; nikcleju@2: # sigsize = x'*x/n; nikcleju@2: # oldERR = sigsize; nikcleju@2: # k=0; nikcleju@2: # z=[]; nikcleju@2: # end nikcleju@2: # nikcleju@2: #else nikcleju@2: # IN = []; nikcleju@2: # Residual = x; nikcleju@2: # s = s_initial; nikcleju@2: # sigsize = x'*x/n; nikcleju@2: # oldERR = sigsize; nikcleju@2: # k=0; nikcleju@2: # z=[]; nikcleju@2: #end nikcleju@2: if initial_given == 1: nikcleju@2: #IN = find(s_initial); nikcleju@2: IN = np.nonzero(s_initial)[0].tolist() nikcleju@2: #if ~isempty(IN) nikcleju@2: if IN.size > 0: nikcleju@2: Residual = x - P(s_initial) nikcleju@2: lengthIN = IN.size nikcleju@2: z = np.array([]) nikcleju@2: #for k=1:length(IN) nikcleju@2: for k in np.arange(IN.size): nikcleju@2: # Extract new element nikcleju@2: mask = np.zeros(m) nikcleju@2: mask[IN[k]] = 1 nikcleju@2: new_element = P(mask) nikcleju@2: nikcleju@2: # Orthogonalise new element nikcleju@2: #qP=Q(:,1:k-1)'*new_element; nikcleju@2: if k-1 >= 0: nikcleju@2: qP = np.dot(Q[:,0:k].T , new_element) nikcleju@2: #q=new_element-Q(:,1:k-1)*(qP); nikcleju@2: q = new_element - np.dot(Q[:,0:k] , qP) nikcleju@2: nikcleju@2: nq = np.linalg.norm(q) nikcleju@2: q = q / nq nikcleju@2: # Update QR factorisation nikcleju@2: R[0:k,k] = qP nikcleju@2: R[k,k] = nq nikcleju@2: Q[:,k] = q nikcleju@2: else: nikcleju@2: q = new_element nikcleju@2: nikcleju@2: nq = np.linalg.norm(q) nikcleju@2: q = q / nq nikcleju@2: # Update QR factorisation nikcleju@2: R[k,k] = nq nikcleju@2: Q[:,k] = q nikcleju@2: nikcleju@2: z[k] = np.dot(q.T , x) nikcleju@2: #end nikcleju@2: s = s_initial.copy() nikcleju@2: Residual = x - np.dot(Q[:,k] , z) nikcleju@2: oldERR = np.vdot(Residual , Residual) / n; nikcleju@2: else: nikcleju@2: #IN = np.array([], dtype = int) nikcleju@2: IN = np.array([], dtype = int).tolist() nikcleju@2: Residual = x.copy() nikcleju@2: s = s_initial.copy() nikcleju@2: sigsize = np.vdot(x , x) / n nikcleju@2: oldERR = sigsize nikcleju@2: k = 0 nikcleju@2: #z = np.array([]) nikcleju@2: z = [] nikcleju@2: #end nikcleju@2: nikcleju@2: else: nikcleju@2: #IN = np.array([], dtype = int) nikcleju@2: IN = np.array([], dtype = int).tolist() nikcleju@2: Residual = x.copy() nikcleju@2: s = s_initial.copy() nikcleju@2: sigsize = np.vdot(x , x) / n nikcleju@2: oldERR = sigsize nikcleju@2: k = 0 nikcleju@2: #z = np.array([]) nikcleju@2: z = [] nikcleju@2: #end nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Main algorithm nikcleju@2: ########################################################################### nikcleju@2: # if verbose nikcleju@2: # display('Main iterations...') nikcleju@2: # end nikcleju@2: # tic nikcleju@2: # t=0; nikcleju@2: # DR=Pt(Residual); nikcleju@2: # done = 0; nikcleju@2: # iter=1; nikcleju@2: if verbose: nikcleju@2: print 'Main iterations...' nikcleju@2: tic = time.time() nikcleju@2: t = 0 nikcleju@2: DR = Pt(Residual) nikcleju@2: done = 0 nikcleju@2: iter = 1 nikcleju@2: nikcleju@2: #while ~done nikcleju@2: # nikcleju@2: # # Select new element nikcleju@2: # DR(IN)=0; nikcleju@2: # # Nic: replace selection with random variable nikcleju@2: # # i.e. Randomized OMP!! nikcleju@2: # # DON'T FORGET ABOUT THIS!! nikcleju@2: # [v I]=max(abs(DR)); nikcleju@2: # #I = randp(exp(abs(DR).^2 ./ (norms.^2)'), [1 1]); nikcleju@2: # IN=[IN I]; nikcleju@2: # nikcleju@2: # nikcleju@2: # k=k+1; nikcleju@2: # # Extract new element nikcleju@2: # mask=zeros(m,1); nikcleju@2: # mask(IN(k))=1; nikcleju@2: # new_element=P(mask); nikcleju@2: # nikcleju@2: # # Orthogonalise new element nikcleju@2: # qP=Q(:,1:k-1)'*new_element; nikcleju@2: # q=new_element-Q(:,1:k-1)*(qP); nikcleju@2: # nikcleju@2: # nq=norm(q); nikcleju@2: # q=q/nq; nikcleju@2: # # Update QR factorisation nikcleju@2: # R(1:k-1,k)=qP; nikcleju@2: # R(k,k)=nq; nikcleju@2: # Q(:,k)=q; nikcleju@2: # nikcleju@2: # z(k)=q'*x; nikcleju@2: # nikcleju@2: # # New residual nikcleju@2: # Residual=Residual-q*(z(k)); nikcleju@2: # DR=Pt(Residual); nikcleju@2: # nikcleju@2: # ERR=Residual'*Residual/n; nikcleju@2: # if comp_err nikcleju@2: # err_mse(iter)=ERR; nikcleju@2: # end nikcleju@2: # nikcleju@2: # if comp_time nikcleju@2: # iter_time(iter)=toc; nikcleju@2: # end nikcleju@2: # nikcleju@2: ############################################################################ nikcleju@2: ## Are we done yet? nikcleju@2: ############################################################################ nikcleju@2: # nikcleju@2: # if strcmp(STOPCRIT,'M') nikcleju@2: # if iter >= STOPTOL nikcleju@2: # done =1; nikcleju@2: # elseif verbose && toc-t>10 nikcleju@2: # display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter)) nikcleju@2: # t=toc; nikcleju@2: # end nikcleju@2: # elseif strcmp(STOPCRIT,'mse') nikcleju@2: # if comp_err nikcleju@2: # if err_mse(iter)10 nikcleju@2: # display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter))) nikcleju@2: # t=toc; nikcleju@2: # end nikcleju@2: # else nikcleju@2: # if ERR10 nikcleju@2: # display(sprintf('Iteration #i. --- #i mse',iter ,ERR)) nikcleju@2: # t=toc; nikcleju@2: # end nikcleju@2: # end nikcleju@2: # elseif strcmp(STOPCRIT,'mse_change') && iter >=2 nikcleju@2: # if comp_err && iter >=2 nikcleju@2: # if ((err_mse(iter-1)-err_mse(iter))/sigsize 10 nikcleju@2: # display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize )) nikcleju@2: # t=toc; nikcleju@2: # end nikcleju@2: # else nikcleju@2: # if ((oldERR - ERR)/sigsize < STOPTOL); nikcleju@2: # done = 1; nikcleju@2: # elseif verbose && toc-t>10 nikcleju@2: # display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize)) nikcleju@2: # t=toc; nikcleju@2: # end nikcleju@2: # end nikcleju@2: # elseif strcmp(STOPCRIT,'corr') nikcleju@2: # if max(abs(DR)) < STOPTOL; nikcleju@2: # done = 1; nikcleju@2: # elseif verbose && toc-t>10 nikcleju@2: # display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR)))) nikcleju@2: # t=toc; nikcleju@2: # end nikcleju@2: # end nikcleju@2: # nikcleju@2: # # Also stop if residual gets too small or maxIter reached nikcleju@2: # if comp_err nikcleju@2: # if err_mse(iter)<1e-16 nikcleju@2: # display('Stopping. Exact signal representation found!') nikcleju@2: # done=1; nikcleju@2: # end nikcleju@2: # else nikcleju@2: # nikcleju@2: # nikcleju@2: # if iter>1 nikcleju@2: # if ERR<1e-16 nikcleju@2: # display('Stopping. Exact signal representation found!') nikcleju@2: # done=1; nikcleju@2: # end nikcleju@2: # end nikcleju@2: # end nikcleju@2: # nikcleju@2: # if iter >= MAXITER nikcleju@2: # display('Stopping. Maximum number of iterations reached!') nikcleju@2: # done = 1; nikcleju@2: # end nikcleju@2: # nikcleju@2: ############################################################################ nikcleju@2: ## If not done, take another round nikcleju@2: ############################################################################ nikcleju@2: # nikcleju@2: # if ~done nikcleju@2: # iter=iter+1; nikcleju@2: # oldERR=ERR; nikcleju@2: # end nikcleju@2: #end nikcleju@2: while not done: nikcleju@2: nikcleju@2: # Select new element nikcleju@2: DR[IN]=0 nikcleju@2: #[v I]=max(abs(DR)); nikcleju@2: #v = np.abs(DR).max() nikcleju@2: I = np.abs(DR).argmax() nikcleju@2: #IN = np.concatenate((IN,I)) nikcleju@2: IN.append(I) nikcleju@2: nikcleju@2: nikcleju@2: #k = k + 1 Move to end, since is zero based nikcleju@2: nikcleju@2: # Extract new element nikcleju@2: mask = np.zeros(m) nikcleju@2: mask[IN[k]] = 1 nikcleju@2: new_element = P(mask) nikcleju@2: nikcleju@2: # Orthogonalise new element nikcleju@2: if k-1 >= 0: nikcleju@2: qP = np.dot(Q[:,0:k].T , new_element) nikcleju@2: q = new_element - np.dot(Q[:,0:k] , qP) nikcleju@2: nikcleju@2: nq = np.linalg.norm(q) nikcleju@2: q = q/nq nikcleju@2: # Update QR factorisation nikcleju@2: R[0:k,k] = qP nikcleju@2: R[k,k] = nq nikcleju@2: Q[:,k] = q nikcleju@2: else: nikcleju@2: q = new_element nikcleju@2: nikcleju@2: nq = np.linalg.norm(q) nikcleju@2: q = q/nq nikcleju@2: # Update QR factorisation nikcleju@2: R[k,k] = nq nikcleju@2: Q[:,k] = q nikcleju@2: nikcleju@2: #z[k]=np.vdot(q , x) nikcleju@2: z.append(np.vdot(q , x)) nikcleju@2: nikcleju@2: # New residual nikcleju@2: Residual = Residual - q * (z[k]) nikcleju@2: DR = Pt(Residual) nikcleju@2: nikcleju@2: ERR = np.vdot(Residual , Residual) / n nikcleju@2: if comp_err: nikcleju@2: err_mse[iter-1] = ERR nikcleju@2: #end nikcleju@2: nikcleju@2: if comp_time: nikcleju@2: iter_time[iter-1] = time.time() - tic nikcleju@2: #end nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Are we done yet? nikcleju@2: ########################################################################### nikcleju@2: if STOPCRIT == 'M': nikcleju@2: if iter >= STOPTOL: nikcleju@2: done = 1 nikcleju@2: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@2: #display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter)) nikcleju@2: print 'Iteration '+iter+'. --- '+(STOPTOL-iter)+' iterations to go' nikcleju@2: t = time.time() nikcleju@2: #end nikcleju@2: elif STOPCRIT =='mse': nikcleju@2: if comp_err: nikcleju@2: if err_mse[iter-1] < STOPTOL: nikcleju@2: done = 1 nikcleju@2: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@2: #display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter))) nikcleju@2: print 'Iteration '+iter+'. --- '+err_mse[iter-1]+' mse' nikcleju@2: t = time.time() nikcleju@2: #end nikcleju@2: else: nikcleju@2: if ERR < STOPTOL: nikcleju@2: done = 1 nikcleju@2: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@2: #display(sprintf('Iteration #i. --- #i mse',iter ,ERR)) nikcleju@2: print 'Iteration '+iter+'. --- '+ERR+' mse' nikcleju@2: t = time.time() nikcleju@2: #end nikcleju@2: #end nikcleju@2: elif STOPCRIT == 'mse_change' and iter >=2: nikcleju@2: if comp_err and iter >=2: nikcleju@2: if ((err_mse[iter-2] - err_mse[iter-1])/sigsize < STOPTOL): nikcleju@2: done = 1 nikcleju@2: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@2: #display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize )) nikcleju@2: print 'Iteration '+iter+'. --- '+((err_mse[iter-2]-err_mse[iter-1])/sigsize)+' mse change' nikcleju@2: t = time.time() nikcleju@2: #end nikcleju@2: else: nikcleju@2: if ((oldERR - ERR)/sigsize < STOPTOL): nikcleju@2: done = 1 nikcleju@2: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@2: #display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize)) nikcleju@2: print 'Iteration '+iter+'. --- '+((oldERR - ERR)/sigsize)+' mse change' nikcleju@2: t = time.time() nikcleju@2: #end nikcleju@2: #end nikcleju@2: elif STOPCRIT == 'corr': nikcleju@2: if np.abs(DR).max() < STOPTOL: nikcleju@2: done = 1 nikcleju@2: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@2: #display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR)))) nikcleju@2: print 'Iteration '+iter+'. --- '+(np.abs(DR).max())+' corr' nikcleju@2: t = time.time() nikcleju@2: #end nikcleju@2: #end nikcleju@2: nikcleju@2: # Also stop if residual gets too small or maxIter reached nikcleju@2: if comp_err: nikcleju@2: if err_mse[iter-1] < 1e-14: nikcleju@2: done = 1 nikcleju@2: # Nic: added verbose check nikcleju@2: if verbose: nikcleju@2: print 'Stopping. Exact signal representation found!' nikcleju@2: #end nikcleju@2: else: nikcleju@2: if iter > 1: nikcleju@2: if ERR < 1e-14: nikcleju@2: done = 1 nikcleju@2: # Nic: added verbose check nikcleju@2: if verbose: nikcleju@2: print 'Stopping. Exact signal representation found!' nikcleju@2: #end nikcleju@2: #end nikcleju@2: #end nikcleju@2: nikcleju@2: nikcleju@2: if iter >= MAXITER: nikcleju@2: done = 1 nikcleju@2: # Nic: added verbose check nikcleju@2: if verbose: nikcleju@2: print 'Stopping. Maximum number of iterations reached!' nikcleju@2: #end nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # If not done, take another round nikcleju@2: ########################################################################### nikcleju@2: if not done: nikcleju@2: iter = iter + 1 nikcleju@2: oldERR = ERR nikcleju@2: #end nikcleju@2: nikcleju@2: # Moved here from front, since we are 0-based nikcleju@2: k = k + 1 nikcleju@2: #end nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Now we can solve for s by back-substitution nikcleju@2: ########################################################################### nikcleju@2: #s(IN)=R(1:k,1:k)\z(1:k)'; nikcleju@2: s[IN] = scipy.linalg.solve(R[0:k,0:k] , np.array(z[0:k])) nikcleju@2: nikcleju@2: ########################################################################### nikcleju@2: # Only return as many elements as iterations nikcleju@2: ########################################################################### nikcleju@2: if opts['nargout'] >= 2: nikcleju@2: err_mse = err_mse[0:iter-1] nikcleju@2: #end nikcleju@2: if opts['nargout'] == 3: nikcleju@2: iter_time = iter_time[0:iter-1] nikcleju@2: #end nikcleju@2: if verbose: nikcleju@2: print 'Done' nikcleju@2: #end nikcleju@2: nikcleju@2: # Return nikcleju@2: if opts['nargout'] == 1: nikcleju@2: return s nikcleju@2: elif opts['nargout'] == 2: nikcleju@2: return s, err_mse nikcleju@2: elif opts['nargout'] == 3: nikcleju@2: return s, err_mse, iter_time nikcleju@2: nikcleju@2: # Change history nikcleju@2: # nikcleju@2: # 8 of Februray: Algo does no longer stop if dictionary is not normaliesd. nikcleju@2: nikcleju@2: # End of greed_omp_qr() function nikcleju@2: #-------------------------------- nikcleju@2: nikcleju@2: nikcleju@2: def omp_qr(x, dict, D, natom, tolerance): nikcleju@2: """ Recover x using QR implementation of OMP nikcleju@2: nikcleju@2: Parameter nikcleju@2: --------- nikcleju@2: x: measurements nikcleju@2: dict: dictionary nikcleju@2: D: Gramian of dictionary nikcleju@2: natom: iterations nikcleju@2: tolerance: error tolerance nikcleju@2: nikcleju@2: Return nikcleju@2: ------ nikcleju@2: x_hat : estimate of x nikcleju@2: gamma : indices where non-zero nikcleju@2: nikcleju@2: For more information, see http://media.aau.dk/null_space_pursuits/2011/10/efficient-omp.html nikcleju@2: """ nikcleju@2: msize, dictsize = dict.shape nikcleju@2: normr2 = np.vdot(x,x) nikcleju@2: normtol2 = tolerance*normr2 nikcleju@2: R = np.zeros((natom,natom)) nikcleju@2: Q = np.zeros((msize,natom)) nikcleju@2: gamma = [] nikcleju@2: nikcleju@2: # find initial projections nikcleju@2: origprojections = np.dot(x.T,dict) nikcleju@2: origprojectionsT = origprojections.T nikcleju@2: projections = origprojections.copy(); nikcleju@2: nikcleju@2: k = 0 nikcleju@2: while (normr2 > normtol2) and (k < natom): nikcleju@2: # find index of maximum magnitude projection nikcleju@2: newgam = np.argmax(np.abs(projections ** 2)) nikcleju@2: gamma.append(newgam) nikcleju@2: # update QR factorization, projections, and residual energy nikcleju@2: if k == 0: nikcleju@2: R[0,0] = 1 nikcleju@2: Q[:,0] = dict[:,newgam].copy() nikcleju@2: # update projections nikcleju@2: QtempQtempT = np.outer(Q[:,0],Q[:,0]) nikcleju@2: projections -= np.dot(x.T, np.dot(QtempQtempT,dict)) nikcleju@2: # update residual energy nikcleju@2: normr2 -= np.vdot(x, np.dot(QtempQtempT,x)) nikcleju@2: else: nikcleju@2: w = scipy.linalg.solve_triangular(R[0:k,0:k],D[gamma[0:k],newgam],trans=1) nikcleju@2: R[k,k] = np.sqrt(1-np.vdot(w,w)) nikcleju@2: R[0:k,k] = w.copy() nikcleju@2: Q[:,k] = (dict[:,newgam] - np.dot(QtempQtempT,dict[:,newgam]))/R[k,k] nikcleju@2: QkQkT = np.outer(Q[:,k],Q[:,k]) nikcleju@2: xTQkQkT = np.dot(x.T,QkQkT) nikcleju@2: QtempQtempT += QkQkT nikcleju@2: # update projections nikcleju@2: projections -= np.dot(xTQkQkT,dict) nikcleju@2: # update residual energy nikcleju@2: normr2 -= np.dot(xTQkQkT,x) nikcleju@2: nikcleju@2: k += 1 nikcleju@2: nikcleju@2: # build solution nikcleju@2: tempR = R[0:k,0:k] nikcleju@2: w = scipy.linalg.solve_triangular(tempR,origprojectionsT[gamma[0:k]],trans=1) nikcleju@2: x_hat = np.zeros((dictsize,1)) nikcleju@2: x_hat[gamma[0:k]] = scipy.linalg.solve_triangular(tempR,w) nikcleju@2: nikcleju@2: return x_hat, gamma