nikcleju@1: import numpy as np nikcleju@1: import scipy.linalg nikcleju@1: import time nikcleju@1: import math nikcleju@1: nikcleju@1: nikcleju@1: #function [s, err_mse, iter_time]=greed_omp_qr(x,A,m,varargin) nikcleju@1: def greed_omp_qr(x,A,m,opts=[]): nikcleju@1: # greed_omp_qr: Orthogonal Matching Pursuit algorithm based on QR nikcleju@1: # factorisation nikcleju@1: # Nic: translated to Python on 19.10.2011. Original Matlab Code by Thomas Blumensath nikcleju@1: ########################################################################### nikcleju@1: # Usage nikcleju@1: # [s, err_mse, iter_time]=greed_omp_qr(x,P,m,'option_name','option_value') nikcleju@1: ########################################################################### nikcleju@1: ########################################################################### nikcleju@1: # Input nikcleju@1: # Mandatory: nikcleju@1: # x Observation vector to be decomposed nikcleju@1: # P Either: nikcleju@1: # 1) An nxm matrix (n must be dimension of x) nikcleju@1: # 2) A function handle (type "help function_format" nikcleju@1: # for more information) nikcleju@1: # Also requires specification of P_trans option. nikcleju@1: # 3) An object handle (type "help object_format" for nikcleju@1: # more information) nikcleju@1: # m length of s nikcleju@1: # nikcleju@1: # Possible additional options: nikcleju@1: # (specify as many as you want using 'option_name','option_value' pairs) nikcleju@1: # See below for explanation of options: nikcleju@1: #__________________________________________________________________________ nikcleju@1: # option_name | available option_values | default nikcleju@1: #-------------------------------------------------------------------------- nikcleju@1: # stopCrit | M, corr, mse, mse_change | M nikcleju@1: # stopTol | number (see below) | n/4 nikcleju@1: # P_trans | function_handle (see below) | nikcleju@1: # maxIter | positive integer (see below) | n nikcleju@1: # verbose | true, false | false nikcleju@1: # start_val | vector of length m | zeros nikcleju@1: # nikcleju@1: # Available stopping criteria : nikcleju@1: # M - Extracts exactly M = stopTol elements. nikcleju@1: # corr - Stops when maximum correlation between nikcleju@1: # residual and atoms is below stopTol value. nikcleju@1: # mse - Stops when mean squared error of residual nikcleju@1: # is below stopTol value. nikcleju@1: # mse_change - Stops when the change in the mean squared nikcleju@1: # error falls below stopTol value. nikcleju@1: # nikcleju@1: # stopTol: Value for stopping criterion. nikcleju@1: # nikcleju@1: # P_trans: If P is a function handle, then P_trans has to be specified and nikcleju@1: # must be a function handle. nikcleju@1: # nikcleju@1: # maxIter: Maximum number of allowed iterations. nikcleju@1: # nikcleju@1: # verbose: Logical value to allow algorithm progress to be displayed. nikcleju@1: # nikcleju@1: # start_val: Allows algorithms to start from partial solution. nikcleju@1: # nikcleju@1: ########################################################################### nikcleju@1: # Outputs nikcleju@1: # s Solution vector nikcleju@1: # err_mse Vector containing mse of approximation error for each nikcleju@1: # iteration nikcleju@1: # iter_time Vector containing computation times for each iteration nikcleju@1: # nikcleju@1: ########################################################################### nikcleju@1: # Description nikcleju@1: # greed_omp_qr performs a greedy signal decomposition. nikcleju@1: # In each iteration a new element is selected depending on the inner nikcleju@1: # product between the current residual and columns in P. nikcleju@1: # The non-zero elements of s are approximated by orthogonally projecting nikcleju@1: # x onto the selected elements in each iteration. nikcleju@1: # The algorithm uses QR decomposition. nikcleju@1: # nikcleju@1: # See Also nikcleju@1: # greed_omp_chol, greed_omp_cg, greed_omp_cgp, greed_omp_pinv, nikcleju@1: # greed_omp_linsolve, greed_gp, greed_nomp nikcleju@1: # nikcleju@1: # Copyright (c) 2007 Thomas Blumensath nikcleju@1: # nikcleju@1: # The University of Edinburgh nikcleju@1: # Email: thomas.blumensath@ed.ac.uk nikcleju@1: # Comments and bug reports welcome nikcleju@1: # nikcleju@1: # This file is part of sparsity Version 0.1 nikcleju@1: # Created: April 2007 nikcleju@1: # nikcleju@1: # Part of this toolbox was developed with the support of EPSRC Grant nikcleju@1: # D000246/1 nikcleju@1: # nikcleju@1: # Please read COPYRIGHT.m for terms and conditions. nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Default values and initialisation nikcleju@1: ########################################################################### nikcleju@1: #[n1 n2]=size(x); nikcleju@1: #n1,n2 = x.shape nikcleju@1: #if n2 == 1 nikcleju@1: # n=n1; nikcleju@1: #elseif n1 == 1 nikcleju@1: # x=x'; nikcleju@1: # n=n2; nikcleju@1: #else nikcleju@1: # display('x must be a vector.'); nikcleju@1: # return nikcleju@1: #end nikcleju@1: if x.ndim != 1: nikcleju@1: print 'x must be a vector.' nikcleju@1: return nikcleju@1: n = x.size nikcleju@1: nikcleju@1: #sigsize = x'*x/n; nikcleju@1: sigsize = np.vdot(x,x)/n; nikcleju@1: initial_given = 0; nikcleju@1: err_mse = np.array([]); nikcleju@1: iter_time = np.array([]); nikcleju@1: STOPCRIT = 'M'; nikcleju@1: STOPTOL = math.ceil(n/4.0); nikcleju@1: MAXITER = n; nikcleju@1: verbose = False; nikcleju@1: s_initial = np.zeros(m); nikcleju@1: nikcleju@1: if verbose: nikcleju@1: print 'Initialising...' nikcleju@1: #end nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Output variables nikcleju@1: ########################################################################### nikcleju@1: #switch nargout nikcleju@1: # case 3 nikcleju@1: # comp_err=true; nikcleju@1: # comp_time=true; nikcleju@1: # case 2 nikcleju@1: # comp_err=true; nikcleju@1: # comp_time=false; nikcleju@1: # case 1 nikcleju@1: # comp_err=false; nikcleju@1: # comp_time=false; nikcleju@1: # case 0 nikcleju@1: # error('Please assign output variable.') nikcleju@1: # otherwise nikcleju@1: # error('Too many output arguments specified') nikcleju@1: #end nikcleju@1: if 'nargout' in opts: nikcleju@1: if opts['nargout'] == 3: nikcleju@1: comp_err = True nikcleju@1: comp_time = True nikcleju@1: elif opts['nargout'] == 2: nikcleju@1: comp_err = True nikcleju@1: comp_time = False nikcleju@1: elif opts['nargout'] == 1: nikcleju@1: comp_err = False nikcleju@1: comp_time = False nikcleju@1: elif opts['nargout'] == 0: nikcleju@1: print 'Please assign output variable.' nikcleju@1: return nikcleju@1: else: nikcleju@1: print 'Too many output arguments specified' nikcleju@1: return nikcleju@1: else: nikcleju@1: # If not given, make default nargout = 3 nikcleju@1: # and add nargout to options nikcleju@1: opts['nargout'] = 3 nikcleju@1: comp_err = True nikcleju@1: comp_time = True nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Look through options nikcleju@1: ########################################################################### nikcleju@1: # Put option into nice format nikcleju@1: #Options={}; nikcleju@1: #OS=nargin-3; nikcleju@1: #c=1; nikcleju@1: #for i=1:OS nikcleju@1: # if isa(varargin{i},'cell') nikcleju@1: # CellSize=length(varargin{i}); nikcleju@1: # ThisCell=varargin{i}; nikcleju@1: # for j=1:CellSize nikcleju@1: # Options{c}=ThisCell{j}; nikcleju@1: # c=c+1; nikcleju@1: # end nikcleju@1: # else nikcleju@1: # Options{c}=varargin{i}; nikcleju@1: # c=c+1; nikcleju@1: # end nikcleju@1: #end nikcleju@1: #OS=length(Options); nikcleju@1: #if rem(OS,2) nikcleju@1: # error('Something is wrong with argument name and argument value pairs.') nikcleju@1: #end nikcleju@1: # nikcleju@1: #for i=1:2:OS nikcleju@1: # switch Options{i} nikcleju@1: # case {'stopCrit'} nikcleju@1: # if (strmatch(Options{i+1},{'M'; 'corr'; 'mse'; 'mse_change'},'exact')); nikcleju@1: # STOPCRIT = Options{i+1}; nikcleju@1: # else error('stopCrit must be char string [M, corr, mse, mse_change]. Exiting.'); end nikcleju@1: # case {'stopTol'} nikcleju@1: # if isa(Options{i+1},'numeric') ; STOPTOL = Options{i+1}; nikcleju@1: # else error('stopTol must be number. Exiting.'); end nikcleju@1: # case {'P_trans'} nikcleju@1: # if isa(Options{i+1},'function_handle'); Pt = Options{i+1}; nikcleju@1: # else error('P_trans must be function _handle. Exiting.'); end nikcleju@1: # case {'maxIter'} nikcleju@1: # if isa(Options{i+1},'numeric'); MAXITER = Options{i+1}; nikcleju@1: # else error('maxIter must be a number. Exiting.'); end nikcleju@1: # case {'verbose'} nikcleju@1: # if isa(Options{i+1},'logical'); verbose = Options{i+1}; nikcleju@1: # else error('verbose must be a logical. Exiting.'); end nikcleju@1: # case {'start_val'} nikcleju@1: # if isa(Options{i+1},'numeric') & length(Options{i+1}) == m ; nikcleju@1: # s_initial = Options{i+1}; nikcleju@1: # initial_given=1; nikcleju@1: # else error('start_val must be a vector of length m. Exiting.'); end nikcleju@1: # otherwise nikcleju@1: # error('Unrecognised option. Exiting.') nikcleju@1: # end nikcleju@1: #end nikcleju@1: if 'stopCrit' in opts: nikcleju@1: STOPCRIT = opts['stopCrit'] nikcleju@1: if 'stopTol' in opts: nikcleju@1: if hasattr(opts['stopTol'], '__int__'): # check if numeric nikcleju@1: STOPTOL = opts['stopTol'] nikcleju@1: else: nikcleju@1: raise TypeError('stopTol must be number. Exiting.') nikcleju@1: if 'P_trans' in opts: nikcleju@1: if hasattr(opts['P_trans'], '__call__'): # check if function handle nikcleju@1: Pt = opts['P_trans'] nikcleju@1: else: nikcleju@1: raise TypeError('P_trans must be function _handle. Exiting.') nikcleju@1: if 'maxIter' in opts: nikcleju@1: if hasattr(opts['maxIter'], '__int__'): # check if numeric nikcleju@1: MAXITER = opts['maxIter'] nikcleju@1: else: nikcleju@1: raise TypeError('maxIter must be a number. Exiting.') nikcleju@1: if 'verbose' in opts: nikcleju@1: # TODO: Should check here if is logical nikcleju@1: verbose = opts['verbose'] nikcleju@1: if 'start_val' in opts: nikcleju@1: # TODO: Should check here if is numeric nikcleju@1: if opts['start_val'].size == m: nikcleju@1: s_initial = opts['start_val'] nikcleju@1: initial_given = 1 nikcleju@1: else: nikcleju@1: raise ValueError('start_val must be a vector of length m. Exiting.') nikcleju@1: # Don't exit if unknown option is given, simply ignore it nikcleju@1: nikcleju@1: #if strcmp(STOPCRIT,'M') nikcleju@1: # maxM=STOPTOL; nikcleju@1: #else nikcleju@1: # maxM=MAXITER; nikcleju@1: #end nikcleju@1: if STOPCRIT == 'M': nikcleju@1: maxM = STOPTOL nikcleju@1: else: nikcleju@1: maxM = MAXITER nikcleju@1: nikcleju@1: # if nargout >=2 nikcleju@1: # err_mse = zeros(maxM,1); nikcleju@1: # end nikcleju@1: # if nargout ==3 nikcleju@1: # iter_time = zeros(maxM,1); nikcleju@1: # end nikcleju@1: if opts['nargout'] >= 2: nikcleju@1: err_mse = np.zeros(maxM) nikcleju@1: if opts['nargout'] == 3: nikcleju@1: iter_time = np.zeros(maxM) nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Make P and Pt functions nikcleju@1: ########################################################################### nikcleju@1: #if isa(A,'float') P =@(z) A*z; Pt =@(z) A'*z; nikcleju@1: #elseif isobject(A) P =@(z) A*z; Pt =@(z) A'*z; nikcleju@1: #elseif isa(A,'function_handle') nikcleju@1: # try nikcleju@1: # if isa(Pt,'function_handle'); P=A; nikcleju@1: # else error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end nikcleju@1: # catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end nikcleju@1: #else error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end nikcleju@1: if hasattr(A, '__call__'): nikcleju@1: if hasattr(Pt, '__call__'): nikcleju@1: P = A nikcleju@1: else: nikcleju@1: raise TypeError('If P is a function handle, Pt also needs to be a function handle.') nikcleju@1: else: nikcleju@1: # TODO: should check here if A is matrix nikcleju@1: P = lambda z: np.dot(A,z) nikcleju@1: Pt = lambda z: np.dot(A.T,z) nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Random Check to see if dictionary is normalised nikcleju@1: ########################################################################### nikcleju@1: # mask=zeros(m,1); nikcleju@1: # mask(ceil(rand*m))=1; nikcleju@1: # nP=norm(P(mask)); nikcleju@1: # if abs(1-nP)>1e-3; nikcleju@1: # display('Dictionary appears not to have unit norm columns.') nikcleju@1: # end nikcleju@1: mask = np.zeros(m) nikcleju@1: mask[math.floor(np.random.rand() * m)] = 1 nikcleju@1: nP = np.linalg.norm(P(mask)) nikcleju@1: if abs(1-nP) > 1e-3: nikcleju@1: print 'Dictionary appears not to have unit norm columns.' nikcleju@1: #end nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Check if we have enough memory and initialise nikcleju@1: ########################################################################### nikcleju@1: # try Q=zeros(n,maxM); nikcleju@1: # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.') nikcleju@1: # end nikcleju@1: # try R=zeros(maxM); nikcleju@1: # catch error('Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.') nikcleju@1: # end nikcleju@1: try: nikcleju@1: Q = np.zeros((n,maxM)) nikcleju@1: except: nikcleju@1: print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.' nikcleju@1: raise nikcleju@1: try: nikcleju@1: R = np.zeros((maxM, maxM)) nikcleju@1: except: nikcleju@1: print 'Variable size is too large. Please try greed_omp_chol algorithm or reduce MAXITER.' nikcleju@1: raise nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Do we start from zero or not? nikcleju@1: ########################################################################### nikcleju@1: #if initial_given ==1; nikcleju@1: # IN = find(s_initial); nikcleju@1: # if ~isempty(IN) nikcleju@1: # Residual = x-P(s_initial); nikcleju@1: # lengthIN=length(IN); nikcleju@1: # z=[]; nikcleju@1: # for k=1:length(IN) nikcleju@1: # # Extract new element nikcleju@1: # mask=zeros(m,1); nikcleju@1: # mask(IN(k))=1; nikcleju@1: # new_element=P(mask); nikcleju@1: # nikcleju@1: # # Orthogonalise new element nikcleju@1: # qP=Q(:,1:k-1)'*new_element; nikcleju@1: # q=new_element-Q(:,1:k-1)*(qP); nikcleju@1: # nikcleju@1: # nq=norm(q); nikcleju@1: # q=q/nq; nikcleju@1: # # Update QR factorisation nikcleju@1: # R(1:k-1,k)=qP; nikcleju@1: # R(k,k)=nq; nikcleju@1: # Q(:,k)=q; nikcleju@1: # nikcleju@1: # z(k)=q'*x; nikcleju@1: # end nikcleju@1: # s = s_initial; nikcleju@1: # Residual=x-Q(:,k)*z; nikcleju@1: # oldERR = Residual'*Residual/n; nikcleju@1: # else nikcleju@1: # IN = []; nikcleju@1: # Residual = x; nikcleju@1: # s = s_initial; nikcleju@1: # sigsize = x'*x/n; nikcleju@1: # oldERR = sigsize; nikcleju@1: # k=0; nikcleju@1: # z=[]; nikcleju@1: # end nikcleju@1: # nikcleju@1: #else nikcleju@1: # IN = []; nikcleju@1: # Residual = x; nikcleju@1: # s = s_initial; nikcleju@1: # sigsize = x'*x/n; nikcleju@1: # oldERR = sigsize; nikcleju@1: # k=0; nikcleju@1: # z=[]; nikcleju@1: #end nikcleju@1: if initial_given == 1: nikcleju@1: #IN = find(s_initial); nikcleju@1: IN = np.nonzero(s_initial)[0].tolist() nikcleju@1: #if ~isempty(IN) nikcleju@1: if IN.size > 0: nikcleju@1: Residual = x - P(s_initial) nikcleju@1: lengthIN = IN.size nikcleju@1: z = np.array([]) nikcleju@1: #for k=1:length(IN) nikcleju@1: for k in np.arange(IN.size): nikcleju@1: # Extract new element nikcleju@1: mask = np.zeros(m) nikcleju@1: mask[IN[k]] = 1 nikcleju@1: new_element = P(mask) nikcleju@1: nikcleju@1: # Orthogonalise new element nikcleju@1: #qP=Q(:,1:k-1)'*new_element; nikcleju@1: if k-1 >= 0: nikcleju@1: qP = np.dot(Q[:,0:k].T , new_element) nikcleju@1: #q=new_element-Q(:,1:k-1)*(qP); nikcleju@1: q = new_element - np.dot(Q[:,0:k] , qP) nikcleju@1: nikcleju@1: nq = np.linalg.norm(q) nikcleju@1: q = q / nq nikcleju@1: # Update QR factorisation nikcleju@1: R[0:k,k] = qP nikcleju@1: R[k,k] = nq nikcleju@1: Q[:,k] = q nikcleju@1: else: nikcleju@1: q = new_element nikcleju@1: nikcleju@1: nq = np.linalg.norm(q) nikcleju@1: q = q / nq nikcleju@1: # Update QR factorisation nikcleju@1: R[k,k] = nq nikcleju@1: Q[:,k] = q nikcleju@1: nikcleju@1: z[k] = np.dot(q.T , x) nikcleju@1: #end nikcleju@1: s = s_initial.copy() nikcleju@1: Residual = x - np.dot(Q[:,k] , z) nikcleju@1: oldERR = np.vdot(Residual , Residual) / n; nikcleju@1: else: nikcleju@1: #IN = np.array([], dtype = int) nikcleju@1: IN = np.array([], dtype = int).tolist() nikcleju@1: Residual = x.copy() nikcleju@1: s = s_initial.copy() nikcleju@1: sigsize = np.vdot(x , x) / n nikcleju@1: oldERR = sigsize nikcleju@1: k = 0 nikcleju@1: #z = np.array([]) nikcleju@1: z = [] nikcleju@1: #end nikcleju@1: nikcleju@1: else: nikcleju@1: #IN = np.array([], dtype = int) nikcleju@1: IN = np.array([], dtype = int).tolist() nikcleju@1: Residual = x.copy() nikcleju@1: s = s_initial.copy() nikcleju@1: sigsize = np.vdot(x , x) / n nikcleju@1: oldERR = sigsize nikcleju@1: k = 0 nikcleju@1: #z = np.array([]) nikcleju@1: z = [] nikcleju@1: #end nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Main algorithm nikcleju@1: ########################################################################### nikcleju@1: # if verbose nikcleju@1: # display('Main iterations...') nikcleju@1: # end nikcleju@1: # tic nikcleju@1: # t=0; nikcleju@1: # DR=Pt(Residual); nikcleju@1: # done = 0; nikcleju@1: # iter=1; nikcleju@1: if verbose: nikcleju@1: print 'Main iterations...' nikcleju@1: tic = time.time() nikcleju@1: t = 0 nikcleju@1: DR = Pt(Residual) nikcleju@1: done = 0 nikcleju@1: iter = 1 nikcleju@1: nikcleju@1: #while ~done nikcleju@1: # nikcleju@1: # # Select new element nikcleju@1: # DR(IN)=0; nikcleju@1: # # Nic: replace selection with random variable nikcleju@1: # # i.e. Randomized OMP!! nikcleju@1: # # DON'T FORGET ABOUT THIS!! nikcleju@1: # [v I]=max(abs(DR)); nikcleju@1: # #I = randp(exp(abs(DR).^2 ./ (norms.^2)'), [1 1]); nikcleju@1: # IN=[IN I]; nikcleju@1: # nikcleju@1: # nikcleju@1: # k=k+1; nikcleju@1: # # Extract new element nikcleju@1: # mask=zeros(m,1); nikcleju@1: # mask(IN(k))=1; nikcleju@1: # new_element=P(mask); nikcleju@1: # nikcleju@1: # # Orthogonalise new element nikcleju@1: # qP=Q(:,1:k-1)'*new_element; nikcleju@1: # q=new_element-Q(:,1:k-1)*(qP); nikcleju@1: # nikcleju@1: # nq=norm(q); nikcleju@1: # q=q/nq; nikcleju@1: # # Update QR factorisation nikcleju@1: # R(1:k-1,k)=qP; nikcleju@1: # R(k,k)=nq; nikcleju@1: # Q(:,k)=q; nikcleju@1: # nikcleju@1: # z(k)=q'*x; nikcleju@1: # nikcleju@1: # # New residual nikcleju@1: # Residual=Residual-q*(z(k)); nikcleju@1: # DR=Pt(Residual); nikcleju@1: # nikcleju@1: # ERR=Residual'*Residual/n; nikcleju@1: # if comp_err nikcleju@1: # err_mse(iter)=ERR; nikcleju@1: # end nikcleju@1: # nikcleju@1: # if comp_time nikcleju@1: # iter_time(iter)=toc; nikcleju@1: # end nikcleju@1: # nikcleju@1: ############################################################################ nikcleju@1: ## Are we done yet? nikcleju@1: ############################################################################ nikcleju@1: # nikcleju@1: # if strcmp(STOPCRIT,'M') nikcleju@1: # if iter >= STOPTOL nikcleju@1: # done =1; nikcleju@1: # elseif verbose && toc-t>10 nikcleju@1: # display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter)) nikcleju@1: # t=toc; nikcleju@1: # end nikcleju@1: # elseif strcmp(STOPCRIT,'mse') nikcleju@1: # if comp_err nikcleju@1: # if err_mse(iter)10 nikcleju@1: # display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter))) nikcleju@1: # t=toc; nikcleju@1: # end nikcleju@1: # else nikcleju@1: # if ERR10 nikcleju@1: # display(sprintf('Iteration #i. --- #i mse',iter ,ERR)) nikcleju@1: # t=toc; nikcleju@1: # end nikcleju@1: # end nikcleju@1: # elseif strcmp(STOPCRIT,'mse_change') && iter >=2 nikcleju@1: # if comp_err && iter >=2 nikcleju@1: # if ((err_mse(iter-1)-err_mse(iter))/sigsize 10 nikcleju@1: # display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize )) nikcleju@1: # t=toc; nikcleju@1: # end nikcleju@1: # else nikcleju@1: # if ((oldERR - ERR)/sigsize < STOPTOL); nikcleju@1: # done = 1; nikcleju@1: # elseif verbose && toc-t>10 nikcleju@1: # display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize)) nikcleju@1: # t=toc; nikcleju@1: # end nikcleju@1: # end nikcleju@1: # elseif strcmp(STOPCRIT,'corr') nikcleju@1: # if max(abs(DR)) < STOPTOL; nikcleju@1: # done = 1; nikcleju@1: # elseif verbose && toc-t>10 nikcleju@1: # display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR)))) nikcleju@1: # t=toc; nikcleju@1: # end nikcleju@1: # end nikcleju@1: # nikcleju@1: # # Also stop if residual gets too small or maxIter reached nikcleju@1: # if comp_err nikcleju@1: # if err_mse(iter)<1e-16 nikcleju@1: # display('Stopping. Exact signal representation found!') nikcleju@1: # done=1; nikcleju@1: # end nikcleju@1: # else nikcleju@1: # nikcleju@1: # nikcleju@1: # if iter>1 nikcleju@1: # if ERR<1e-16 nikcleju@1: # display('Stopping. Exact signal representation found!') nikcleju@1: # done=1; nikcleju@1: # end nikcleju@1: # end nikcleju@1: # end nikcleju@1: # nikcleju@1: # if iter >= MAXITER nikcleju@1: # display('Stopping. Maximum number of iterations reached!') nikcleju@1: # done = 1; nikcleju@1: # end nikcleju@1: # nikcleju@1: ############################################################################ nikcleju@1: ## If not done, take another round nikcleju@1: ############################################################################ nikcleju@1: # nikcleju@1: # if ~done nikcleju@1: # iter=iter+1; nikcleju@1: # oldERR=ERR; nikcleju@1: # end nikcleju@1: #end nikcleju@1: while not done: nikcleju@1: nikcleju@1: # Select new element nikcleju@1: DR[IN]=0 nikcleju@1: #[v I]=max(abs(DR)); nikcleju@1: #v = np.abs(DR).max() nikcleju@1: I = np.abs(DR).argmax() nikcleju@1: #IN = np.concatenate((IN,I)) nikcleju@1: IN.append(I) nikcleju@1: nikcleju@1: nikcleju@1: #k = k + 1 Move to end, since is zero based nikcleju@1: nikcleju@1: # Extract new element nikcleju@1: mask = np.zeros(m) nikcleju@1: mask[IN[k]] = 1 nikcleju@1: new_element = P(mask) nikcleju@1: nikcleju@1: # Orthogonalise new element nikcleju@1: if k-1 >= 0: nikcleju@1: qP = np.dot(Q[:,0:k].T , new_element) nikcleju@1: q = new_element - np.dot(Q[:,0:k] , qP) nikcleju@1: nikcleju@1: nq = np.linalg.norm(q) nikcleju@1: q = q/nq nikcleju@1: # Update QR factorisation nikcleju@1: R[0:k,k] = qP nikcleju@1: R[k,k] = nq nikcleju@1: Q[:,k] = q nikcleju@1: else: nikcleju@1: q = new_element nikcleju@1: nikcleju@1: nq = np.linalg.norm(q) nikcleju@1: q = q/nq nikcleju@1: # Update QR factorisation nikcleju@1: R[k,k] = nq nikcleju@1: Q[:,k] = q nikcleju@1: nikcleju@1: #z[k]=np.vdot(q , x) nikcleju@1: z.append(np.vdot(q , x)) nikcleju@1: nikcleju@1: # New residual nikcleju@1: Residual = Residual - q * (z[k]) nikcleju@1: DR = Pt(Residual) nikcleju@1: nikcleju@1: ERR = np.vdot(Residual , Residual) / n nikcleju@1: if comp_err: nikcleju@1: err_mse[iter-1] = ERR nikcleju@1: #end nikcleju@1: nikcleju@1: if comp_time: nikcleju@1: iter_time[iter-1] = time.time() - tic nikcleju@1: #end nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Are we done yet? nikcleju@1: ########################################################################### nikcleju@1: if STOPCRIT == 'M': nikcleju@1: if iter >= STOPTOL: nikcleju@1: done = 1 nikcleju@1: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@1: #display(sprintf('Iteration #i. --- #i iterations to go',iter ,STOPTOL-iter)) nikcleju@1: print 'Iteration '+iter+'. --- '+(STOPTOL-iter)+' iterations to go' nikcleju@1: t = time.time() nikcleju@1: #end nikcleju@1: elif STOPCRIT =='mse': nikcleju@1: if comp_err: nikcleju@1: if err_mse[iter-1] < STOPTOL: nikcleju@1: done = 1 nikcleju@1: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@1: #display(sprintf('Iteration #i. --- #i mse',iter ,err_mse(iter))) nikcleju@1: print 'Iteration '+iter+'. --- '+err_mse[iter-1]+' mse' nikcleju@1: t = time.time() nikcleju@1: #end nikcleju@1: else: nikcleju@1: if ERR < STOPTOL: nikcleju@1: done = 1 nikcleju@1: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@1: #display(sprintf('Iteration #i. --- #i mse',iter ,ERR)) nikcleju@1: print 'Iteration '+iter+'. --- '+ERR+' mse' nikcleju@1: t = time.time() nikcleju@1: #end nikcleju@1: #end nikcleju@1: elif STOPCRIT == 'mse_change' and iter >=2: nikcleju@1: if comp_err and iter >=2: nikcleju@1: if ((err_mse[iter-2] - err_mse[iter-1])/sigsize < STOPTOL): nikcleju@1: done = 1 nikcleju@1: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@1: #display(sprintf('Iteration #i. --- #i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize )) nikcleju@1: print 'Iteration '+iter+'. --- '+((err_mse[iter-2]-err_mse[iter-1])/sigsize)+' mse change' nikcleju@1: t = time.time() nikcleju@1: #end nikcleju@1: else: nikcleju@1: if ((oldERR - ERR)/sigsize < STOPTOL): nikcleju@1: done = 1 nikcleju@1: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@1: #display(sprintf('Iteration #i. --- #i mse change',iter ,(oldERR - ERR)/sigsize)) nikcleju@1: print 'Iteration '+iter+'. --- '+((oldERR - ERR)/sigsize)+' mse change' nikcleju@1: t = time.time() nikcleju@1: #end nikcleju@1: #end nikcleju@1: elif STOPCRIT == 'corr': nikcleju@1: if np.abs(DR).max() < STOPTOL: nikcleju@1: done = 1 nikcleju@1: elif verbose and time.time() - t > 10.0/1000: # time() returns sec nikcleju@1: #display(sprintf('Iteration #i. --- #i corr',iter ,max(abs(DR)))) nikcleju@1: print 'Iteration '+iter+'. --- '+(np.abs(DR).max())+' corr' nikcleju@1: t = time.time() nikcleju@1: #end nikcleju@1: #end nikcleju@1: nikcleju@1: # Also stop if residual gets too small or maxIter reached nikcleju@1: if comp_err: nikcleju@1: if err_mse[iter-1] < 1e-14: nikcleju@1: done = 1 nikcleju@1: # Nic: added verbose check nikcleju@1: if verbose: nikcleju@1: print 'Stopping. Exact signal representation found!' nikcleju@1: #end nikcleju@1: else: nikcleju@1: if iter > 1: nikcleju@1: if ERR < 1e-14: nikcleju@1: done = 1 nikcleju@1: # Nic: added verbose check nikcleju@1: if verbose: nikcleju@1: print 'Stopping. Exact signal representation found!' nikcleju@1: #end nikcleju@1: #end nikcleju@1: #end nikcleju@1: nikcleju@1: nikcleju@1: if iter >= MAXITER: nikcleju@1: done = 1 nikcleju@1: # Nic: added verbose check nikcleju@1: if verbose: nikcleju@1: print 'Stopping. Maximum number of iterations reached!' nikcleju@1: #end nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # If not done, take another round nikcleju@1: ########################################################################### nikcleju@1: if not done: nikcleju@1: iter = iter + 1 nikcleju@1: oldERR = ERR nikcleju@1: #end nikcleju@1: nikcleju@1: # Moved here from front, since we are 0-based nikcleju@1: k = k + 1 nikcleju@1: #end nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Now we can solve for s by back-substitution nikcleju@1: ########################################################################### nikcleju@1: #s(IN)=R(1:k,1:k)\z(1:k)'; nikcleju@1: s[IN] = scipy.linalg.solve(R[0:k,0:k] , np.array(z[0:k])) nikcleju@1: nikcleju@1: ########################################################################### nikcleju@1: # Only return as many elements as iterations nikcleju@1: ########################################################################### nikcleju@1: if opts['nargout'] >= 2: nikcleju@1: err_mse = err_mse[0:iter-1] nikcleju@1: #end nikcleju@1: if opts['nargout'] == 3: nikcleju@1: iter_time = iter_time[0:iter-1] nikcleju@1: #end nikcleju@1: if verbose: nikcleju@1: print 'Done' nikcleju@1: #end nikcleju@1: nikcleju@1: # Return nikcleju@1: if opts['nargout'] == 1: nikcleju@1: return s nikcleju@1: elif opts['nargout'] == 2: nikcleju@1: return s, err_mse nikcleju@1: elif opts['nargout'] == 3: nikcleju@1: return s, err_mse, iter_time nikcleju@1: nikcleju@1: # Change history nikcleju@1: # nikcleju@1: # 8 of Februray: Algo does no longer stop if dictionary is not normaliesd. nikcleju@1: nikcleju@1: # End of greed_omp_qr() function nikcleju@1: #-------------------------------- nikcleju@1: nikcleju@1: nikcleju@1: def omp_qr(x, dict, D, natom, tolerance): nikcleju@1: """ Recover x using QR implementation of OMP nikcleju@1: nikcleju@1: Parameter nikcleju@1: --------- nikcleju@1: x: measurements nikcleju@1: dict: dictionary nikcleju@1: D: Gramian of dictionary nikcleju@1: natom: iterations nikcleju@1: tolerance: error tolerance nikcleju@1: nikcleju@1: Return nikcleju@1: ------ nikcleju@1: x_hat : estimate of x nikcleju@1: gamma : indices where non-zero nikcleju@1: nikcleju@1: For more information, see http://media.aau.dk/null_space_pursuits/2011/10/efficient-omp.html nikcleju@1: """ nikcleju@1: msize, dictsize = dict.shape nikcleju@1: normr2 = np.vdot(x,x) nikcleju@1: normtol2 = tolerance*normr2 nikcleju@1: R = np.zeros((natom,natom)) nikcleju@1: Q = np.zeros((msize,natom)) nikcleju@1: gamma = [] nikcleju@1: nikcleju@1: # find initial projections nikcleju@1: origprojections = np.dot(x.T,dict) nikcleju@1: origprojectionsT = origprojections.T nikcleju@1: projections = origprojections.copy(); nikcleju@1: nikcleju@1: k = 0 nikcleju@1: while (normr2 > normtol2) and (k < natom): nikcleju@1: # find index of maximum magnitude projection nikcleju@1: newgam = np.argmax(np.abs(projections ** 2)) nikcleju@1: gamma.append(newgam) nikcleju@1: # update QR factorization, projections, and residual energy nikcleju@1: if k == 0: nikcleju@1: R[0,0] = 1 nikcleju@1: Q[:,0] = dict[:,newgam].copy() nikcleju@1: # update projections nikcleju@1: QtempQtempT = np.outer(Q[:,0],Q[:,0]) nikcleju@1: projections -= np.dot(x.T, np.dot(QtempQtempT,dict)) nikcleju@1: # update residual energy nikcleju@1: normr2 -= np.vdot(x, np.dot(QtempQtempT,x)) nikcleju@1: else: nikcleju@1: w = scipy.linalg.solve_triangular(R[0:k,0:k],D[gamma[0:k],newgam],trans=1) nikcleju@1: R[k,k] = np.sqrt(1-np.vdot(w,w)) nikcleju@1: R[0:k,k] = w.copy() nikcleju@1: Q[:,k] = (dict[:,newgam] - np.dot(QtempQtempT,dict[:,newgam]))/R[k,k] nikcleju@1: QkQkT = np.outer(Q[:,k],Q[:,k]) nikcleju@1: xTQkQkT = np.dot(x.T,QkQkT) nikcleju@1: QtempQtempT += QkQkT nikcleju@1: # update projections nikcleju@1: projections -= np.dot(xTQkQkT,dict) nikcleju@1: # update residual energy nikcleju@1: normr2 -= np.dot(xTQkQkT,x) nikcleju@1: nikcleju@1: k += 1 nikcleju@1: nikcleju@1: # build solution nikcleju@1: tempR = R[0:k,0:k] nikcleju@1: w = scipy.linalg.solve_triangular(tempR,origprojectionsT[gamma[0:k]],trans=1) nikcleju@1: x_hat = np.zeros((dictsize,1)) nikcleju@1: x_hat[gamma[0:k]] = scipy.linalg.solve_triangular(tempR,w) nikcleju@1: nikcleju@1: return x_hat, gamma