annotate DL/RLS-DLA/SMALL_rlsdla.m @ 226:2124e4884765 danieleb

Update to reflect default branch.
author luisf <luis.figueira@eecs.qmul.ac.uk>
date Thu, 12 Apr 2012 13:59:59 +0100
parents 4337e28183f1
children
rev   line source
ivan@128 1 function [D] = SMALL_rlsdla(X, params)
ivan@128 2 %% Recursive Least Squares Dictionary Learning Algorithm
ivan@85 3 %
ivan@85 4 % D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for
ivan@85 5 % training signals specified as columns of matrix X with parameters
ivan@85 6 % specified in params structure returning the learned dictionary D.
ivan@85 7 %
ivan@85 8 % Fields in params structure:
ivan@85 9 % Required:
ivan@85 10 % 'Tdata' / 'Edata' sparse-coding target
ivan@85 11 % 'initdict' / 'dictsize' initial dictionary / dictionary size
ivan@85 12 %
ivan@85 13 % Optional (default values in parentheses):
ivan@85 14 % 'codemode' 'sparsity' or 'error' ('sparsity')
ivan@85 15 % 'maxatoms' max # of atoms in error sparse-coding (none)
ivan@85 16 % 'forgettingMode' 'fix' - fix forgetting factor,
ivan@85 17 % other modes are not implemented in
ivan@85 18 % this version(exponential etc.)
ivan@85 19 % 'forgettingFactor' for 'fix' mode (default is 1)
ivan@85 20 % 'show_dict' shows dictionary after # of
ivan@85 21 % iterations specified (less then 100
ivan@85 22 % can make it running slow). In this
ivan@85 23 % version it assumes that it is image
ivan@85 24 % dictionary and atoms size is 8x8
ivan@85 25 %
ivan@85 26 % - RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares
ivan@85 27 % Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on,
ivan@85 28 % vol.58, no.4, pp.2121-2130, April 2010
aris@219 29
ivan@85 30 %
ivan@85 31 % Centre for Digital Music, Queen Mary, University of London.
ivan@85 32 % This file copyright 2011 Ivan Damnjanovic.
ivan@85 33 %
ivan@85 34 % This program is free software; you can redistribute it and/or
ivan@85 35 % modify it under the terms of the GNU General Public License as
ivan@85 36 % published by the Free Software Foundation; either version 2 of the
ivan@85 37 % License, or (at your option) any later version. See the file
ivan@85 38 % COPYING included with this distribution for more information.
ivan@85 39 %
ivan@85 40 %%
idamnjanovic@40 41
idamnjanovic@40 42 CODE_SPARSITY = 1;
idamnjanovic@40 43 CODE_ERROR = 2;
idamnjanovic@40 44
idamnjanovic@40 45
idamnjanovic@40 46 % Determine which method will be used for sparse representation step -
idamnjanovic@40 47 % Sparsity or Error mode
idamnjanovic@40 48
idamnjanovic@40 49 if (isfield(params,'codemode'))
idamnjanovic@40 50 switch lower(params.codemode)
idamnjanovic@40 51 case 'sparsity'
idamnjanovic@40 52 codemode = CODE_SPARSITY;
idamnjanovic@40 53 thresh = params.Tdata;
idamnjanovic@40 54 case 'error'
idamnjanovic@40 55 codemode = CODE_ERROR;
idamnjanovic@40 56 thresh = params.Edata;
idamnjanovic@40 57
idamnjanovic@40 58 otherwise
idamnjanovic@40 59 error('Invalid coding mode specified');
idamnjanovic@40 60 end
idamnjanovic@40 61 elseif (isfield(params,'Tdata'))
idamnjanovic@40 62 codemode = CODE_SPARSITY;
idamnjanovic@40 63 thresh = params.Tdata;
idamnjanovic@40 64 elseif (isfield(params,'Edata'))
idamnjanovic@40 65 codemode = CODE_ERROR;
idamnjanovic@40 66 thresh = params.Edata;
idamnjanovic@40 67
idamnjanovic@40 68 else
idamnjanovic@40 69 error('Data sparse-coding target not specified');
idamnjanovic@40 70 end
idamnjanovic@40 71
idamnjanovic@40 72
idamnjanovic@40 73 % max number of atoms %
idamnjanovic@40 74
idamnjanovic@40 75 if (codemode==CODE_ERROR && isfield(params,'maxatoms'))
idamnjanovic@40 76 maxatoms = params.maxatoms;
idamnjanovic@40 77 else
idamnjanovic@40 78 maxatoms = -1;
idamnjanovic@40 79 end
idamnjanovic@40 80
idamnjanovic@40 81
ivan@85 82
ivan@85 83
ivan@85 84 % determine dictionary size %
ivan@85 85
ivan@85 86 if (isfield(params,'initdict'))
ivan@85 87 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
ivan@85 88 dictsize = length(params.initdict);
ivan@85 89 else
ivan@85 90 dictsize = size(params.initdict,2);
ivan@85 91 end
ivan@85 92 end
ivan@85 93 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict
ivan@85 94 dictsize = params.dictsize;
ivan@85 95 end
ivan@85 96
ivan@85 97 if (size(X,2) < dictsize)
ivan@85 98 error('Number of training signals is smaller than number of atoms to train');
ivan@85 99 end
ivan@85 100
ivan@85 101
ivan@85 102 % initialize the dictionary %
ivan@85 103
ivan@85 104 if (isfield(params,'initdict'))
ivan@85 105 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
ivan@85 106 D = X(:,params.initdict(1:dictsize));
ivan@85 107 else
ivan@85 108 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
ivan@85 109 error('Invalid initial dictionary');
ivan@85 110 end
ivan@85 111 D = params.initdict(:,1:dictsize);
ivan@85 112 end
ivan@85 113 else
ivan@85 114 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen
ivan@85 115 perm = randperm(length(data_ids));
ivan@85 116 D = X(:,data_ids(perm(1:dictsize)));
ivan@85 117 end
ivan@85 118
ivan@85 119
ivan@85 120 % normalize the dictionary %
ivan@85 121
ivan@85 122 D = normcols(D);
ivan@85 123
ivan@85 124 % show dictonary every specified number of iterations
ivan@85 125
ivan@85 126 if (isfield(params,'show_dict'))
ivan@85 127 show_dictionary=1;
ivan@85 128 show_iter=params.show_dict;
ivan@85 129 else
ivan@85 130 show_dictionary=0;
ivan@85 131 show_iter=0;
ivan@85 132 end
ivan@85 133
ivan@85 134 if (show_dictionary)
ivan@114 135 dictimg = SMALL_showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
ivan@114 136 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@85 137 end
idamnjanovic@40 138 % Forgetting factor
idamnjanovic@40 139
idamnjanovic@40 140 if (isfield(params,'forgettingMode'))
idamnjanovic@40 141 switch lower(params.forgettingMode)
idamnjanovic@40 142 case 'fix'
idamnjanovic@40 143 if (isfield(params,'forgettingFactor'))
idamnjanovic@40 144 lambda=params.forgettingFactor;
idamnjanovic@40 145 else
idamnjanovic@40 146 lambda=1;
idamnjanovic@40 147 end
idamnjanovic@40 148 otherwise
idamnjanovic@40 149 error('This mode is still not implemented');
idamnjanovic@40 150 end
idamnjanovic@40 151 elseif (isfield(params,'forgettingFactor'))
idamnjanovic@40 152 lambda=params.forgettingFactor;
idamnjanovic@40 153 else
idamnjanovic@40 154 lambda=1;
idamnjanovic@40 155 end
idamnjanovic@40 156
idamnjanovic@40 157 % Training data
idamnjanovic@40 158
idamnjanovic@40 159 data=X;
idamnjanovic@65 160 cnt=size(data,2);
ivan@85 161
idamnjanovic@40 162 %
idamnjanovic@40 163
idamnjanovic@40 164 C=(100000*thresh)*eye(dictsize);
idamnjanovic@40 165 w=zeros(dictsize,1);
idamnjanovic@40 166 u=zeros(dictsize,1);
idamnjanovic@40 167
idamnjanovic@40 168
idamnjanovic@65 169 for i = 1:cnt
idamnjanovic@40 170
idamnjanovic@40 171 if (codemode == CODE_SPARSITY)
ivan@85 172 w = omp2(D,data(:,i),[],thresh,'checkdict','off');
idamnjanovic@40 173 else
idamnjanovic@66 174 w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off');
idamnjanovic@40 175 end
idamnjanovic@40 176
idamnjanovic@40 177 spind=find(w);
idamnjanovic@40 178
idamnjanovic@40 179 residual = data(:,i) - D * w;
idamnjanovic@40 180
idamnjanovic@40 181 if (lambda~=1)
idamnjanovic@40 182 C = C *(1/ lambda);
idamnjanovic@40 183 end
idamnjanovic@40 184
idamnjanovic@40 185 u = C(:,spind) * w(spind);
idamnjanovic@40 186
idamnjanovic@40 187
idamnjanovic@40 188 alfa = 1/(1 + w' * u);
idamnjanovic@40 189
idamnjanovic@40 190 D = D + (alfa * residual) * u';
idamnjanovic@40 191
idamnjanovic@40 192
idamnjanovic@40 193 C = C - (alfa * u)* u';
ivan@85 194 if (show_dictionary &&(mod(i,show_iter)==0))
ivan@114 195 dictimg = SMALL_showdict(D,[8 8],...
ivan@85 196 round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
ivan@114 197 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@85 198 pause(0.02);
ivan@85 199 end
idamnjanovic@40 200 end
idamnjanovic@40 201
idamnjanovic@40 202
idamnjanovic@40 203 end