view DL/RLS-DLA/SMALL_rlsdla.m @ 86:f6cc633fd94b

cpu/tic-toc time comments
author Maria Jafari <maria.jafari@eecs.qmul.ac.uk>
date Mon, 11 Apr 2011 16:44:31 +0100
parents fd1c32cda22c
children 04cce72a4dc8
line wrap: on
line source
function Dictionary = SMALL_rlsdla(X, params)
%% Recursive Least Squares Dictionary Learning Algorithm
%
%   D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for
%   training signals specified as columns of matrix X with parameters 
%   specified in params structure returning the learned dictionary D.
%   
%   Fields in params structure:
%       Required:
%               'Tdata' / 'Edata'        sparse-coding target
%               'initdict' / 'dictsize'  initial dictionary / dictionary size
%
%       Optional (default values in parentheses):
%               'codemode'               'sparsity' or 'error' ('sparsity')
%               'maxatoms'               max # of atoms in error sparse-coding (none)
%               'forgettingMode'         'fix' - fix forgetting factor,
%                                        other modes are not implemented in
%                                        this version(exponential etc.)
%               'forgettingFactor'      for 'fix' mode (default is 1)
%               'show_dict'             shows dictionary after # of
%                                       iterations specified (less then 100
%                                       can make it running slow). In this
%                                       version it assumes that it is image
%                                       dictionary and atoms size is 8x8
%                                       
%   -   RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares
%       Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on,
%       vol.58, no.4, pp.2121-2130, April 2010
%


%   Centre for Digital Music, Queen Mary, University of London.
%   This file copyright 2011 Ivan Damnjanovic.
%
%   This program is free software; you can redistribute it and/or
%   modify it under the terms of the GNU General Public License as
%   published by the Free Software Foundation; either version 2 of the
%   License, or (at your option) any later version.  See the file
%   COPYING included with this distribution for more information.
%   
%%

CODE_SPARSITY = 1;
CODE_ERROR = 2;


% Determine which method will be used for sparse representation step -
% Sparsity or Error mode

if (isfield(params,'codemode'))
    switch lower(params.codemode)
        case 'sparsity'
            codemode = CODE_SPARSITY;
            thresh = params.Tdata;
        case 'error'
            codemode = CODE_ERROR;
            thresh = params.Edata;
            
        otherwise
            error('Invalid coding mode specified');
    end
elseif (isfield(params,'Tdata'))
    codemode = CODE_SPARSITY;
    thresh = params.Tdata;
elseif (isfield(params,'Edata'))
    codemode = CODE_ERROR;
    thresh = params.Edata;
    
else
    error('Data sparse-coding target not specified');
end


% max number of atoms %

if (codemode==CODE_ERROR && isfield(params,'maxatoms'))
  maxatoms = params.maxatoms;
else
  maxatoms = -1;
end




% determine dictionary size %

if (isfield(params,'initdict'))
  if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
    dictsize = length(params.initdict);
  else
    dictsize = size(params.initdict,2);
  end
end
if (isfield(params,'dictsize'))    % this superceedes the size determined by initdict
  dictsize = params.dictsize;
end

if (size(X,2) < dictsize)
  error('Number of training signals is smaller than number of atoms to train');
end


% initialize the dictionary %

if (isfield(params,'initdict'))
  if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
    D = X(:,params.initdict(1:dictsize));
  else
    if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
      error('Invalid initial dictionary');
    end
    D = params.initdict(:,1:dictsize);
  end
else
  data_ids = find(colnorms_squared(X) > 1e-6);   % ensure no zero data elements are chosen
  perm = randperm(length(data_ids));
  D = X(:,data_ids(perm(1:dictsize)));
end


% normalize the dictionary %

D = normcols(D);

% show dictonary every specified number of iterations

if (isfield(params,'show_dict'))
    show_dictionary=1;
    show_iter=params.show_dict;
else
    show_dictionary=0;
    show_iter=0;
end

if (show_dictionary)
    dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
    figure(2); imshow(imresize(dictimg,2,'nearest'));
end
% Forgetting factor

if (isfield(params,'forgettingMode'))
     switch lower(params.forgettingMode)
       case 'fix'
                if (isfield(params,'forgettingFactor'))
                    lambda=params.forgettingFactor;
                else
                    lambda=1;
                end
       otherwise
              error('This mode is still not implemented');
     end
elseif (isfield(params,'forgettingFactor'))
     lambda=params.forgettingFactor;
else
     lambda=1;
end

% Training data

data=X;
cnt=size(data,2);

%

C=(100000*thresh)*eye(dictsize);
w=zeros(dictsize,1);
u=zeros(dictsize,1);


for i = 1:cnt

   if (codemode == CODE_SPARSITY)
        w = omp2(D,data(:,i),[],thresh,'checkdict','off');
   else
        w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off');
   end
   
   spind=find(w);
   
   residual = data(:,i) - D * w;
   
   if (lambda~=1) 
       C = C *(1/ lambda);
   end
   
   u = C(:,spind) * w(spind);
  
    
   alfa = 1/(1 + w' * u);
    
   D = D + (alfa * residual) * u';
    
   
   C = C - (alfa * u)* u';
   if (show_dictionary &&(mod(i,show_iter)==0))
       dictimg = showdict(D,[8 8],...
            round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');  
       figure(2); imshow(imresize(dictimg,2,'nearest'));
       pause(0.02);
   end
end

Dictionary = D;

end