Mercurial > hg > smallbox
comparison DL/Majorization Minimization DL/wrapper_mm_DL.m @ 155:b14209313ba4 ivand_dev
Integration of Majorization Minimisation Dictionary Learning
author | Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk> |
---|---|
date | Mon, 22 Aug 2011 11:46:35 +0100 |
parents | |
children | 0c7c20f3246c |
comparison
equal
deleted
inserted
replaced
154:0de08f68256b | 155:b14209313ba4 |
---|---|
1 function DL = wrapper_mm_DL(Problem, DL) | |
2 | |
3 % determine which solver is used for sparse representation % | |
4 | |
5 solver = DL.param.solver; | |
6 | |
7 % determine which type of udate to use | |
8 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn', | |
9 % 'MAP_cn', 'KSVD_cn') | |
10 | |
11 typeUpdate = DL.name; | |
12 | |
13 sig = Problem.b; | |
14 | |
15 % determine dictionary size % | |
16 | |
17 if (isfield(DL.param,'initdict')) | |
18 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) | |
19 dictsize = length(DL.param.initdict); | |
20 else | |
21 dictsize = size(DL.param.initdict,2); | |
22 end | |
23 end | |
24 | |
25 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict | |
26 dictsize = DL.param.dictsize; | |
27 end | |
28 | |
29 if (size(sig,2) < dictsize) | |
30 error('Number of training signals is smaller than number of atoms to train'); | |
31 end | |
32 | |
33 | |
34 % initialize the dictionary % | |
35 | |
36 if (isfield(DL.param,'initdict')) | |
37 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) | |
38 dico = sig(:,DL.param.initdict(1:dictsize)); | |
39 else | |
40 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) | |
41 error('Invalid initial dictionary'); | |
42 end | |
43 dico = DL.param.initdict(:,1:dictsize); | |
44 end | |
45 else | |
46 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen | |
47 perm = randperm(length(data_ids)); | |
48 dico = sig(:,data_ids(perm(1:dictsize))); | |
49 end | |
50 | |
51 | |
52 % number of iterations (default is 40) % | |
53 | |
54 if isfield(DL.param,'iternum') | |
55 iternum = DL.param.iternum; | |
56 else | |
57 iternum = 40; | |
58 end | |
59 | |
60 % number of iterations (default is 40) % | |
61 | |
62 if isfield(DL.param,'iterDictUpdate') | |
63 maxIT = DL.param.iterDictUpdate; | |
64 else | |
65 maxIT = 1000; | |
66 end | |
67 | |
68 % Stopping criterion for MM dictionary update (default = 1e-7) | |
69 | |
70 if isfield(DL.param,'epsDictUpdate') | |
71 epsD = DL.param.epsDictUpdate; | |
72 else | |
73 epsD = 1e-7; | |
74 end | |
75 | |
76 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1 | |
77 % (default cvset is o) % | |
78 | |
79 if isfield(DL.param,'cvset') | |
80 cvset = DL.param.cvset; | |
81 else | |
82 cvset = 0; | |
83 end | |
84 | |
85 % determine if we should do decorrelation in every iteration % | |
86 | |
87 if isfield(DL.param,'coherence') | |
88 decorrelate = 1; | |
89 mu = DL.param.coherence; | |
90 else | |
91 decorrelate = 0; | |
92 end | |
93 | |
94 % show dictonary every specified number of iterations | |
95 | |
96 if isfield(DL.param,'show_dict') | |
97 show_dictionary = 1; | |
98 show_iter = DL.param.show_dict; | |
99 else | |
100 show_dictionary = 0; | |
101 show_iter = 0; | |
102 end | |
103 | |
104 % This is a small patch that needs to be resolved in dictionary learning we | |
105 % want sparse representation of training set, and in Problem.b1 in this | |
106 % version of software we store the signal that needs to be represented | |
107 % (for example the whole image) | |
108 if isfield(Problem,'b1') | |
109 tmpTraining = Problem.b1; | |
110 Problem.b1 = sig; | |
111 end | |
112 if isfield(Problem,'reconstruct') | |
113 Problem = rmfield(Problem, 'reconstruct'); | |
114 end | |
115 solver.profile = 0; | |
116 | |
117 % main loop % | |
118 | |
119 for i = 1:iternum | |
120 Problem.A = dico; | |
121 | |
122 solver = SMALL_solve(Problem, solver); | |
123 | |
124 switch lower(typeUpdate) | |
125 case 'mm_cn' | |
126 [dico, solver.solution] = ... | |
127 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset); | |
128 case 'mm_fn' | |
129 [dico, solver.solution] = ... | |
130 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset); | |
131 case 'mod_cn' | |
132 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset); | |
133 case 'map_cn' | |
134 if isfield(DL.param,'muMAP') | |
135 muMAP = DL.param.muMAP; | |
136 else | |
137 muMAP = 1e-4; | |
138 end | |
139 [dico, solver.solution] = ... | |
140 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset); | |
141 case 'ksvd_cn' | |
142 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution); | |
143 otherwise | |
144 error('Dictionary update is not defined'); | |
145 end | |
146 | |
147 % Set previous solution as the best initial guess | |
148 % for the next iteration of iterative soft tresholding | |
149 | |
150 if (strcmpi(solver.toolbox, 'MMbox')) | |
151 solver.param.initcoeff = solver.solution; | |
152 end | |
153 | |
154 % Optional decorrelation of athoms - this is from Boris Mailhe and | |
155 % we need to test how it preforms with Mehrdad's updates | |
156 | |
157 if (decorrelate) | |
158 dico = dico_decorr(dico, mu, solver.solution); | |
159 end | |
160 | |
161 if ((show_dictionary)&&(mod(i,show_iter)==0)) | |
162 dictimg = SMALL_showdict(dico,[8 8],... | |
163 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); | |
164 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; | |
165 pause(0.02); | |
166 end | |
167 end | |
168 if isfield(Problem,'b1') | |
169 Problem.b1 = tmpTraining; | |
170 end | |
171 DL.D = dico; | |
172 | |
173 end | |
174 | |
175 function Y = colnorms_squared(X) | |
176 | |
177 % compute in blocks to conserve memory | |
178 Y = zeros(1,size(X,2)); | |
179 blocksize = 2000; | |
180 for i = 1:blocksize:size(X,2) | |
181 blockids = i : min(i+blocksize-1,size(X,2)); | |
182 Y(blockids) = sum(X(:,blockids).^2); | |
183 end | |
184 | |
185 end |