ivan@128
|
1 function [D] = SMALL_rlsdla(X, params)
|
ivan@128
|
2 %% Recursive Least Squares Dictionary Learning Algorithm
|
ivan@85
|
3 %
|
ivan@85
|
4 % D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for
|
ivan@85
|
5 % training signals specified as columns of matrix X with parameters
|
ivan@85
|
6 % specified in params structure returning the learned dictionary D.
|
ivan@85
|
7 %
|
ivan@85
|
8 % Fields in params structure:
|
ivan@85
|
9 % Required:
|
ivan@85
|
10 % 'Tdata' / 'Edata' sparse-coding target
|
ivan@85
|
11 % 'initdict' / 'dictsize' initial dictionary / dictionary size
|
ivan@85
|
12 %
|
ivan@85
|
13 % Optional (default values in parentheses):
|
ivan@85
|
14 % 'codemode' 'sparsity' or 'error' ('sparsity')
|
ivan@85
|
15 % 'maxatoms' max # of atoms in error sparse-coding (none)
|
ivan@85
|
16 % 'forgettingMode' 'fix' - fix forgetting factor,
|
ivan@85
|
17 % other modes are not implemented in
|
ivan@85
|
18 % this version(exponential etc.)
|
ivan@85
|
19 % 'forgettingFactor' for 'fix' mode (default is 1)
|
ivan@85
|
20 % 'show_dict' shows dictionary after # of
|
ivan@85
|
21 % iterations specified (less then 100
|
ivan@85
|
22 % can make it running slow). In this
|
ivan@85
|
23 % version it assumes that it is image
|
ivan@85
|
24 % dictionary and atoms size is 8x8
|
ivan@85
|
25 %
|
ivan@85
|
26 % - RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares
|
ivan@85
|
27 % Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on,
|
ivan@85
|
28 % vol.58, no.4, pp.2121-2130, April 2010
|
aris@219
|
29
|
ivan@85
|
30 %
|
ivan@85
|
31 % Centre for Digital Music, Queen Mary, University of London.
|
ivan@85
|
32 % This file copyright 2011 Ivan Damnjanovic.
|
ivan@85
|
33 %
|
ivan@85
|
34 % This program is free software; you can redistribute it and/or
|
ivan@85
|
35 % modify it under the terms of the GNU General Public License as
|
ivan@85
|
36 % published by the Free Software Foundation; either version 2 of the
|
ivan@85
|
37 % License, or (at your option) any later version. See the file
|
ivan@85
|
38 % COPYING included with this distribution for more information.
|
ivan@85
|
39 %
|
ivan@85
|
40 %%
|
idamnjanovic@40
|
41
|
idamnjanovic@40
|
42 CODE_SPARSITY = 1;
|
idamnjanovic@40
|
43 CODE_ERROR = 2;
|
idamnjanovic@40
|
44
|
idamnjanovic@40
|
45
|
idamnjanovic@40
|
46 % Determine which method will be used for sparse representation step -
|
idamnjanovic@40
|
47 % Sparsity or Error mode
|
idamnjanovic@40
|
48
|
idamnjanovic@40
|
49 if (isfield(params,'codemode'))
|
idamnjanovic@40
|
50 switch lower(params.codemode)
|
idamnjanovic@40
|
51 case 'sparsity'
|
idamnjanovic@40
|
52 codemode = CODE_SPARSITY;
|
idamnjanovic@40
|
53 thresh = params.Tdata;
|
idamnjanovic@40
|
54 case 'error'
|
idamnjanovic@40
|
55 codemode = CODE_ERROR;
|
idamnjanovic@40
|
56 thresh = params.Edata;
|
idamnjanovic@40
|
57
|
idamnjanovic@40
|
58 otherwise
|
idamnjanovic@40
|
59 error('Invalid coding mode specified');
|
idamnjanovic@40
|
60 end
|
idamnjanovic@40
|
61 elseif (isfield(params,'Tdata'))
|
idamnjanovic@40
|
62 codemode = CODE_SPARSITY;
|
idamnjanovic@40
|
63 thresh = params.Tdata;
|
idamnjanovic@40
|
64 elseif (isfield(params,'Edata'))
|
idamnjanovic@40
|
65 codemode = CODE_ERROR;
|
idamnjanovic@40
|
66 thresh = params.Edata;
|
idamnjanovic@40
|
67
|
idamnjanovic@40
|
68 else
|
idamnjanovic@40
|
69 error('Data sparse-coding target not specified');
|
idamnjanovic@40
|
70 end
|
idamnjanovic@40
|
71
|
idamnjanovic@40
|
72
|
idamnjanovic@40
|
73 % max number of atoms %
|
idamnjanovic@40
|
74
|
idamnjanovic@40
|
75 if (codemode==CODE_ERROR && isfield(params,'maxatoms'))
|
idamnjanovic@40
|
76 maxatoms = params.maxatoms;
|
idamnjanovic@40
|
77 else
|
idamnjanovic@40
|
78 maxatoms = -1;
|
idamnjanovic@40
|
79 end
|
idamnjanovic@40
|
80
|
idamnjanovic@40
|
81
|
ivan@85
|
82
|
ivan@85
|
83
|
ivan@85
|
84 % determine dictionary size %
|
ivan@85
|
85
|
ivan@85
|
86 if (isfield(params,'initdict'))
|
ivan@85
|
87 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
|
ivan@85
|
88 dictsize = length(params.initdict);
|
ivan@85
|
89 else
|
ivan@85
|
90 dictsize = size(params.initdict,2);
|
ivan@85
|
91 end
|
ivan@85
|
92 end
|
ivan@85
|
93 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict
|
ivan@85
|
94 dictsize = params.dictsize;
|
ivan@85
|
95 end
|
ivan@85
|
96
|
ivan@85
|
97 if (size(X,2) < dictsize)
|
ivan@85
|
98 error('Number of training signals is smaller than number of atoms to train');
|
ivan@85
|
99 end
|
ivan@85
|
100
|
ivan@85
|
101
|
ivan@85
|
102 % initialize the dictionary %
|
ivan@85
|
103
|
ivan@85
|
104 if (isfield(params,'initdict'))
|
ivan@85
|
105 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
|
ivan@85
|
106 D = X(:,params.initdict(1:dictsize));
|
ivan@85
|
107 else
|
ivan@85
|
108 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
|
ivan@85
|
109 error('Invalid initial dictionary');
|
ivan@85
|
110 end
|
ivan@85
|
111 D = params.initdict(:,1:dictsize);
|
ivan@85
|
112 end
|
ivan@85
|
113 else
|
ivan@85
|
114 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen
|
ivan@85
|
115 perm = randperm(length(data_ids));
|
ivan@85
|
116 D = X(:,data_ids(perm(1:dictsize)));
|
ivan@85
|
117 end
|
ivan@85
|
118
|
ivan@85
|
119
|
ivan@85
|
120 % normalize the dictionary %
|
ivan@85
|
121
|
ivan@85
|
122 D = normcols(D);
|
ivan@85
|
123
|
ivan@85
|
124 % show dictonary every specified number of iterations
|
ivan@85
|
125
|
ivan@85
|
126 if (isfield(params,'show_dict'))
|
ivan@85
|
127 show_dictionary=1;
|
ivan@85
|
128 show_iter=params.show_dict;
|
ivan@85
|
129 else
|
ivan@85
|
130 show_dictionary=0;
|
ivan@85
|
131 show_iter=0;
|
ivan@85
|
132 end
|
ivan@85
|
133
|
ivan@85
|
134 if (show_dictionary)
|
ivan@114
|
135 dictimg = SMALL_showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
|
ivan@114
|
136 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@85
|
137 end
|
idamnjanovic@40
|
138 % Forgetting factor
|
idamnjanovic@40
|
139
|
idamnjanovic@40
|
140 if (isfield(params,'forgettingMode'))
|
idamnjanovic@40
|
141 switch lower(params.forgettingMode)
|
idamnjanovic@40
|
142 case 'fix'
|
idamnjanovic@40
|
143 if (isfield(params,'forgettingFactor'))
|
idamnjanovic@40
|
144 lambda=params.forgettingFactor;
|
idamnjanovic@40
|
145 else
|
idamnjanovic@40
|
146 lambda=1;
|
idamnjanovic@40
|
147 end
|
idamnjanovic@40
|
148 otherwise
|
idamnjanovic@40
|
149 error('This mode is still not implemented');
|
idamnjanovic@40
|
150 end
|
idamnjanovic@40
|
151 elseif (isfield(params,'forgettingFactor'))
|
idamnjanovic@40
|
152 lambda=params.forgettingFactor;
|
idamnjanovic@40
|
153 else
|
idamnjanovic@40
|
154 lambda=1;
|
idamnjanovic@40
|
155 end
|
idamnjanovic@40
|
156
|
idamnjanovic@40
|
157 % Training data
|
idamnjanovic@40
|
158
|
idamnjanovic@40
|
159 data=X;
|
idamnjanovic@65
|
160 cnt=size(data,2);
|
ivan@85
|
161
|
idamnjanovic@40
|
162 %
|
idamnjanovic@40
|
163
|
idamnjanovic@40
|
164 C=(100000*thresh)*eye(dictsize);
|
idamnjanovic@40
|
165 w=zeros(dictsize,1);
|
idamnjanovic@40
|
166 u=zeros(dictsize,1);
|
idamnjanovic@40
|
167
|
idamnjanovic@40
|
168
|
idamnjanovic@65
|
169 for i = 1:cnt
|
idamnjanovic@40
|
170
|
idamnjanovic@40
|
171 if (codemode == CODE_SPARSITY)
|
ivan@85
|
172 w = omp2(D,data(:,i),[],thresh,'checkdict','off');
|
idamnjanovic@40
|
173 else
|
idamnjanovic@66
|
174 w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off');
|
idamnjanovic@40
|
175 end
|
idamnjanovic@40
|
176
|
idamnjanovic@40
|
177 spind=find(w);
|
idamnjanovic@40
|
178
|
idamnjanovic@40
|
179 residual = data(:,i) - D * w;
|
idamnjanovic@40
|
180
|
idamnjanovic@40
|
181 if (lambda~=1)
|
idamnjanovic@40
|
182 C = C *(1/ lambda);
|
idamnjanovic@40
|
183 end
|
idamnjanovic@40
|
184
|
idamnjanovic@40
|
185 u = C(:,spind) * w(spind);
|
idamnjanovic@40
|
186
|
idamnjanovic@40
|
187
|
idamnjanovic@40
|
188 alfa = 1/(1 + w' * u);
|
idamnjanovic@40
|
189
|
idamnjanovic@40
|
190 D = D + (alfa * residual) * u';
|
idamnjanovic@40
|
191
|
idamnjanovic@40
|
192
|
idamnjanovic@40
|
193 C = C - (alfa * u)* u';
|
ivan@85
|
194 if (show_dictionary &&(mod(i,show_iter)==0))
|
ivan@114
|
195 dictimg = SMALL_showdict(D,[8 8],...
|
ivan@85
|
196 round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
|
ivan@114
|
197 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@85
|
198 pause(0.02);
|
ivan@85
|
199 end
|
idamnjanovic@40
|
200 end
|
idamnjanovic@40
|
201
|
idamnjanovic@40
|
202
|
idamnjanovic@40
|
203 end
|