ivan@155
|
1 function DL = wrapper_mm_DL(Problem, DL)
|
aris@211
|
2 %% SMALL wrapper for Majorization Minimization Dictionary Learning Algorithm
|
aris@211
|
3 %
|
aris@211
|
4 % Function gets as input Problem and Dictionary Learning (DL) structures
|
aris@211
|
5 % and outputs the learned Dictionary.
|
aris@219
|
6 %
|
aris@211
|
7 % In Problem structure field b with the training set needs to be defined.
|
aris@219
|
8 %
|
aris@214
|
9 % In DL structure field with name of the Dictionary update method needs
|
aris@214
|
10 % to be present. For the orignal version of MM algorithm the update
|
aris@214
|
11 % method should be:
|
aris@219
|
12 % - 'mm_cn' Regularized DL with column norm contraint
|
aris@219
|
13 % - 'mm_fn' Regularized DL with Frobenius norm contraint
|
aris@211
|
14 % Alternatively, for comparison purposes the following Dictioanry update
|
aris@211
|
15 % methods (which do not represent the optimised version of the algorithm)
|
aris@211
|
16 % be used:
|
aris@219
|
17 % - 'mod_cn' Method of Optimized Direction
|
aris@219
|
18 % - 'map-cn' Maximum a Posteriory Dictionary update
|
aris@219
|
19 % - 'ksvd-cn' KSVD update
|
aris@219
|
20 %
|
aris@219
|
21 % The structure DL.param with parameters is also required. These are:
|
aris@219
|
22 % - solver structure with fields toolbox, solver and parameters.
|
aris@219
|
23 % For the original version of the algorithm toolbox
|
aris@219
|
24 % should be 'MMbox' and solver field should be left
|
aris@219
|
25 % empty ''. Type HELP WRAPPER_MM_SOLVER for more
|
aris@219
|
26 % details on how to set the parameters.
|
aris@219
|
27 % - initdict Initial Dictionary
|
aris@219
|
28 % - dictsize Dictionary size (optional)
|
aris@219
|
29 % - iternum Number of iterations (default is 40)
|
aris@219
|
30 % - iterDictUpdate Number of iterations for Dictionary Update (default is 1000)
|
aris@219
|
31 % - epsDictUpdate Stopping criterion for MM dictionary update (default = 1e-7)
|
aris@219
|
32 % - cvset Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
|
aris@219
|
33 % (default is 0)
|
aris@219
|
34 % - coherence Set at 1 if to perform decorrelation in every iteration
|
aris@219
|
35 % (default is 0)
|
aris@219
|
36 % - show_dict Show dictonary every specified number of iterations
|
aris@219
|
37 %
|
aris@211
|
38 %
|
aris@211
|
39 % - MM-DL - Yaghoobi, M.; Blumensath, T,; Davies M.; , "Dictionary
|
aris@211
|
40 % Learning for Sparse Approximation with Majorization Method," IEEE
|
aris@211
|
41 % Transactions on Signal Processing, vol.57, no.6, pp.2178-2191, 2009.
|
aris@211
|
42
|
aris@219
|
43 %
|
aris@211
|
44 % Centre for Digital Music, Queen Mary, University of London.
|
aris@211
|
45 % This file copyright 2011 Ivan Damnjanovic.
|
aris@211
|
46 %
|
aris@211
|
47 % This program is free software; you can redistribute it and/or
|
aris@211
|
48 % modify it under the terms of the GNU General Public License as
|
aris@211
|
49 % published by the Free Software Foundation; either version 2 of the
|
aris@211
|
50 % License, or (at your option) any later version. See the file
|
aris@211
|
51 % COPYING included with this distribution for more information.
|
aris@211
|
52 %%
|
ivan@155
|
53
|
ivan@155
|
54 % determine which solver is used for sparse representation %
|
ivan@155
|
55
|
ivan@155
|
56 solver = DL.param.solver;
|
ivan@155
|
57
|
ivan@155
|
58 % determine which type of udate to use
|
ivan@155
|
59 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
|
ivan@155
|
60 % 'MAP_cn', 'KSVD_cn')
|
ivan@155
|
61
|
ivan@155
|
62 typeUpdate = DL.name;
|
ivan@155
|
63
|
ivan@155
|
64 sig = Problem.b;
|
ivan@155
|
65
|
ivan@155
|
66 % determine dictionary size %
|
ivan@155
|
67
|
ivan@155
|
68 if (isfield(DL.param,'initdict'))
|
ivan@155
|
69 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@155
|
70 dictsize = length(DL.param.initdict);
|
ivan@155
|
71 else
|
ivan@155
|
72 dictsize = size(DL.param.initdict,2);
|
ivan@155
|
73 end
|
ivan@155
|
74 end
|
ivan@155
|
75
|
ivan@155
|
76 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
|
ivan@155
|
77 dictsize = DL.param.dictsize;
|
ivan@155
|
78 end
|
ivan@155
|
79
|
ivan@155
|
80 if (size(sig,2) < dictsize)
|
ivan@155
|
81 error('Number of training signals is smaller than number of atoms to train');
|
ivan@155
|
82 end
|
ivan@155
|
83
|
ivan@155
|
84
|
ivan@155
|
85 % initialize the dictionary %
|
ivan@155
|
86
|
ivan@155
|
87 if (isfield(DL.param,'initdict'))
|
ivan@155
|
88 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@155
|
89 dico = sig(:,DL.param.initdict(1:dictsize));
|
ivan@155
|
90 else
|
ivan@155
|
91 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
|
ivan@155
|
92 error('Invalid initial dictionary');
|
ivan@155
|
93 end
|
ivan@155
|
94 dico = DL.param.initdict(:,1:dictsize);
|
ivan@155
|
95 end
|
ivan@155
|
96 else
|
ivan@155
|
97 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
|
ivan@155
|
98 perm = randperm(length(data_ids));
|
ivan@155
|
99 dico = sig(:,data_ids(perm(1:dictsize)));
|
ivan@155
|
100 end
|
ivan@155
|
101
|
ivan@155
|
102
|
ivan@155
|
103 % number of iterations (default is 40) %
|
ivan@155
|
104
|
ivan@155
|
105 if isfield(DL.param,'iternum')
|
ivan@155
|
106 iternum = DL.param.iternum;
|
ivan@155
|
107 else
|
ivan@155
|
108 iternum = 40;
|
ivan@155
|
109 end
|
ivan@155
|
110
|
ivan@155
|
111 % number of iterations (default is 40) %
|
ivan@155
|
112
|
ivan@155
|
113 if isfield(DL.param,'iterDictUpdate')
|
ivan@155
|
114 maxIT = DL.param.iterDictUpdate;
|
ivan@155
|
115 else
|
ivan@155
|
116 maxIT = 1000;
|
ivan@155
|
117 end
|
ivan@155
|
118
|
ivan@155
|
119 % Stopping criterion for MM dictionary update (default = 1e-7)
|
ivan@155
|
120
|
ivan@155
|
121 if isfield(DL.param,'epsDictUpdate')
|
ivan@155
|
122 epsD = DL.param.epsDictUpdate;
|
ivan@155
|
123 else
|
ivan@155
|
124 epsD = 1e-7;
|
ivan@155
|
125 end
|
ivan@155
|
126
|
ivan@155
|
127 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
|
ivan@155
|
128 % (default cvset is o) %
|
ivan@155
|
129
|
ivan@155
|
130 if isfield(DL.param,'cvset')
|
ivan@155
|
131 cvset = DL.param.cvset;
|
ivan@155
|
132 else
|
ivan@155
|
133 cvset = 0;
|
ivan@155
|
134 end
|
ivan@155
|
135
|
ivan@155
|
136 % determine if we should do decorrelation in every iteration %
|
ivan@155
|
137
|
ivan@155
|
138 if isfield(DL.param,'coherence')
|
ivan@155
|
139 decorrelate = 1;
|
ivan@155
|
140 mu = DL.param.coherence;
|
ivan@155
|
141 else
|
ivan@155
|
142 decorrelate = 0;
|
ivan@155
|
143 end
|
ivan@155
|
144
|
ivan@155
|
145 % show dictonary every specified number of iterations
|
ivan@155
|
146
|
ivan@155
|
147 if isfield(DL.param,'show_dict')
|
ivan@155
|
148 show_dictionary = 1;
|
ivan@155
|
149 show_iter = DL.param.show_dict;
|
ivan@155
|
150 else
|
ivan@155
|
151 show_dictionary = 0;
|
ivan@155
|
152 show_iter = 0;
|
ivan@155
|
153 end
|
ivan@155
|
154
|
ivan@155
|
155 % This is a small patch that needs to be resolved in dictionary learning we
|
ivan@155
|
156 % want sparse representation of training set, and in Problem.b1 in this
|
ivan@155
|
157 % version of software we store the signal that needs to be represented
|
ivan@155
|
158 % (for example the whole image)
|
ivan@155
|
159 if isfield(Problem,'b1')
|
ivan@155
|
160 tmpTraining = Problem.b1;
|
ivan@155
|
161 Problem.b1 = sig;
|
ivan@155
|
162 end
|
ivan@155
|
163 if isfield(Problem,'reconstruct')
|
ivan@155
|
164 Problem = rmfield(Problem, 'reconstruct');
|
ivan@155
|
165 end
|
ivan@155
|
166 solver.profile = 0;
|
ivan@155
|
167
|
ivan@155
|
168 % main loop %
|
ivan@155
|
169
|
ivan@155
|
170 for i = 1:iternum
|
ivan@155
|
171 Problem.A = dico;
|
ivan@155
|
172
|
ivan@155
|
173 solver = SMALL_solve(Problem, solver);
|
ivan@155
|
174
|
ivan@155
|
175 switch lower(typeUpdate)
|
ivan@155
|
176 case 'mm_cn'
|
ivan@155
|
177 [dico, solver.solution] = ...
|
ivan@155
|
178 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
|
ivan@155
|
179 case 'mm_fn'
|
ivan@155
|
180 [dico, solver.solution] = ...
|
ivan@155
|
181 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
|
ivan@155
|
182 case 'mod_cn'
|
ivan@155
|
183 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
|
ivan@155
|
184 case 'map_cn'
|
ivan@155
|
185 if isfield(DL.param,'muMAP')
|
ivan@155
|
186 muMAP = DL.param.muMAP;
|
ivan@155
|
187 else
|
ivan@155
|
188 muMAP = 1e-4;
|
ivan@155
|
189 end
|
ivan@155
|
190 [dico, solver.solution] = ...
|
ivan@155
|
191 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
|
ivan@155
|
192 case 'ksvd_cn'
|
ivan@155
|
193 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
|
ivan@155
|
194 otherwise
|
ivan@155
|
195 error('Dictionary update is not defined');
|
ivan@155
|
196 end
|
ivan@155
|
197
|
ivan@155
|
198 % Set previous solution as the best initial guess
|
ivan@155
|
199 % for the next iteration of iterative soft tresholding
|
ivan@155
|
200
|
ivan@155
|
201 if (strcmpi(solver.toolbox, 'MMbox'))
|
ivan@155
|
202 solver.param.initcoeff = solver.solution;
|
ivan@155
|
203 end
|
ivan@155
|
204
|
ivan@155
|
205 % Optional decorrelation of athoms - this is from Boris Mailhe and
|
ivan@155
|
206 % we need to test how it preforms with Mehrdad's updates
|
ivan@155
|
207
|
ivan@155
|
208 if (decorrelate)
|
ivan@155
|
209 dico = dico_decorr(dico, mu, solver.solution);
|
ivan@155
|
210 end
|
ivan@155
|
211
|
ivan@155
|
212 if ((show_dictionary)&&(mod(i,show_iter)==0))
|
ivan@155
|
213 dictimg = SMALL_showdict(dico,[8 8],...
|
ivan@155
|
214 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
|
ivan@155
|
215 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@155
|
216 pause(0.02);
|
ivan@155
|
217 end
|
ivan@155
|
218 end
|
ivan@155
|
219 if isfield(Problem,'b1')
|
ivan@155
|
220 Problem.b1 = tmpTraining;
|
ivan@155
|
221 end
|
ivan@155
|
222 DL.D = dico;
|
ivan@155
|
223
|
ivan@155
|
224 end
|
ivan@155
|
225
|
ivan@155
|
226 function Y = colnorms_squared(X)
|
ivan@155
|
227
|
ivan@155
|
228 % compute in blocks to conserve memory
|
ivan@155
|
229 Y = zeros(1,size(X,2));
|
ivan@155
|
230 blocksize = 2000;
|
ivan@155
|
231 for i = 1:blocksize:size(X,2)
|
ivan@155
|
232 blockids = i : min(i+blocksize-1,size(X,2));
|
ivan@155
|
233 Y(blockids) = sum(X(:,blockids).^2);
|
ivan@155
|
234 end
|
ivan@155
|
235
|
ivan@155
|
236 end |