annotate DL/Majorization Minimization DL/wrapper_mm_DL.m @ 214:b9b4dc87f1aa luisf_dev

Additional help comments in ~/DL/Majorization Minimization DL/wrapper_mm_DL.m.
author Aris Gretsistas <aris.gretsistas@elec.qmul.ac.uk>
date Wed, 21 Mar 2012 18:27:23 +0000
parents 0c7c20f3246c
children 4337e28183f1
rev   line source
ivan@155 1 function DL = wrapper_mm_DL(Problem, DL)
aris@211 2 %% SMALL wrapper for Majorization Minimization Dictionary Learning Algorithm
aris@211 3 %
aris@211 4 % Function gets as input Problem and Dictionary Learning (DL) structures
aris@211 5 % and outputs the learned Dictionary.
aris@211 6
aris@211 7 % In Problem structure field b with the training set needs to be defined.
aris@211 8
aris@214 9 % In DL structure field with name of the Dictionary update method needs
aris@214 10 % to be present. For the orignal version of MM algorithm the update
aris@214 11 % method should be:
aris@211 12 % - 'mm_cn' - Regularized DL with column norm contraint
aris@211 13 % - 'mm_fn' - Regularized DL with Frobenius norm contraint
aris@211 14 % Alternatively, for comparison purposes the following Dictioanry update
aris@211 15 % methods (which do not represent the optimised version of the algorithm)
aris@211 16 % be used:
aris@211 17 % - 'mod_cn' - Method of Optimized Direction
aris@211 18 % - 'map-cn' - Maximum a Posteriory Dictionary update
aris@211 19 % - 'ksvd-cn'- KSVD update
aris@214 20 %
aris@214 21 % DL.param.solver structure is also required. For the original version of
aris@214 22 % MM algorithm, DL.param.solver.toolbox should be 'MMbox'. The parameters
aris@214 23 % in DL.param.solver.param should be set accordingly. Type help
aris@214 24 % wrapper_mm_solver for more details.
aris@211 25 %
aris@211 26 % - MM-DL - Yaghoobi, M.; Blumensath, T,; Davies M.; , "Dictionary
aris@211 27 % Learning for Sparse Approximation with Majorization Method," IEEE
aris@211 28 % Transactions on Signal Processing, vol.57, no.6, pp.2178-2191, 2009.
aris@211 29
aris@211 30 % Centre for Digital Music, Queen Mary, University of London.
aris@211 31 % This file copyright 2011 Ivan Damnjanovic.
aris@211 32 %
aris@211 33 % This program is free software; you can redistribute it and/or
aris@211 34 % modify it under the terms of the GNU General Public License as
aris@211 35 % published by the Free Software Foundation; either version 2 of the
aris@211 36 % License, or (at your option) any later version. See the file
aris@211 37 % COPYING included with this distribution for more information.
aris@211 38 %%
ivan@155 39
ivan@155 40 % determine which solver is used for sparse representation %
ivan@155 41
ivan@155 42 solver = DL.param.solver;
ivan@155 43
ivan@155 44 % determine which type of udate to use
ivan@155 45 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
ivan@155 46 % 'MAP_cn', 'KSVD_cn')
ivan@155 47
ivan@155 48 typeUpdate = DL.name;
ivan@155 49
ivan@155 50 sig = Problem.b;
ivan@155 51
ivan@155 52 % determine dictionary size %
ivan@155 53
ivan@155 54 if (isfield(DL.param,'initdict'))
ivan@155 55 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@155 56 dictsize = length(DL.param.initdict);
ivan@155 57 else
ivan@155 58 dictsize = size(DL.param.initdict,2);
ivan@155 59 end
ivan@155 60 end
ivan@155 61
ivan@155 62 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
ivan@155 63 dictsize = DL.param.dictsize;
ivan@155 64 end
ivan@155 65
ivan@155 66 if (size(sig,2) < dictsize)
ivan@155 67 error('Number of training signals is smaller than number of atoms to train');
ivan@155 68 end
ivan@155 69
ivan@155 70
ivan@155 71 % initialize the dictionary %
ivan@155 72
ivan@155 73 if (isfield(DL.param,'initdict'))
ivan@155 74 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@155 75 dico = sig(:,DL.param.initdict(1:dictsize));
ivan@155 76 else
ivan@155 77 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
ivan@155 78 error('Invalid initial dictionary');
ivan@155 79 end
ivan@155 80 dico = DL.param.initdict(:,1:dictsize);
ivan@155 81 end
ivan@155 82 else
ivan@155 83 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
ivan@155 84 perm = randperm(length(data_ids));
ivan@155 85 dico = sig(:,data_ids(perm(1:dictsize)));
ivan@155 86 end
ivan@155 87
ivan@155 88
ivan@155 89 % number of iterations (default is 40) %
ivan@155 90
ivan@155 91 if isfield(DL.param,'iternum')
ivan@155 92 iternum = DL.param.iternum;
ivan@155 93 else
ivan@155 94 iternum = 40;
ivan@155 95 end
ivan@155 96
ivan@155 97 % number of iterations (default is 40) %
ivan@155 98
ivan@155 99 if isfield(DL.param,'iterDictUpdate')
ivan@155 100 maxIT = DL.param.iterDictUpdate;
ivan@155 101 else
ivan@155 102 maxIT = 1000;
ivan@155 103 end
ivan@155 104
ivan@155 105 % Stopping criterion for MM dictionary update (default = 1e-7)
ivan@155 106
ivan@155 107 if isfield(DL.param,'epsDictUpdate')
ivan@155 108 epsD = DL.param.epsDictUpdate;
ivan@155 109 else
ivan@155 110 epsD = 1e-7;
ivan@155 111 end
ivan@155 112
ivan@155 113 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
ivan@155 114 % (default cvset is o) %
ivan@155 115
ivan@155 116 if isfield(DL.param,'cvset')
ivan@155 117 cvset = DL.param.cvset;
ivan@155 118 else
ivan@155 119 cvset = 0;
ivan@155 120 end
ivan@155 121
ivan@155 122 % determine if we should do decorrelation in every iteration %
ivan@155 123
ivan@155 124 if isfield(DL.param,'coherence')
ivan@155 125 decorrelate = 1;
ivan@155 126 mu = DL.param.coherence;
ivan@155 127 else
ivan@155 128 decorrelate = 0;
ivan@155 129 end
ivan@155 130
ivan@155 131 % show dictonary every specified number of iterations
ivan@155 132
ivan@155 133 if isfield(DL.param,'show_dict')
ivan@155 134 show_dictionary = 1;
ivan@155 135 show_iter = DL.param.show_dict;
ivan@155 136 else
ivan@155 137 show_dictionary = 0;
ivan@155 138 show_iter = 0;
ivan@155 139 end
ivan@155 140
ivan@155 141 % This is a small patch that needs to be resolved in dictionary learning we
ivan@155 142 % want sparse representation of training set, and in Problem.b1 in this
ivan@155 143 % version of software we store the signal that needs to be represented
ivan@155 144 % (for example the whole image)
ivan@155 145 if isfield(Problem,'b1')
ivan@155 146 tmpTraining = Problem.b1;
ivan@155 147 Problem.b1 = sig;
ivan@155 148 end
ivan@155 149 if isfield(Problem,'reconstruct')
ivan@155 150 Problem = rmfield(Problem, 'reconstruct');
ivan@155 151 end
ivan@155 152 solver.profile = 0;
ivan@155 153
ivan@155 154 % main loop %
ivan@155 155
ivan@155 156 for i = 1:iternum
ivan@155 157 Problem.A = dico;
ivan@155 158
ivan@155 159 solver = SMALL_solve(Problem, solver);
ivan@155 160
ivan@155 161 switch lower(typeUpdate)
ivan@155 162 case 'mm_cn'
ivan@155 163 [dico, solver.solution] = ...
ivan@155 164 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
ivan@155 165 case 'mm_fn'
ivan@155 166 [dico, solver.solution] = ...
ivan@155 167 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
ivan@155 168 case 'mod_cn'
ivan@155 169 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
ivan@155 170 case 'map_cn'
ivan@155 171 if isfield(DL.param,'muMAP')
ivan@155 172 muMAP = DL.param.muMAP;
ivan@155 173 else
ivan@155 174 muMAP = 1e-4;
ivan@155 175 end
ivan@155 176 [dico, solver.solution] = ...
ivan@155 177 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
ivan@155 178 case 'ksvd_cn'
ivan@155 179 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
ivan@155 180 otherwise
ivan@155 181 error('Dictionary update is not defined');
ivan@155 182 end
ivan@155 183
ivan@155 184 % Set previous solution as the best initial guess
ivan@155 185 % for the next iteration of iterative soft tresholding
ivan@155 186
ivan@155 187 if (strcmpi(solver.toolbox, 'MMbox'))
ivan@155 188 solver.param.initcoeff = solver.solution;
ivan@155 189 end
ivan@155 190
ivan@155 191 % Optional decorrelation of athoms - this is from Boris Mailhe and
ivan@155 192 % we need to test how it preforms with Mehrdad's updates
ivan@155 193
ivan@155 194 if (decorrelate)
ivan@155 195 dico = dico_decorr(dico, mu, solver.solution);
ivan@155 196 end
ivan@155 197
ivan@155 198 if ((show_dictionary)&&(mod(i,show_iter)==0))
ivan@155 199 dictimg = SMALL_showdict(dico,[8 8],...
ivan@155 200 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
ivan@155 201 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@155 202 pause(0.02);
ivan@155 203 end
ivan@155 204 end
ivan@155 205 if isfield(Problem,'b1')
ivan@155 206 Problem.b1 = tmpTraining;
ivan@155 207 end
ivan@155 208 DL.D = dico;
ivan@155 209
ivan@155 210 end
ivan@155 211
ivan@155 212 function Y = colnorms_squared(X)
ivan@155 213
ivan@155 214 % compute in blocks to conserve memory
ivan@155 215 Y = zeros(1,size(X,2));
ivan@155 216 blocksize = 2000;
ivan@155 217 for i = 1:blocksize:size(X,2)
ivan@155 218 blockids = i : min(i+blocksize-1,size(X,2));
ivan@155 219 Y(blockids) = sum(X(:,blockids).^2);
ivan@155 220 end
ivan@155 221
ivan@155 222 end