comparison DL/Majorization Minimization DL/wrapper_mm_DL.m @ 155:b14209313ba4 ivand_dev

Integration of Majorization Minimisation Dictionary Learning
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Mon, 22 Aug 2011 11:46:35 +0100
parents
children 0c7c20f3246c
comparison
equal deleted inserted replaced
154:0de08f68256b 155:b14209313ba4
1 function DL = wrapper_mm_DL(Problem, DL)
2
3 % determine which solver is used for sparse representation %
4
5 solver = DL.param.solver;
6
7 % determine which type of udate to use
8 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
9 % 'MAP_cn', 'KSVD_cn')
10
11 typeUpdate = DL.name;
12
13 sig = Problem.b;
14
15 % determine dictionary size %
16
17 if (isfield(DL.param,'initdict'))
18 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
19 dictsize = length(DL.param.initdict);
20 else
21 dictsize = size(DL.param.initdict,2);
22 end
23 end
24
25 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
26 dictsize = DL.param.dictsize;
27 end
28
29 if (size(sig,2) < dictsize)
30 error('Number of training signals is smaller than number of atoms to train');
31 end
32
33
34 % initialize the dictionary %
35
36 if (isfield(DL.param,'initdict'))
37 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
38 dico = sig(:,DL.param.initdict(1:dictsize));
39 else
40 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
41 error('Invalid initial dictionary');
42 end
43 dico = DL.param.initdict(:,1:dictsize);
44 end
45 else
46 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
47 perm = randperm(length(data_ids));
48 dico = sig(:,data_ids(perm(1:dictsize)));
49 end
50
51
52 % number of iterations (default is 40) %
53
54 if isfield(DL.param,'iternum')
55 iternum = DL.param.iternum;
56 else
57 iternum = 40;
58 end
59
60 % number of iterations (default is 40) %
61
62 if isfield(DL.param,'iterDictUpdate')
63 maxIT = DL.param.iterDictUpdate;
64 else
65 maxIT = 1000;
66 end
67
68 % Stopping criterion for MM dictionary update (default = 1e-7)
69
70 if isfield(DL.param,'epsDictUpdate')
71 epsD = DL.param.epsDictUpdate;
72 else
73 epsD = 1e-7;
74 end
75
76 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
77 % (default cvset is o) %
78
79 if isfield(DL.param,'cvset')
80 cvset = DL.param.cvset;
81 else
82 cvset = 0;
83 end
84
85 % determine if we should do decorrelation in every iteration %
86
87 if isfield(DL.param,'coherence')
88 decorrelate = 1;
89 mu = DL.param.coherence;
90 else
91 decorrelate = 0;
92 end
93
94 % show dictonary every specified number of iterations
95
96 if isfield(DL.param,'show_dict')
97 show_dictionary = 1;
98 show_iter = DL.param.show_dict;
99 else
100 show_dictionary = 0;
101 show_iter = 0;
102 end
103
104 % This is a small patch that needs to be resolved in dictionary learning we
105 % want sparse representation of training set, and in Problem.b1 in this
106 % version of software we store the signal that needs to be represented
107 % (for example the whole image)
108 if isfield(Problem,'b1')
109 tmpTraining = Problem.b1;
110 Problem.b1 = sig;
111 end
112 if isfield(Problem,'reconstruct')
113 Problem = rmfield(Problem, 'reconstruct');
114 end
115 solver.profile = 0;
116
117 % main loop %
118
119 for i = 1:iternum
120 Problem.A = dico;
121
122 solver = SMALL_solve(Problem, solver);
123
124 switch lower(typeUpdate)
125 case 'mm_cn'
126 [dico, solver.solution] = ...
127 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
128 case 'mm_fn'
129 [dico, solver.solution] = ...
130 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
131 case 'mod_cn'
132 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
133 case 'map_cn'
134 if isfield(DL.param,'muMAP')
135 muMAP = DL.param.muMAP;
136 else
137 muMAP = 1e-4;
138 end
139 [dico, solver.solution] = ...
140 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
141 case 'ksvd_cn'
142 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
143 otherwise
144 error('Dictionary update is not defined');
145 end
146
147 % Set previous solution as the best initial guess
148 % for the next iteration of iterative soft tresholding
149
150 if (strcmpi(solver.toolbox, 'MMbox'))
151 solver.param.initcoeff = solver.solution;
152 end
153
154 % Optional decorrelation of athoms - this is from Boris Mailhe and
155 % we need to test how it preforms with Mehrdad's updates
156
157 if (decorrelate)
158 dico = dico_decorr(dico, mu, solver.solution);
159 end
160
161 if ((show_dictionary)&&(mod(i,show_iter)==0))
162 dictimg = SMALL_showdict(dico,[8 8],...
163 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
164 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
165 pause(0.02);
166 end
167 end
168 if isfield(Problem,'b1')
169 Problem.b1 = tmpTraining;
170 end
171 DL.D = dico;
172
173 end
174
175 function Y = colnorms_squared(X)
176
177 % compute in blocks to conserve memory
178 Y = zeros(1,size(X,2));
179 blocksize = 2000;
180 for i = 1:blocksize:size(X,2)
181 blockids = i : min(i+blocksize-1,size(X,2));
182 Y(blockids) = sum(X(:,blockids).^2);
183 end
184
185 end