idamnjanovic@40: function Dictionary = SMALL_rlsdla(X, params) ivan@85: %% Recursive Least Squares Dictionary Learning Algorithm ivan@85: % ivan@85: % D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for ivan@85: % training signals specified as columns of matrix X with parameters ivan@85: % specified in params structure returning the learned dictionary D. ivan@85: % ivan@85: % Fields in params structure: ivan@85: % Required: ivan@85: % 'Tdata' / 'Edata' sparse-coding target ivan@85: % 'initdict' / 'dictsize' initial dictionary / dictionary size ivan@85: % ivan@85: % Optional (default values in parentheses): ivan@85: % 'codemode' 'sparsity' or 'error' ('sparsity') ivan@85: % 'maxatoms' max # of atoms in error sparse-coding (none) ivan@85: % 'forgettingMode' 'fix' - fix forgetting factor, ivan@85: % other modes are not implemented in ivan@85: % this version(exponential etc.) ivan@85: % 'forgettingFactor' for 'fix' mode (default is 1) ivan@85: % 'show_dict' shows dictionary after # of ivan@85: % iterations specified (less then 100 ivan@85: % can make it running slow). In this ivan@85: % version it assumes that it is image ivan@85: % dictionary and atoms size is 8x8 ivan@85: % ivan@85: % - RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares ivan@85: % Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on, ivan@85: % vol.58, no.4, pp.2121-2130, April 2010 ivan@85: % idamnjanovic@40: idamnjanovic@40: ivan@85: % Centre for Digital Music, Queen Mary, University of London. ivan@85: % This file copyright 2011 Ivan Damnjanovic. ivan@85: % ivan@85: % This program is free software; you can redistribute it and/or ivan@85: % modify it under the terms of the GNU General Public License as ivan@85: % published by the Free Software Foundation; either version 2 of the ivan@85: % License, or (at your option) any later version. See the file ivan@85: % COPYING included with this distribution for more information. ivan@85: % ivan@85: %% idamnjanovic@40: idamnjanovic@40: CODE_SPARSITY = 1; idamnjanovic@40: CODE_ERROR = 2; idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % Determine which method will be used for sparse representation step - idamnjanovic@40: % Sparsity or Error mode idamnjanovic@40: idamnjanovic@40: if (isfield(params,'codemode')) idamnjanovic@40: switch lower(params.codemode) idamnjanovic@40: case 'sparsity' idamnjanovic@40: codemode = CODE_SPARSITY; idamnjanovic@40: thresh = params.Tdata; idamnjanovic@40: case 'error' idamnjanovic@40: codemode = CODE_ERROR; idamnjanovic@40: thresh = params.Edata; idamnjanovic@40: idamnjanovic@40: otherwise idamnjanovic@40: error('Invalid coding mode specified'); idamnjanovic@40: end idamnjanovic@40: elseif (isfield(params,'Tdata')) idamnjanovic@40: codemode = CODE_SPARSITY; idamnjanovic@40: thresh = params.Tdata; idamnjanovic@40: elseif (isfield(params,'Edata')) idamnjanovic@40: codemode = CODE_ERROR; idamnjanovic@40: thresh = params.Edata; idamnjanovic@40: idamnjanovic@40: else idamnjanovic@40: error('Data sparse-coding target not specified'); idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: % max number of atoms % idamnjanovic@40: idamnjanovic@40: if (codemode==CODE_ERROR && isfield(params,'maxatoms')) idamnjanovic@40: maxatoms = params.maxatoms; idamnjanovic@40: else idamnjanovic@40: maxatoms = -1; idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: ivan@85: ivan@85: ivan@85: % determine dictionary size % ivan@85: ivan@85: if (isfield(params,'initdict')) ivan@85: if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) ivan@85: dictsize = length(params.initdict); ivan@85: else ivan@85: dictsize = size(params.initdict,2); ivan@85: end ivan@85: end ivan@85: if (isfield(params,'dictsize')) % this superceedes the size determined by initdict ivan@85: dictsize = params.dictsize; ivan@85: end ivan@85: ivan@85: if (size(X,2) < dictsize) ivan@85: error('Number of training signals is smaller than number of atoms to train'); ivan@85: end ivan@85: ivan@85: ivan@85: % initialize the dictionary % ivan@85: ivan@85: if (isfield(params,'initdict')) ivan@85: if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) ivan@85: D = X(:,params.initdict(1:dictsize)); ivan@85: else ivan@85: if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2) 1e-6); % ensure no zero data elements are chosen ivan@85: perm = randperm(length(data_ids)); ivan@85: D = X(:,data_ids(perm(1:dictsize))); ivan@85: end ivan@85: ivan@85: ivan@85: % normalize the dictionary % ivan@85: ivan@85: D = normcols(D); ivan@85: ivan@85: % show dictonary every specified number of iterations ivan@85: ivan@85: if (isfield(params,'show_dict')) ivan@85: show_dictionary=1; ivan@85: show_iter=params.show_dict; ivan@85: else ivan@85: show_dictionary=0; ivan@85: show_iter=0; ivan@85: end ivan@85: ivan@85: if (show_dictionary) ivan@85: dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); ivan@85: figure(2); imshow(imresize(dictimg,2,'nearest')); ivan@85: end idamnjanovic@40: % Forgetting factor idamnjanovic@40: idamnjanovic@40: if (isfield(params,'forgettingMode')) idamnjanovic@40: switch lower(params.forgettingMode) idamnjanovic@40: case 'fix' idamnjanovic@40: if (isfield(params,'forgettingFactor')) idamnjanovic@40: lambda=params.forgettingFactor; idamnjanovic@40: else idamnjanovic@40: lambda=1; idamnjanovic@40: end idamnjanovic@40: otherwise idamnjanovic@40: error('This mode is still not implemented'); idamnjanovic@40: end idamnjanovic@40: elseif (isfield(params,'forgettingFactor')) idamnjanovic@40: lambda=params.forgettingFactor; idamnjanovic@40: else idamnjanovic@40: lambda=1; idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: % Training data idamnjanovic@40: idamnjanovic@40: data=X; idamnjanovic@65: cnt=size(data,2); ivan@85: idamnjanovic@40: % idamnjanovic@40: idamnjanovic@40: C=(100000*thresh)*eye(dictsize); idamnjanovic@40: w=zeros(dictsize,1); idamnjanovic@40: u=zeros(dictsize,1); idamnjanovic@40: idamnjanovic@40: idamnjanovic@65: for i = 1:cnt idamnjanovic@40: idamnjanovic@40: if (codemode == CODE_SPARSITY) ivan@85: w = omp2(D,data(:,i),[],thresh,'checkdict','off'); idamnjanovic@40: else idamnjanovic@66: w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off'); idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: spind=find(w); idamnjanovic@40: idamnjanovic@40: residual = data(:,i) - D * w; idamnjanovic@40: idamnjanovic@40: if (lambda~=1) idamnjanovic@40: C = C *(1/ lambda); idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: u = C(:,spind) * w(spind); idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: alfa = 1/(1 + w' * u); idamnjanovic@40: idamnjanovic@40: D = D + (alfa * residual) * u'; idamnjanovic@40: idamnjanovic@40: idamnjanovic@40: C = C - (alfa * u)* u'; ivan@85: if (show_dictionary &&(mod(i,show_iter)==0)) ivan@85: dictimg = showdict(D,[8 8],... ivan@85: round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); ivan@85: figure(2); imshow(imresize(dictimg,2,'nearest')); ivan@85: pause(0.02); ivan@85: end idamnjanovic@40: end idamnjanovic@40: idamnjanovic@40: Dictionary = D; idamnjanovic@40: idamnjanovic@40: end