annotate DL/RLS-DLA/SMALL_rlsdla.m @ 160:e3035d45d014 danieleb

Added support classes
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Wed, 31 Aug 2011 10:53:10 +0100
parents 8e660fd14774
children 4337e28183f1
rev   line source
ivan@128 1 function [D] = SMALL_rlsdla(X, params)
ivan@128 2 %% Recursive Least Squares Dictionary Learning Algorithm
ivan@85 3 %
ivan@85 4 % D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for
ivan@85 5 % training signals specified as columns of matrix X with parameters
ivan@85 6 % specified in params structure returning the learned dictionary D.
ivan@85 7 %
ivan@85 8 % Fields in params structure:
ivan@85 9 % Required:
ivan@85 10 % 'Tdata' / 'Edata' sparse-coding target
ivan@85 11 % 'initdict' / 'dictsize' initial dictionary / dictionary size
ivan@85 12 %
ivan@85 13 % Optional (default values in parentheses):
ivan@85 14 % 'codemode' 'sparsity' or 'error' ('sparsity')
ivan@85 15 % 'maxatoms' max # of atoms in error sparse-coding (none)
ivan@85 16 % 'forgettingMode' 'fix' - fix forgetting factor,
ivan@85 17 % other modes are not implemented in
ivan@85 18 % this version(exponential etc.)
ivan@85 19 % 'forgettingFactor' for 'fix' mode (default is 1)
ivan@85 20 % 'show_dict' shows dictionary after # of
ivan@85 21 % iterations specified (less then 100
ivan@85 22 % can make it running slow). In this
ivan@85 23 % version it assumes that it is image
ivan@85 24 % dictionary and atoms size is 8x8
ivan@85 25 %
ivan@85 26 % - RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares
ivan@85 27 % Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on,
ivan@85 28 % vol.58, no.4, pp.2121-2130, April 2010
ivan@85 29 %
idamnjanovic@40 30
idamnjanovic@40 31
ivan@85 32 % Centre for Digital Music, Queen Mary, University of London.
ivan@85 33 % This file copyright 2011 Ivan Damnjanovic.
ivan@85 34 %
ivan@85 35 % This program is free software; you can redistribute it and/or
ivan@85 36 % modify it under the terms of the GNU General Public License as
ivan@85 37 % published by the Free Software Foundation; either version 2 of the
ivan@85 38 % License, or (at your option) any later version. See the file
ivan@85 39 % COPYING included with this distribution for more information.
ivan@85 40 %
ivan@85 41 %%
idamnjanovic@40 42
idamnjanovic@40 43 CODE_SPARSITY = 1;
idamnjanovic@40 44 CODE_ERROR = 2;
idamnjanovic@40 45
idamnjanovic@40 46
idamnjanovic@40 47 % Determine which method will be used for sparse representation step -
idamnjanovic@40 48 % Sparsity or Error mode
idamnjanovic@40 49
idamnjanovic@40 50 if (isfield(params,'codemode'))
idamnjanovic@40 51 switch lower(params.codemode)
idamnjanovic@40 52 case 'sparsity'
idamnjanovic@40 53 codemode = CODE_SPARSITY;
idamnjanovic@40 54 thresh = params.Tdata;
idamnjanovic@40 55 case 'error'
idamnjanovic@40 56 codemode = CODE_ERROR;
idamnjanovic@40 57 thresh = params.Edata;
idamnjanovic@40 58
idamnjanovic@40 59 otherwise
idamnjanovic@40 60 error('Invalid coding mode specified');
idamnjanovic@40 61 end
idamnjanovic@40 62 elseif (isfield(params,'Tdata'))
idamnjanovic@40 63 codemode = CODE_SPARSITY;
idamnjanovic@40 64 thresh = params.Tdata;
idamnjanovic@40 65 elseif (isfield(params,'Edata'))
idamnjanovic@40 66 codemode = CODE_ERROR;
idamnjanovic@40 67 thresh = params.Edata;
idamnjanovic@40 68
idamnjanovic@40 69 else
idamnjanovic@40 70 error('Data sparse-coding target not specified');
idamnjanovic@40 71 end
idamnjanovic@40 72
idamnjanovic@40 73
idamnjanovic@40 74 % max number of atoms %
idamnjanovic@40 75
idamnjanovic@40 76 if (codemode==CODE_ERROR && isfield(params,'maxatoms'))
idamnjanovic@40 77 maxatoms = params.maxatoms;
idamnjanovic@40 78 else
idamnjanovic@40 79 maxatoms = -1;
idamnjanovic@40 80 end
idamnjanovic@40 81
idamnjanovic@40 82
ivan@85 83
ivan@85 84
ivan@85 85 % determine dictionary size %
ivan@85 86
ivan@85 87 if (isfield(params,'initdict'))
ivan@85 88 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
ivan@85 89 dictsize = length(params.initdict);
ivan@85 90 else
ivan@85 91 dictsize = size(params.initdict,2);
ivan@85 92 end
ivan@85 93 end
ivan@85 94 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict
ivan@85 95 dictsize = params.dictsize;
ivan@85 96 end
ivan@85 97
ivan@85 98 if (size(X,2) < dictsize)
ivan@85 99 error('Number of training signals is smaller than number of atoms to train');
ivan@85 100 end
ivan@85 101
ivan@85 102
ivan@85 103 % initialize the dictionary %
ivan@85 104
ivan@85 105 if (isfield(params,'initdict'))
ivan@85 106 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
ivan@85 107 D = X(:,params.initdict(1:dictsize));
ivan@85 108 else
ivan@85 109 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
ivan@85 110 error('Invalid initial dictionary');
ivan@85 111 end
ivan@85 112 D = params.initdict(:,1:dictsize);
ivan@85 113 end
ivan@85 114 else
ivan@85 115 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen
ivan@85 116 perm = randperm(length(data_ids));
ivan@85 117 D = X(:,data_ids(perm(1:dictsize)));
ivan@85 118 end
ivan@85 119
ivan@85 120
ivan@85 121 % normalize the dictionary %
ivan@85 122
ivan@85 123 D = normcols(D);
ivan@85 124
ivan@85 125 % show dictonary every specified number of iterations
ivan@85 126
ivan@85 127 if (isfield(params,'show_dict'))
ivan@85 128 show_dictionary=1;
ivan@85 129 show_iter=params.show_dict;
ivan@85 130 else
ivan@85 131 show_dictionary=0;
ivan@85 132 show_iter=0;
ivan@85 133 end
ivan@85 134
ivan@85 135 if (show_dictionary)
ivan@114 136 dictimg = SMALL_showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
ivan@114 137 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@85 138 end
idamnjanovic@40 139 % Forgetting factor
idamnjanovic@40 140
idamnjanovic@40 141 if (isfield(params,'forgettingMode'))
idamnjanovic@40 142 switch lower(params.forgettingMode)
idamnjanovic@40 143 case 'fix'
idamnjanovic@40 144 if (isfield(params,'forgettingFactor'))
idamnjanovic@40 145 lambda=params.forgettingFactor;
idamnjanovic@40 146 else
idamnjanovic@40 147 lambda=1;
idamnjanovic@40 148 end
idamnjanovic@40 149 otherwise
idamnjanovic@40 150 error('This mode is still not implemented');
idamnjanovic@40 151 end
idamnjanovic@40 152 elseif (isfield(params,'forgettingFactor'))
idamnjanovic@40 153 lambda=params.forgettingFactor;
idamnjanovic@40 154 else
idamnjanovic@40 155 lambda=1;
idamnjanovic@40 156 end
idamnjanovic@40 157
idamnjanovic@40 158 % Training data
idamnjanovic@40 159
idamnjanovic@40 160 data=X;
idamnjanovic@65 161 cnt=size(data,2);
ivan@85 162
idamnjanovic@40 163 %
idamnjanovic@40 164
idamnjanovic@40 165 C=(100000*thresh)*eye(dictsize);
idamnjanovic@40 166 w=zeros(dictsize,1);
idamnjanovic@40 167 u=zeros(dictsize,1);
idamnjanovic@40 168
idamnjanovic@40 169
idamnjanovic@65 170 for i = 1:cnt
idamnjanovic@40 171
idamnjanovic@40 172 if (codemode == CODE_SPARSITY)
ivan@85 173 w = omp2(D,data(:,i),[],thresh,'checkdict','off');
idamnjanovic@40 174 else
idamnjanovic@66 175 w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off');
idamnjanovic@40 176 end
idamnjanovic@40 177
idamnjanovic@40 178 spind=find(w);
idamnjanovic@40 179
idamnjanovic@40 180 residual = data(:,i) - D * w;
idamnjanovic@40 181
idamnjanovic@40 182 if (lambda~=1)
idamnjanovic@40 183 C = C *(1/ lambda);
idamnjanovic@40 184 end
idamnjanovic@40 185
idamnjanovic@40 186 u = C(:,spind) * w(spind);
idamnjanovic@40 187
idamnjanovic@40 188
idamnjanovic@40 189 alfa = 1/(1 + w' * u);
idamnjanovic@40 190
idamnjanovic@40 191 D = D + (alfa * residual) * u';
idamnjanovic@40 192
idamnjanovic@40 193
idamnjanovic@40 194 C = C - (alfa * u)* u';
ivan@85 195 if (show_dictionary &&(mod(i,show_iter)==0))
ivan@114 196 dictimg = SMALL_showdict(D,[8 8],...
ivan@85 197 round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
ivan@114 198 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@85 199 pause(0.02);
ivan@85 200 end
idamnjanovic@40 201 end
idamnjanovic@40 202
idamnjanovic@40 203
idamnjanovic@40 204 end