ivan@155
|
1 function DL = wrapper_mm_DL(Problem, DL)
|
aris@211
|
2 %% SMALL wrapper for Majorization Minimization Dictionary Learning Algorithm
|
aris@211
|
3 %
|
aris@211
|
4 % Function gets as input Problem and Dictionary Learning (DL) structures
|
aris@211
|
5 % and outputs the learned Dictionary.
|
aris@211
|
6
|
aris@211
|
7 % In Problem structure field b with the training set needs to be defined.
|
aris@211
|
8
|
aris@214
|
9 % In DL structure field with name of the Dictionary update method needs
|
aris@214
|
10 % to be present. For the orignal version of MM algorithm the update
|
aris@214
|
11 % method should be:
|
aris@211
|
12 % - 'mm_cn' - Regularized DL with column norm contraint
|
aris@211
|
13 % - 'mm_fn' - Regularized DL with Frobenius norm contraint
|
aris@211
|
14 % Alternatively, for comparison purposes the following Dictioanry update
|
aris@211
|
15 % methods (which do not represent the optimised version of the algorithm)
|
aris@211
|
16 % be used:
|
aris@211
|
17 % - 'mod_cn' - Method of Optimized Direction
|
aris@211
|
18 % - 'map-cn' - Maximum a Posteriory Dictionary update
|
aris@211
|
19 % - 'ksvd-cn'- KSVD update
|
aris@214
|
20 %
|
aris@214
|
21 % DL.param.solver structure is also required. For the original version of
|
aris@214
|
22 % MM algorithm, DL.param.solver.toolbox should be 'MMbox'. The parameters
|
aris@214
|
23 % in DL.param.solver.param should be set accordingly. Type help
|
aris@214
|
24 % wrapper_mm_solver for more details.
|
aris@211
|
25 %
|
aris@211
|
26 % - MM-DL - Yaghoobi, M.; Blumensath, T,; Davies M.; , "Dictionary
|
aris@211
|
27 % Learning for Sparse Approximation with Majorization Method," IEEE
|
aris@211
|
28 % Transactions on Signal Processing, vol.57, no.6, pp.2178-2191, 2009.
|
aris@211
|
29
|
aris@211
|
30 % Centre for Digital Music, Queen Mary, University of London.
|
aris@211
|
31 % This file copyright 2011 Ivan Damnjanovic.
|
aris@211
|
32 %
|
aris@211
|
33 % This program is free software; you can redistribute it and/or
|
aris@211
|
34 % modify it under the terms of the GNU General Public License as
|
aris@211
|
35 % published by the Free Software Foundation; either version 2 of the
|
aris@211
|
36 % License, or (at your option) any later version. See the file
|
aris@211
|
37 % COPYING included with this distribution for more information.
|
aris@211
|
38 %%
|
ivan@155
|
39
|
ivan@155
|
40 % determine which solver is used for sparse representation %
|
ivan@155
|
41
|
ivan@155
|
42 solver = DL.param.solver;
|
ivan@155
|
43
|
ivan@155
|
44 % determine which type of udate to use
|
ivan@155
|
45 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
|
ivan@155
|
46 % 'MAP_cn', 'KSVD_cn')
|
ivan@155
|
47
|
ivan@155
|
48 typeUpdate = DL.name;
|
ivan@155
|
49
|
ivan@155
|
50 sig = Problem.b;
|
ivan@155
|
51
|
ivan@155
|
52 % determine dictionary size %
|
ivan@155
|
53
|
ivan@155
|
54 if (isfield(DL.param,'initdict'))
|
ivan@155
|
55 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@155
|
56 dictsize = length(DL.param.initdict);
|
ivan@155
|
57 else
|
ivan@155
|
58 dictsize = size(DL.param.initdict,2);
|
ivan@155
|
59 end
|
ivan@155
|
60 end
|
ivan@155
|
61
|
ivan@155
|
62 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
|
ivan@155
|
63 dictsize = DL.param.dictsize;
|
ivan@155
|
64 end
|
ivan@155
|
65
|
ivan@155
|
66 if (size(sig,2) < dictsize)
|
ivan@155
|
67 error('Number of training signals is smaller than number of atoms to train');
|
ivan@155
|
68 end
|
ivan@155
|
69
|
ivan@155
|
70
|
ivan@155
|
71 % initialize the dictionary %
|
ivan@155
|
72
|
ivan@155
|
73 if (isfield(DL.param,'initdict'))
|
ivan@155
|
74 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@155
|
75 dico = sig(:,DL.param.initdict(1:dictsize));
|
ivan@155
|
76 else
|
ivan@155
|
77 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
|
ivan@155
|
78 error('Invalid initial dictionary');
|
ivan@155
|
79 end
|
ivan@155
|
80 dico = DL.param.initdict(:,1:dictsize);
|
ivan@155
|
81 end
|
ivan@155
|
82 else
|
ivan@155
|
83 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
|
ivan@155
|
84 perm = randperm(length(data_ids));
|
ivan@155
|
85 dico = sig(:,data_ids(perm(1:dictsize)));
|
ivan@155
|
86 end
|
ivan@155
|
87
|
ivan@155
|
88
|
ivan@155
|
89 % number of iterations (default is 40) %
|
ivan@155
|
90
|
ivan@155
|
91 if isfield(DL.param,'iternum')
|
ivan@155
|
92 iternum = DL.param.iternum;
|
ivan@155
|
93 else
|
ivan@155
|
94 iternum = 40;
|
ivan@155
|
95 end
|
ivan@155
|
96
|
ivan@155
|
97 % number of iterations (default is 40) %
|
ivan@155
|
98
|
ivan@155
|
99 if isfield(DL.param,'iterDictUpdate')
|
ivan@155
|
100 maxIT = DL.param.iterDictUpdate;
|
ivan@155
|
101 else
|
ivan@155
|
102 maxIT = 1000;
|
ivan@155
|
103 end
|
ivan@155
|
104
|
ivan@155
|
105 % Stopping criterion for MM dictionary update (default = 1e-7)
|
ivan@155
|
106
|
ivan@155
|
107 if isfield(DL.param,'epsDictUpdate')
|
ivan@155
|
108 epsD = DL.param.epsDictUpdate;
|
ivan@155
|
109 else
|
ivan@155
|
110 epsD = 1e-7;
|
ivan@155
|
111 end
|
ivan@155
|
112
|
ivan@155
|
113 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
|
ivan@155
|
114 % (default cvset is o) %
|
ivan@155
|
115
|
ivan@155
|
116 if isfield(DL.param,'cvset')
|
ivan@155
|
117 cvset = DL.param.cvset;
|
ivan@155
|
118 else
|
ivan@155
|
119 cvset = 0;
|
ivan@155
|
120 end
|
ivan@155
|
121
|
ivan@155
|
122 % determine if we should do decorrelation in every iteration %
|
ivan@155
|
123
|
ivan@155
|
124 if isfield(DL.param,'coherence')
|
ivan@155
|
125 decorrelate = 1;
|
ivan@155
|
126 mu = DL.param.coherence;
|
ivan@155
|
127 else
|
ivan@155
|
128 decorrelate = 0;
|
ivan@155
|
129 end
|
ivan@155
|
130
|
ivan@155
|
131 % show dictonary every specified number of iterations
|
ivan@155
|
132
|
ivan@155
|
133 if isfield(DL.param,'show_dict')
|
ivan@155
|
134 show_dictionary = 1;
|
ivan@155
|
135 show_iter = DL.param.show_dict;
|
ivan@155
|
136 else
|
ivan@155
|
137 show_dictionary = 0;
|
ivan@155
|
138 show_iter = 0;
|
ivan@155
|
139 end
|
ivan@155
|
140
|
ivan@155
|
141 % This is a small patch that needs to be resolved in dictionary learning we
|
ivan@155
|
142 % want sparse representation of training set, and in Problem.b1 in this
|
ivan@155
|
143 % version of software we store the signal that needs to be represented
|
ivan@155
|
144 % (for example the whole image)
|
ivan@155
|
145 if isfield(Problem,'b1')
|
ivan@155
|
146 tmpTraining = Problem.b1;
|
ivan@155
|
147 Problem.b1 = sig;
|
ivan@155
|
148 end
|
ivan@155
|
149 if isfield(Problem,'reconstruct')
|
ivan@155
|
150 Problem = rmfield(Problem, 'reconstruct');
|
ivan@155
|
151 end
|
ivan@155
|
152 solver.profile = 0;
|
ivan@155
|
153
|
ivan@155
|
154 % main loop %
|
ivan@155
|
155
|
ivan@155
|
156 for i = 1:iternum
|
ivan@155
|
157 Problem.A = dico;
|
ivan@155
|
158
|
ivan@155
|
159 solver = SMALL_solve(Problem, solver);
|
ivan@155
|
160
|
ivan@155
|
161 switch lower(typeUpdate)
|
ivan@155
|
162 case 'mm_cn'
|
ivan@155
|
163 [dico, solver.solution] = ...
|
ivan@155
|
164 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
|
ivan@155
|
165 case 'mm_fn'
|
ivan@155
|
166 [dico, solver.solution] = ...
|
ivan@155
|
167 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
|
ivan@155
|
168 case 'mod_cn'
|
ivan@155
|
169 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
|
ivan@155
|
170 case 'map_cn'
|
ivan@155
|
171 if isfield(DL.param,'muMAP')
|
ivan@155
|
172 muMAP = DL.param.muMAP;
|
ivan@155
|
173 else
|
ivan@155
|
174 muMAP = 1e-4;
|
ivan@155
|
175 end
|
ivan@155
|
176 [dico, solver.solution] = ...
|
ivan@155
|
177 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
|
ivan@155
|
178 case 'ksvd_cn'
|
ivan@155
|
179 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
|
ivan@155
|
180 otherwise
|
ivan@155
|
181 error('Dictionary update is not defined');
|
ivan@155
|
182 end
|
ivan@155
|
183
|
ivan@155
|
184 % Set previous solution as the best initial guess
|
ivan@155
|
185 % for the next iteration of iterative soft tresholding
|
ivan@155
|
186
|
ivan@155
|
187 if (strcmpi(solver.toolbox, 'MMbox'))
|
ivan@155
|
188 solver.param.initcoeff = solver.solution;
|
ivan@155
|
189 end
|
ivan@155
|
190
|
ivan@155
|
191 % Optional decorrelation of athoms - this is from Boris Mailhe and
|
ivan@155
|
192 % we need to test how it preforms with Mehrdad's updates
|
ivan@155
|
193
|
ivan@155
|
194 if (decorrelate)
|
ivan@155
|
195 dico = dico_decorr(dico, mu, solver.solution);
|
ivan@155
|
196 end
|
ivan@155
|
197
|
ivan@155
|
198 if ((show_dictionary)&&(mod(i,show_iter)==0))
|
ivan@155
|
199 dictimg = SMALL_showdict(dico,[8 8],...
|
ivan@155
|
200 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
|
ivan@155
|
201 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@155
|
202 pause(0.02);
|
ivan@155
|
203 end
|
ivan@155
|
204 end
|
ivan@155
|
205 if isfield(Problem,'b1')
|
ivan@155
|
206 Problem.b1 = tmpTraining;
|
ivan@155
|
207 end
|
ivan@155
|
208 DL.D = dico;
|
ivan@155
|
209
|
ivan@155
|
210 end
|
ivan@155
|
211
|
ivan@155
|
212 function Y = colnorms_squared(X)
|
ivan@155
|
213
|
ivan@155
|
214 % compute in blocks to conserve memory
|
ivan@155
|
215 Y = zeros(1,size(X,2));
|
ivan@155
|
216 blocksize = 2000;
|
ivan@155
|
217 for i = 1:blocksize:size(X,2)
|
ivan@155
|
218 blockids = i : min(i+blocksize-1,size(X,2));
|
ivan@155
|
219 Y(blockids) = sum(X(:,blockids).^2);
|
ivan@155
|
220 end
|
ivan@155
|
221
|
ivan@155
|
222 end |