ivan@155
|
1 function DL = wrapper_mm_DL(Problem, DL)
|
aris@211
|
2 %% SMALL wrapper for Majorization Minimization Dictionary Learning Algorithm
|
aris@211
|
3 %
|
aris@211
|
4 % Function gets as input Problem and Dictionary Learning (DL) structures
|
aris@211
|
5 % and outputs the learned Dictionary.
|
aris@211
|
6
|
aris@211
|
7 % In Problem structure field b with the training set needs to be defined.
|
aris@211
|
8
|
aris@211
|
9 % In DL fields with name of the Dictionary update method and parameters
|
aris@211
|
10 % for particular dictionary learning technique need to be present. For
|
aris@211
|
11 % the orignal version of MM algorithm the update method should be:
|
aris@211
|
12 % - 'mm_cn' - Regularized DL with column norm contraint
|
aris@211
|
13 % - 'mm_fn' - Regularized DL with Frobenius norm contraint
|
aris@211
|
14 % Alternatively, for comparison purposes the following Dictioanry update
|
aris@211
|
15 % methods (which do not represent the optimised version of the algorithm)
|
aris@211
|
16 % be used:
|
aris@211
|
17 % - 'mod_cn' - Method of Optimized Direction
|
aris@211
|
18 % - 'map-cn' - Maximum a Posteriory Dictionary update
|
aris@211
|
19 % - 'ksvd-cn'- KSVD update
|
aris@211
|
20 %
|
aris@211
|
21 % - MM-DL - Yaghoobi, M.; Blumensath, T,; Davies M.; , "Dictionary
|
aris@211
|
22 % Learning for Sparse Approximation with Majorization Method," IEEE
|
aris@211
|
23 % Transactions on Signal Processing, vol.57, no.6, pp.2178-2191, 2009.
|
aris@211
|
24
|
aris@211
|
25 % Centre for Digital Music, Queen Mary, University of London.
|
aris@211
|
26 % This file copyright 2011 Ivan Damnjanovic.
|
aris@211
|
27 %
|
aris@211
|
28 % This program is free software; you can redistribute it and/or
|
aris@211
|
29 % modify it under the terms of the GNU General Public License as
|
aris@211
|
30 % published by the Free Software Foundation; either version 2 of the
|
aris@211
|
31 % License, or (at your option) any later version. See the file
|
aris@211
|
32 % COPYING included with this distribution for more information.
|
aris@211
|
33 %%
|
ivan@155
|
34
|
ivan@155
|
35 % determine which solver is used for sparse representation %
|
ivan@155
|
36
|
ivan@155
|
37 solver = DL.param.solver;
|
ivan@155
|
38
|
ivan@155
|
39 % determine which type of udate to use
|
ivan@155
|
40 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
|
ivan@155
|
41 % 'MAP_cn', 'KSVD_cn')
|
ivan@155
|
42
|
ivan@155
|
43 typeUpdate = DL.name;
|
ivan@155
|
44
|
ivan@155
|
45 sig = Problem.b;
|
ivan@155
|
46
|
ivan@155
|
47 % determine dictionary size %
|
ivan@155
|
48
|
ivan@155
|
49 if (isfield(DL.param,'initdict'))
|
ivan@155
|
50 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@155
|
51 dictsize = length(DL.param.initdict);
|
ivan@155
|
52 else
|
ivan@155
|
53 dictsize = size(DL.param.initdict,2);
|
ivan@155
|
54 end
|
ivan@155
|
55 end
|
ivan@155
|
56
|
ivan@155
|
57 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
|
ivan@155
|
58 dictsize = DL.param.dictsize;
|
ivan@155
|
59 end
|
ivan@155
|
60
|
ivan@155
|
61 if (size(sig,2) < dictsize)
|
ivan@155
|
62 error('Number of training signals is smaller than number of atoms to train');
|
ivan@155
|
63 end
|
ivan@155
|
64
|
ivan@155
|
65
|
ivan@155
|
66 % initialize the dictionary %
|
ivan@155
|
67
|
ivan@155
|
68 if (isfield(DL.param,'initdict'))
|
ivan@155
|
69 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@155
|
70 dico = sig(:,DL.param.initdict(1:dictsize));
|
ivan@155
|
71 else
|
ivan@155
|
72 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
|
ivan@155
|
73 error('Invalid initial dictionary');
|
ivan@155
|
74 end
|
ivan@155
|
75 dico = DL.param.initdict(:,1:dictsize);
|
ivan@155
|
76 end
|
ivan@155
|
77 else
|
ivan@155
|
78 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
|
ivan@155
|
79 perm = randperm(length(data_ids));
|
ivan@155
|
80 dico = sig(:,data_ids(perm(1:dictsize)));
|
ivan@155
|
81 end
|
ivan@155
|
82
|
ivan@155
|
83
|
ivan@155
|
84 % number of iterations (default is 40) %
|
ivan@155
|
85
|
ivan@155
|
86 if isfield(DL.param,'iternum')
|
ivan@155
|
87 iternum = DL.param.iternum;
|
ivan@155
|
88 else
|
ivan@155
|
89 iternum = 40;
|
ivan@155
|
90 end
|
ivan@155
|
91
|
ivan@155
|
92 % number of iterations (default is 40) %
|
ivan@155
|
93
|
ivan@155
|
94 if isfield(DL.param,'iterDictUpdate')
|
ivan@155
|
95 maxIT = DL.param.iterDictUpdate;
|
ivan@155
|
96 else
|
ivan@155
|
97 maxIT = 1000;
|
ivan@155
|
98 end
|
ivan@155
|
99
|
ivan@155
|
100 % Stopping criterion for MM dictionary update (default = 1e-7)
|
ivan@155
|
101
|
ivan@155
|
102 if isfield(DL.param,'epsDictUpdate')
|
ivan@155
|
103 epsD = DL.param.epsDictUpdate;
|
ivan@155
|
104 else
|
ivan@155
|
105 epsD = 1e-7;
|
ivan@155
|
106 end
|
ivan@155
|
107
|
ivan@155
|
108 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
|
ivan@155
|
109 % (default cvset is o) %
|
ivan@155
|
110
|
ivan@155
|
111 if isfield(DL.param,'cvset')
|
ivan@155
|
112 cvset = DL.param.cvset;
|
ivan@155
|
113 else
|
ivan@155
|
114 cvset = 0;
|
ivan@155
|
115 end
|
ivan@155
|
116
|
ivan@155
|
117 % determine if we should do decorrelation in every iteration %
|
ivan@155
|
118
|
ivan@155
|
119 if isfield(DL.param,'coherence')
|
ivan@155
|
120 decorrelate = 1;
|
ivan@155
|
121 mu = DL.param.coherence;
|
ivan@155
|
122 else
|
ivan@155
|
123 decorrelate = 0;
|
ivan@155
|
124 end
|
ivan@155
|
125
|
ivan@155
|
126 % show dictonary every specified number of iterations
|
ivan@155
|
127
|
ivan@155
|
128 if isfield(DL.param,'show_dict')
|
ivan@155
|
129 show_dictionary = 1;
|
ivan@155
|
130 show_iter = DL.param.show_dict;
|
ivan@155
|
131 else
|
ivan@155
|
132 show_dictionary = 0;
|
ivan@155
|
133 show_iter = 0;
|
ivan@155
|
134 end
|
ivan@155
|
135
|
ivan@155
|
136 % This is a small patch that needs to be resolved in dictionary learning we
|
ivan@155
|
137 % want sparse representation of training set, and in Problem.b1 in this
|
ivan@155
|
138 % version of software we store the signal that needs to be represented
|
ivan@155
|
139 % (for example the whole image)
|
ivan@155
|
140 if isfield(Problem,'b1')
|
ivan@155
|
141 tmpTraining = Problem.b1;
|
ivan@155
|
142 Problem.b1 = sig;
|
ivan@155
|
143 end
|
ivan@155
|
144 if isfield(Problem,'reconstruct')
|
ivan@155
|
145 Problem = rmfield(Problem, 'reconstruct');
|
ivan@155
|
146 end
|
ivan@155
|
147 solver.profile = 0;
|
ivan@155
|
148
|
ivan@155
|
149 % main loop %
|
ivan@155
|
150
|
ivan@155
|
151 for i = 1:iternum
|
ivan@155
|
152 Problem.A = dico;
|
ivan@155
|
153
|
ivan@155
|
154 solver = SMALL_solve(Problem, solver);
|
ivan@155
|
155
|
ivan@155
|
156 switch lower(typeUpdate)
|
ivan@155
|
157 case 'mm_cn'
|
ivan@155
|
158 [dico, solver.solution] = ...
|
ivan@155
|
159 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
|
ivan@155
|
160 case 'mm_fn'
|
ivan@155
|
161 [dico, solver.solution] = ...
|
ivan@155
|
162 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
|
ivan@155
|
163 case 'mod_cn'
|
ivan@155
|
164 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
|
ivan@155
|
165 case 'map_cn'
|
ivan@155
|
166 if isfield(DL.param,'muMAP')
|
ivan@155
|
167 muMAP = DL.param.muMAP;
|
ivan@155
|
168 else
|
ivan@155
|
169 muMAP = 1e-4;
|
ivan@155
|
170 end
|
ivan@155
|
171 [dico, solver.solution] = ...
|
ivan@155
|
172 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
|
ivan@155
|
173 case 'ksvd_cn'
|
ivan@155
|
174 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
|
ivan@155
|
175 otherwise
|
ivan@155
|
176 error('Dictionary update is not defined');
|
ivan@155
|
177 end
|
ivan@155
|
178
|
ivan@155
|
179 % Set previous solution as the best initial guess
|
ivan@155
|
180 % for the next iteration of iterative soft tresholding
|
ivan@155
|
181
|
ivan@155
|
182 if (strcmpi(solver.toolbox, 'MMbox'))
|
ivan@155
|
183 solver.param.initcoeff = solver.solution;
|
ivan@155
|
184 end
|
ivan@155
|
185
|
ivan@155
|
186 % Optional decorrelation of athoms - this is from Boris Mailhe and
|
ivan@155
|
187 % we need to test how it preforms with Mehrdad's updates
|
ivan@155
|
188
|
ivan@155
|
189 if (decorrelate)
|
ivan@155
|
190 dico = dico_decorr(dico, mu, solver.solution);
|
ivan@155
|
191 end
|
ivan@155
|
192
|
ivan@155
|
193 if ((show_dictionary)&&(mod(i,show_iter)==0))
|
ivan@155
|
194 dictimg = SMALL_showdict(dico,[8 8],...
|
ivan@155
|
195 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
|
ivan@155
|
196 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@155
|
197 pause(0.02);
|
ivan@155
|
198 end
|
ivan@155
|
199 end
|
ivan@155
|
200 if isfield(Problem,'b1')
|
ivan@155
|
201 Problem.b1 = tmpTraining;
|
ivan@155
|
202 end
|
ivan@155
|
203 DL.D = dico;
|
ivan@155
|
204
|
ivan@155
|
205 end
|
ivan@155
|
206
|
ivan@155
|
207 function Y = colnorms_squared(X)
|
ivan@155
|
208
|
ivan@155
|
209 % compute in blocks to conserve memory
|
ivan@155
|
210 Y = zeros(1,size(X,2));
|
ivan@155
|
211 blocksize = 2000;
|
ivan@155
|
212 for i = 1:blocksize:size(X,2)
|
ivan@155
|
213 blockids = i : min(i+blocksize-1,size(X,2));
|
ivan@155
|
214 Y(blockids) = sum(X(:,blockids).^2);
|
ivan@155
|
215 end
|
ivan@155
|
216
|
ivan@155
|
217 end |