Mercurial > hg > smallbox
view DL/RLS-DLA/SMALL_rlsdla.m @ 137:9207d56c5547 ivand_dev
New ompbox in utils for testing purposes
author | Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk> |
---|---|
date | Thu, 21 Jul 2011 14:07:41 +0100 |
parents | 8e660fd14774 |
children | 4337e28183f1 |
line wrap: on
line source
function [D] = SMALL_rlsdla(X, params) %% Recursive Least Squares Dictionary Learning Algorithm % % D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for % training signals specified as columns of matrix X with parameters % specified in params structure returning the learned dictionary D. % % Fields in params structure: % Required: % 'Tdata' / 'Edata' sparse-coding target % 'initdict' / 'dictsize' initial dictionary / dictionary size % % Optional (default values in parentheses): % 'codemode' 'sparsity' or 'error' ('sparsity') % 'maxatoms' max # of atoms in error sparse-coding (none) % 'forgettingMode' 'fix' - fix forgetting factor, % other modes are not implemented in % this version(exponential etc.) % 'forgettingFactor' for 'fix' mode (default is 1) % 'show_dict' shows dictionary after # of % iterations specified (less then 100 % can make it running slow). In this % version it assumes that it is image % dictionary and atoms size is 8x8 % % - RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares % Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on, % vol.58, no.4, pp.2121-2130, April 2010 % % Centre for Digital Music, Queen Mary, University of London. % This file copyright 2011 Ivan Damnjanovic. % % This program is free software; you can redistribute it and/or % modify it under the terms of the GNU General Public License as % published by the Free Software Foundation; either version 2 of the % License, or (at your option) any later version. See the file % COPYING included with this distribution for more information. % %% CODE_SPARSITY = 1; CODE_ERROR = 2; % Determine which method will be used for sparse representation step - % Sparsity or Error mode if (isfield(params,'codemode')) switch lower(params.codemode) case 'sparsity' codemode = CODE_SPARSITY; thresh = params.Tdata; case 'error' codemode = CODE_ERROR; thresh = params.Edata; otherwise error('Invalid coding mode specified'); end elseif (isfield(params,'Tdata')) codemode = CODE_SPARSITY; thresh = params.Tdata; elseif (isfield(params,'Edata')) codemode = CODE_ERROR; thresh = params.Edata; else error('Data sparse-coding target not specified'); end % max number of atoms % if (codemode==CODE_ERROR && isfield(params,'maxatoms')) maxatoms = params.maxatoms; else maxatoms = -1; end % determine dictionary size % if (isfield(params,'initdict')) if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) dictsize = length(params.initdict); else dictsize = size(params.initdict,2); end end if (isfield(params,'dictsize')) % this superceedes the size determined by initdict dictsize = params.dictsize; end if (size(X,2) < dictsize) error('Number of training signals is smaller than number of atoms to train'); end % initialize the dictionary % if (isfield(params,'initdict')) if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) D = X(:,params.initdict(1:dictsize)); else if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) error('Invalid initial dictionary'); end D = params.initdict(:,1:dictsize); end else data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen perm = randperm(length(data_ids)); D = X(:,data_ids(perm(1:dictsize))); end % normalize the dictionary % D = normcols(D); % show dictonary every specified number of iterations if (isfield(params,'show_dict')) show_dictionary=1; show_iter=params.show_dict; else show_dictionary=0; show_iter=0; end if (show_dictionary) dictimg = SMALL_showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; end % Forgetting factor if (isfield(params,'forgettingMode')) switch lower(params.forgettingMode) case 'fix' if (isfield(params,'forgettingFactor')) lambda=params.forgettingFactor; else lambda=1; end otherwise error('This mode is still not implemented'); end elseif (isfield(params,'forgettingFactor')) lambda=params.forgettingFactor; else lambda=1; end % Training data data=X; cnt=size(data,2); % C=(100000*thresh)*eye(dictsize); w=zeros(dictsize,1); u=zeros(dictsize,1); for i = 1:cnt if (codemode == CODE_SPARSITY) w = omp2(D,data(:,i),[],thresh,'checkdict','off'); else w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off'); end spind=find(w); residual = data(:,i) - D * w; if (lambda~=1) C = C *(1/ lambda); end u = C(:,spind) * w(spind); alfa = 1/(1 + w' * u); D = D + (alfa * residual) * u'; C = C - (alfa * u)* u'; if (show_dictionary &&(mod(i,show_iter)==0)) dictimg = SMALL_showdict(D,[8 8],... round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; pause(0.02); end end end