ivan@155
|
1 function DL = wrapper_mm_DL(Problem, DL)
|
ivan@155
|
2
|
ivan@155
|
3 % determine which solver is used for sparse representation %
|
ivan@155
|
4
|
ivan@155
|
5 solver = DL.param.solver;
|
ivan@155
|
6
|
ivan@155
|
7 % determine which type of udate to use
|
ivan@155
|
8 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
|
ivan@155
|
9 % 'MAP_cn', 'KSVD_cn')
|
ivan@155
|
10
|
ivan@155
|
11 typeUpdate = DL.name;
|
ivan@155
|
12
|
ivan@155
|
13 sig = Problem.b;
|
ivan@155
|
14
|
ivan@155
|
15 % determine dictionary size %
|
ivan@155
|
16
|
ivan@155
|
17 if (isfield(DL.param,'initdict'))
|
ivan@155
|
18 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@155
|
19 dictsize = length(DL.param.initdict);
|
ivan@155
|
20 else
|
ivan@155
|
21 dictsize = size(DL.param.initdict,2);
|
ivan@155
|
22 end
|
ivan@155
|
23 end
|
ivan@155
|
24
|
ivan@155
|
25 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
|
ivan@155
|
26 dictsize = DL.param.dictsize;
|
ivan@155
|
27 end
|
ivan@155
|
28
|
ivan@155
|
29 if (size(sig,2) < dictsize)
|
ivan@155
|
30 error('Number of training signals is smaller than number of atoms to train');
|
ivan@155
|
31 end
|
ivan@155
|
32
|
ivan@155
|
33
|
ivan@155
|
34 % initialize the dictionary %
|
ivan@155
|
35
|
ivan@155
|
36 if (isfield(DL.param,'initdict'))
|
ivan@155
|
37 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@155
|
38 dico = sig(:,DL.param.initdict(1:dictsize));
|
ivan@155
|
39 else
|
ivan@155
|
40 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
|
ivan@155
|
41 error('Invalid initial dictionary');
|
ivan@155
|
42 end
|
ivan@155
|
43 dico = DL.param.initdict(:,1:dictsize);
|
ivan@155
|
44 end
|
ivan@155
|
45 else
|
ivan@155
|
46 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
|
ivan@155
|
47 perm = randperm(length(data_ids));
|
ivan@155
|
48 dico = sig(:,data_ids(perm(1:dictsize)));
|
ivan@155
|
49 end
|
ivan@155
|
50
|
ivan@155
|
51
|
ivan@155
|
52 % number of iterations (default is 40) %
|
ivan@155
|
53
|
ivan@155
|
54 if isfield(DL.param,'iternum')
|
ivan@155
|
55 iternum = DL.param.iternum;
|
ivan@155
|
56 else
|
ivan@155
|
57 iternum = 40;
|
ivan@155
|
58 end
|
ivan@155
|
59
|
ivan@155
|
60 % number of iterations (default is 40) %
|
ivan@155
|
61
|
ivan@155
|
62 if isfield(DL.param,'iterDictUpdate')
|
ivan@155
|
63 maxIT = DL.param.iterDictUpdate;
|
ivan@155
|
64 else
|
ivan@155
|
65 maxIT = 1000;
|
ivan@155
|
66 end
|
ivan@155
|
67
|
ivan@155
|
68 % Stopping criterion for MM dictionary update (default = 1e-7)
|
ivan@155
|
69
|
ivan@155
|
70 if isfield(DL.param,'epsDictUpdate')
|
ivan@155
|
71 epsD = DL.param.epsDictUpdate;
|
ivan@155
|
72 else
|
ivan@155
|
73 epsD = 1e-7;
|
ivan@155
|
74 end
|
ivan@155
|
75
|
ivan@155
|
76 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
|
ivan@155
|
77 % (default cvset is o) %
|
ivan@155
|
78
|
ivan@155
|
79 if isfield(DL.param,'cvset')
|
ivan@155
|
80 cvset = DL.param.cvset;
|
ivan@155
|
81 else
|
ivan@155
|
82 cvset = 0;
|
ivan@155
|
83 end
|
ivan@155
|
84
|
ivan@155
|
85 % determine if we should do decorrelation in every iteration %
|
ivan@155
|
86
|
ivan@155
|
87 if isfield(DL.param,'coherence')
|
ivan@155
|
88 decorrelate = 1;
|
ivan@155
|
89 mu = DL.param.coherence;
|
ivan@155
|
90 else
|
ivan@155
|
91 decorrelate = 0;
|
ivan@155
|
92 end
|
ivan@155
|
93
|
ivan@155
|
94 % show dictonary every specified number of iterations
|
ivan@155
|
95
|
ivan@155
|
96 if isfield(DL.param,'show_dict')
|
ivan@155
|
97 show_dictionary = 1;
|
ivan@155
|
98 show_iter = DL.param.show_dict;
|
ivan@155
|
99 else
|
ivan@155
|
100 show_dictionary = 0;
|
ivan@155
|
101 show_iter = 0;
|
ivan@155
|
102 end
|
ivan@155
|
103
|
ivan@155
|
104 % This is a small patch that needs to be resolved in dictionary learning we
|
ivan@155
|
105 % want sparse representation of training set, and in Problem.b1 in this
|
ivan@155
|
106 % version of software we store the signal that needs to be represented
|
ivan@155
|
107 % (for example the whole image)
|
ivan@155
|
108 if isfield(Problem,'b1')
|
ivan@155
|
109 tmpTraining = Problem.b1;
|
ivan@155
|
110 Problem.b1 = sig;
|
ivan@155
|
111 end
|
ivan@155
|
112 if isfield(Problem,'reconstruct')
|
ivan@155
|
113 Problem = rmfield(Problem, 'reconstruct');
|
ivan@155
|
114 end
|
ivan@155
|
115 solver.profile = 0;
|
ivan@155
|
116
|
ivan@155
|
117 % main loop %
|
ivan@155
|
118
|
ivan@155
|
119 for i = 1:iternum
|
ivan@155
|
120 Problem.A = dico;
|
ivan@155
|
121
|
ivan@155
|
122 solver = SMALL_solve(Problem, solver);
|
ivan@155
|
123
|
ivan@155
|
124 switch lower(typeUpdate)
|
ivan@155
|
125 case 'mm_cn'
|
ivan@155
|
126 [dico, solver.solution] = ...
|
ivan@155
|
127 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
|
ivan@155
|
128 case 'mm_fn'
|
ivan@155
|
129 [dico, solver.solution] = ...
|
ivan@155
|
130 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
|
ivan@155
|
131 case 'mod_cn'
|
ivan@155
|
132 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
|
ivan@155
|
133 case 'map_cn'
|
ivan@155
|
134 if isfield(DL.param,'muMAP')
|
ivan@155
|
135 muMAP = DL.param.muMAP;
|
ivan@155
|
136 else
|
ivan@155
|
137 muMAP = 1e-4;
|
ivan@155
|
138 end
|
ivan@155
|
139 [dico, solver.solution] = ...
|
ivan@155
|
140 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
|
ivan@155
|
141 case 'ksvd_cn'
|
ivan@155
|
142 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
|
ivan@155
|
143 otherwise
|
ivan@155
|
144 error('Dictionary update is not defined');
|
ivan@155
|
145 end
|
ivan@155
|
146
|
ivan@155
|
147 % Set previous solution as the best initial guess
|
ivan@155
|
148 % for the next iteration of iterative soft tresholding
|
ivan@155
|
149
|
ivan@155
|
150 if (strcmpi(solver.toolbox, 'MMbox'))
|
ivan@155
|
151 solver.param.initcoeff = solver.solution;
|
ivan@155
|
152 end
|
ivan@155
|
153
|
ivan@155
|
154 % Optional decorrelation of athoms - this is from Boris Mailhe and
|
ivan@155
|
155 % we need to test how it preforms with Mehrdad's updates
|
ivan@155
|
156
|
ivan@155
|
157 if (decorrelate)
|
ivan@155
|
158 dico = dico_decorr(dico, mu, solver.solution);
|
ivan@155
|
159 end
|
ivan@155
|
160
|
ivan@155
|
161 if ((show_dictionary)&&(mod(i,show_iter)==0))
|
ivan@155
|
162 dictimg = SMALL_showdict(dico,[8 8],...
|
ivan@155
|
163 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
|
ivan@155
|
164 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@155
|
165 pause(0.02);
|
ivan@155
|
166 end
|
ivan@155
|
167 end
|
ivan@155
|
168 if isfield(Problem,'b1')
|
ivan@155
|
169 Problem.b1 = tmpTraining;
|
ivan@155
|
170 end
|
ivan@155
|
171 DL.D = dico;
|
ivan@155
|
172
|
ivan@155
|
173 end
|
ivan@155
|
174
|
ivan@155
|
175 function Y = colnorms_squared(X)
|
ivan@155
|
176
|
ivan@155
|
177 % compute in blocks to conserve memory
|
ivan@155
|
178 Y = zeros(1,size(X,2));
|
ivan@155
|
179 blocksize = 2000;
|
ivan@155
|
180 for i = 1:blocksize:size(X,2)
|
ivan@155
|
181 blockids = i : min(i+blocksize-1,size(X,2));
|
ivan@155
|
182 Y(blockids) = sum(X(:,blockids).^2);
|
ivan@155
|
183 end
|
ivan@155
|
184
|
ivan@155
|
185 end |