ivan@128
|
1 function [D] = SMALL_rlsdla(X, params)
|
ivan@128
|
2 %% Recursive Least Squares Dictionary Learning Algorithm
|
ivan@85
|
3 %
|
ivan@85
|
4 % D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for
|
ivan@85
|
5 % training signals specified as columns of matrix X with parameters
|
ivan@85
|
6 % specified in params structure returning the learned dictionary D.
|
ivan@85
|
7 %
|
ivan@85
|
8 % Fields in params structure:
|
ivan@85
|
9 % Required:
|
ivan@85
|
10 % 'Tdata' / 'Edata' sparse-coding target
|
ivan@85
|
11 % 'initdict' / 'dictsize' initial dictionary / dictionary size
|
ivan@85
|
12 %
|
ivan@85
|
13 % Optional (default values in parentheses):
|
ivan@85
|
14 % 'codemode' 'sparsity' or 'error' ('sparsity')
|
ivan@85
|
15 % 'maxatoms' max # of atoms in error sparse-coding (none)
|
ivan@85
|
16 % 'forgettingMode' 'fix' - fix forgetting factor,
|
ivan@85
|
17 % other modes are not implemented in
|
ivan@85
|
18 % this version(exponential etc.)
|
ivan@85
|
19 % 'forgettingFactor' for 'fix' mode (default is 1)
|
ivan@85
|
20 % 'show_dict' shows dictionary after # of
|
ivan@85
|
21 % iterations specified (less then 100
|
ivan@85
|
22 % can make it running slow). In this
|
ivan@85
|
23 % version it assumes that it is image
|
ivan@85
|
24 % dictionary and atoms size is 8x8
|
ivan@85
|
25 %
|
ivan@85
|
26 % - RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares
|
ivan@85
|
27 % Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on,
|
ivan@85
|
28 % vol.58, no.4, pp.2121-2130, April 2010
|
ivan@85
|
29 %
|
idamnjanovic@40
|
30
|
idamnjanovic@40
|
31
|
ivan@85
|
32 % Centre for Digital Music, Queen Mary, University of London.
|
ivan@85
|
33 % This file copyright 2011 Ivan Damnjanovic.
|
ivan@85
|
34 %
|
ivan@85
|
35 % This program is free software; you can redistribute it and/or
|
ivan@85
|
36 % modify it under the terms of the GNU General Public License as
|
ivan@85
|
37 % published by the Free Software Foundation; either version 2 of the
|
ivan@85
|
38 % License, or (at your option) any later version. See the file
|
ivan@85
|
39 % COPYING included with this distribution for more information.
|
ivan@85
|
40 %
|
ivan@85
|
41 %%
|
idamnjanovic@40
|
42
|
idamnjanovic@40
|
43 CODE_SPARSITY = 1;
|
idamnjanovic@40
|
44 CODE_ERROR = 2;
|
idamnjanovic@40
|
45
|
idamnjanovic@40
|
46
|
idamnjanovic@40
|
47 % Determine which method will be used for sparse representation step -
|
idamnjanovic@40
|
48 % Sparsity or Error mode
|
idamnjanovic@40
|
49
|
idamnjanovic@40
|
50 if (isfield(params,'codemode'))
|
idamnjanovic@40
|
51 switch lower(params.codemode)
|
idamnjanovic@40
|
52 case 'sparsity'
|
idamnjanovic@40
|
53 codemode = CODE_SPARSITY;
|
idamnjanovic@40
|
54 thresh = params.Tdata;
|
idamnjanovic@40
|
55 case 'error'
|
idamnjanovic@40
|
56 codemode = CODE_ERROR;
|
idamnjanovic@40
|
57 thresh = params.Edata;
|
idamnjanovic@40
|
58
|
idamnjanovic@40
|
59 otherwise
|
idamnjanovic@40
|
60 error('Invalid coding mode specified');
|
idamnjanovic@40
|
61 end
|
idamnjanovic@40
|
62 elseif (isfield(params,'Tdata'))
|
idamnjanovic@40
|
63 codemode = CODE_SPARSITY;
|
idamnjanovic@40
|
64 thresh = params.Tdata;
|
idamnjanovic@40
|
65 elseif (isfield(params,'Edata'))
|
idamnjanovic@40
|
66 codemode = CODE_ERROR;
|
idamnjanovic@40
|
67 thresh = params.Edata;
|
idamnjanovic@40
|
68
|
idamnjanovic@40
|
69 else
|
idamnjanovic@40
|
70 error('Data sparse-coding target not specified');
|
idamnjanovic@40
|
71 end
|
idamnjanovic@40
|
72
|
idamnjanovic@40
|
73
|
idamnjanovic@40
|
74 % max number of atoms %
|
idamnjanovic@40
|
75
|
idamnjanovic@40
|
76 if (codemode==CODE_ERROR && isfield(params,'maxatoms'))
|
idamnjanovic@40
|
77 maxatoms = params.maxatoms;
|
idamnjanovic@40
|
78 else
|
idamnjanovic@40
|
79 maxatoms = -1;
|
idamnjanovic@40
|
80 end
|
idamnjanovic@40
|
81
|
idamnjanovic@40
|
82
|
ivan@85
|
83
|
ivan@85
|
84
|
ivan@85
|
85 % determine dictionary size %
|
ivan@85
|
86
|
ivan@85
|
87 if (isfield(params,'initdict'))
|
ivan@85
|
88 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
|
ivan@85
|
89 dictsize = length(params.initdict);
|
ivan@85
|
90 else
|
ivan@85
|
91 dictsize = size(params.initdict,2);
|
ivan@85
|
92 end
|
ivan@85
|
93 end
|
ivan@85
|
94 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict
|
ivan@85
|
95 dictsize = params.dictsize;
|
ivan@85
|
96 end
|
ivan@85
|
97
|
ivan@85
|
98 if (size(X,2) < dictsize)
|
ivan@85
|
99 error('Number of training signals is smaller than number of atoms to train');
|
ivan@85
|
100 end
|
ivan@85
|
101
|
ivan@85
|
102
|
ivan@85
|
103 % initialize the dictionary %
|
ivan@85
|
104
|
ivan@85
|
105 if (isfield(params,'initdict'))
|
ivan@85
|
106 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
|
ivan@85
|
107 D = X(:,params.initdict(1:dictsize));
|
ivan@85
|
108 else
|
ivan@85
|
109 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
|
ivan@85
|
110 error('Invalid initial dictionary');
|
ivan@85
|
111 end
|
ivan@85
|
112 D = params.initdict(:,1:dictsize);
|
ivan@85
|
113 end
|
ivan@85
|
114 else
|
ivan@85
|
115 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen
|
ivan@85
|
116 perm = randperm(length(data_ids));
|
ivan@85
|
117 D = X(:,data_ids(perm(1:dictsize)));
|
ivan@85
|
118 end
|
ivan@85
|
119
|
ivan@85
|
120
|
ivan@85
|
121 % normalize the dictionary %
|
ivan@85
|
122
|
ivan@85
|
123 D = normcols(D);
|
ivan@85
|
124
|
ivan@85
|
125 % show dictonary every specified number of iterations
|
ivan@85
|
126
|
ivan@85
|
127 if (isfield(params,'show_dict'))
|
ivan@85
|
128 show_dictionary=1;
|
ivan@85
|
129 show_iter=params.show_dict;
|
ivan@85
|
130 else
|
ivan@85
|
131 show_dictionary=0;
|
ivan@85
|
132 show_iter=0;
|
ivan@85
|
133 end
|
ivan@85
|
134
|
ivan@85
|
135 if (show_dictionary)
|
ivan@114
|
136 dictimg = SMALL_showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
|
ivan@114
|
137 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@85
|
138 end
|
idamnjanovic@40
|
139 % Forgetting factor
|
idamnjanovic@40
|
140
|
idamnjanovic@40
|
141 if (isfield(params,'forgettingMode'))
|
idamnjanovic@40
|
142 switch lower(params.forgettingMode)
|
idamnjanovic@40
|
143 case 'fix'
|
idamnjanovic@40
|
144 if (isfield(params,'forgettingFactor'))
|
idamnjanovic@40
|
145 lambda=params.forgettingFactor;
|
idamnjanovic@40
|
146 else
|
idamnjanovic@40
|
147 lambda=1;
|
idamnjanovic@40
|
148 end
|
idamnjanovic@40
|
149 otherwise
|
idamnjanovic@40
|
150 error('This mode is still not implemented');
|
idamnjanovic@40
|
151 end
|
idamnjanovic@40
|
152 elseif (isfield(params,'forgettingFactor'))
|
idamnjanovic@40
|
153 lambda=params.forgettingFactor;
|
idamnjanovic@40
|
154 else
|
idamnjanovic@40
|
155 lambda=1;
|
idamnjanovic@40
|
156 end
|
idamnjanovic@40
|
157
|
idamnjanovic@40
|
158 % Training data
|
idamnjanovic@40
|
159
|
idamnjanovic@40
|
160 data=X;
|
idamnjanovic@65
|
161 cnt=size(data,2);
|
ivan@85
|
162
|
idamnjanovic@40
|
163 %
|
idamnjanovic@40
|
164
|
idamnjanovic@40
|
165 C=(100000*thresh)*eye(dictsize);
|
idamnjanovic@40
|
166 w=zeros(dictsize,1);
|
idamnjanovic@40
|
167 u=zeros(dictsize,1);
|
idamnjanovic@40
|
168
|
idamnjanovic@40
|
169
|
idamnjanovic@65
|
170 for i = 1:cnt
|
idamnjanovic@40
|
171
|
idamnjanovic@40
|
172 if (codemode == CODE_SPARSITY)
|
ivan@85
|
173 w = omp2(D,data(:,i),[],thresh,'checkdict','off');
|
idamnjanovic@40
|
174 else
|
idamnjanovic@66
|
175 w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off');
|
idamnjanovic@40
|
176 end
|
idamnjanovic@40
|
177
|
idamnjanovic@40
|
178 spind=find(w);
|
idamnjanovic@40
|
179
|
idamnjanovic@40
|
180 residual = data(:,i) - D * w;
|
idamnjanovic@40
|
181
|
idamnjanovic@40
|
182 if (lambda~=1)
|
idamnjanovic@40
|
183 C = C *(1/ lambda);
|
idamnjanovic@40
|
184 end
|
idamnjanovic@40
|
185
|
idamnjanovic@40
|
186 u = C(:,spind) * w(spind);
|
idamnjanovic@40
|
187
|
idamnjanovic@40
|
188
|
idamnjanovic@40
|
189 alfa = 1/(1 + w' * u);
|
idamnjanovic@40
|
190
|
idamnjanovic@40
|
191 D = D + (alfa * residual) * u';
|
idamnjanovic@40
|
192
|
idamnjanovic@40
|
193
|
idamnjanovic@40
|
194 C = C - (alfa * u)* u';
|
ivan@85
|
195 if (show_dictionary &&(mod(i,show_iter)==0))
|
ivan@114
|
196 dictimg = SMALL_showdict(D,[8 8],...
|
ivan@85
|
197 round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
|
ivan@114
|
198 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@85
|
199 pause(0.02);
|
ivan@85
|
200 end
|
idamnjanovic@40
|
201 end
|
idamnjanovic@40
|
202
|
idamnjanovic@40
|
203
|
idamnjanovic@40
|
204 end
|