annotate DL/Majorization Minimization DL/wrapper_mm_DL.m @ 234:c96880c0c47c

renamed file.
author luisf <luis.figueira@eecs.qmul.ac.uk>
date Thu, 19 Apr 2012 17:21:05 +0100
parents 4337e28183f1
children
rev   line source
ivan@155 1 function DL = wrapper_mm_DL(Problem, DL)
aris@211 2 %% SMALL wrapper for Majorization Minimization Dictionary Learning Algorithm
aris@211 3 %
aris@211 4 % Function gets as input Problem and Dictionary Learning (DL) structures
aris@211 5 % and outputs the learned Dictionary.
aris@219 6 %
aris@211 7 % In Problem structure field b with the training set needs to be defined.
aris@219 8 %
aris@214 9 % In DL structure field with name of the Dictionary update method needs
aris@214 10 % to be present. For the orignal version of MM algorithm the update
aris@214 11 % method should be:
aris@219 12 % - 'mm_cn' Regularized DL with column norm contraint
aris@219 13 % - 'mm_fn' Regularized DL with Frobenius norm contraint
aris@211 14 % Alternatively, for comparison purposes the following Dictioanry update
aris@211 15 % methods (which do not represent the optimised version of the algorithm)
aris@211 16 % be used:
aris@219 17 % - 'mod_cn' Method of Optimized Direction
aris@219 18 % - 'map-cn' Maximum a Posteriory Dictionary update
aris@219 19 % - 'ksvd-cn' KSVD update
aris@219 20 %
aris@219 21 % The structure DL.param with parameters is also required. These are:
aris@219 22 % - solver structure with fields toolbox, solver and parameters.
aris@219 23 % For the original version of the algorithm toolbox
aris@219 24 % should be 'MMbox' and solver field should be left
aris@219 25 % empty ''. Type HELP WRAPPER_MM_SOLVER for more
aris@219 26 % details on how to set the parameters.
aris@219 27 % - initdict Initial Dictionary
aris@219 28 % - dictsize Dictionary size (optional)
aris@219 29 % - iternum Number of iterations (default is 40)
aris@219 30 % - iterDictUpdate Number of iterations for Dictionary Update (default is 1000)
aris@219 31 % - epsDictUpdate Stopping criterion for MM dictionary update (default = 1e-7)
aris@219 32 % - cvset Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
aris@219 33 % (default is 0)
aris@219 34 % - coherence Set at 1 if to perform decorrelation in every iteration
aris@219 35 % (default is 0)
aris@219 36 % - show_dict Show dictonary every specified number of iterations
aris@219 37 %
aris@211 38 %
aris@211 39 % - MM-DL - Yaghoobi, M.; Blumensath, T,; Davies M.; , "Dictionary
aris@211 40 % Learning for Sparse Approximation with Majorization Method," IEEE
aris@211 41 % Transactions on Signal Processing, vol.57, no.6, pp.2178-2191, 2009.
aris@211 42
aris@219 43 %
aris@211 44 % Centre for Digital Music, Queen Mary, University of London.
aris@211 45 % This file copyright 2011 Ivan Damnjanovic.
aris@211 46 %
aris@211 47 % This program is free software; you can redistribute it and/or
aris@211 48 % modify it under the terms of the GNU General Public License as
aris@211 49 % published by the Free Software Foundation; either version 2 of the
aris@211 50 % License, or (at your option) any later version. See the file
aris@211 51 % COPYING included with this distribution for more information.
aris@211 52 %%
ivan@155 53
ivan@155 54 % determine which solver is used for sparse representation %
ivan@155 55
ivan@155 56 solver = DL.param.solver;
ivan@155 57
ivan@155 58 % determine which type of udate to use
ivan@155 59 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
ivan@155 60 % 'MAP_cn', 'KSVD_cn')
ivan@155 61
ivan@155 62 typeUpdate = DL.name;
ivan@155 63
ivan@155 64 sig = Problem.b;
ivan@155 65
ivan@155 66 % determine dictionary size %
ivan@155 67
ivan@155 68 if (isfield(DL.param,'initdict'))
ivan@155 69 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@155 70 dictsize = length(DL.param.initdict);
ivan@155 71 else
ivan@155 72 dictsize = size(DL.param.initdict,2);
ivan@155 73 end
ivan@155 74 end
ivan@155 75
ivan@155 76 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
ivan@155 77 dictsize = DL.param.dictsize;
ivan@155 78 end
ivan@155 79
ivan@155 80 if (size(sig,2) < dictsize)
ivan@155 81 error('Number of training signals is smaller than number of atoms to train');
ivan@155 82 end
ivan@155 83
ivan@155 84
ivan@155 85 % initialize the dictionary %
ivan@155 86
ivan@155 87 if (isfield(DL.param,'initdict'))
ivan@155 88 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@155 89 dico = sig(:,DL.param.initdict(1:dictsize));
ivan@155 90 else
ivan@155 91 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
ivan@155 92 error('Invalid initial dictionary');
ivan@155 93 end
ivan@155 94 dico = DL.param.initdict(:,1:dictsize);
ivan@155 95 end
ivan@155 96 else
ivan@155 97 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
ivan@155 98 perm = randperm(length(data_ids));
ivan@155 99 dico = sig(:,data_ids(perm(1:dictsize)));
ivan@155 100 end
ivan@155 101
ivan@155 102
ivan@155 103 % number of iterations (default is 40) %
ivan@155 104
ivan@155 105 if isfield(DL.param,'iternum')
ivan@155 106 iternum = DL.param.iternum;
ivan@155 107 else
ivan@155 108 iternum = 40;
ivan@155 109 end
ivan@155 110
ivan@155 111 % number of iterations (default is 40) %
ivan@155 112
ivan@155 113 if isfield(DL.param,'iterDictUpdate')
ivan@155 114 maxIT = DL.param.iterDictUpdate;
ivan@155 115 else
ivan@155 116 maxIT = 1000;
ivan@155 117 end
ivan@155 118
ivan@155 119 % Stopping criterion for MM dictionary update (default = 1e-7)
ivan@155 120
ivan@155 121 if isfield(DL.param,'epsDictUpdate')
ivan@155 122 epsD = DL.param.epsDictUpdate;
ivan@155 123 else
ivan@155 124 epsD = 1e-7;
ivan@155 125 end
ivan@155 126
ivan@155 127 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
ivan@155 128 % (default cvset is o) %
ivan@155 129
ivan@155 130 if isfield(DL.param,'cvset')
ivan@155 131 cvset = DL.param.cvset;
ivan@155 132 else
ivan@155 133 cvset = 0;
ivan@155 134 end
ivan@155 135
ivan@155 136 % determine if we should do decorrelation in every iteration %
ivan@155 137
ivan@155 138 if isfield(DL.param,'coherence')
ivan@155 139 decorrelate = 1;
ivan@155 140 mu = DL.param.coherence;
ivan@155 141 else
ivan@155 142 decorrelate = 0;
ivan@155 143 end
ivan@155 144
ivan@155 145 % show dictonary every specified number of iterations
ivan@155 146
ivan@155 147 if isfield(DL.param,'show_dict')
ivan@155 148 show_dictionary = 1;
ivan@155 149 show_iter = DL.param.show_dict;
ivan@155 150 else
ivan@155 151 show_dictionary = 0;
ivan@155 152 show_iter = 0;
ivan@155 153 end
ivan@155 154
ivan@155 155 % This is a small patch that needs to be resolved in dictionary learning we
ivan@155 156 % want sparse representation of training set, and in Problem.b1 in this
ivan@155 157 % version of software we store the signal that needs to be represented
ivan@155 158 % (for example the whole image)
ivan@155 159 if isfield(Problem,'b1')
ivan@155 160 tmpTraining = Problem.b1;
ivan@155 161 Problem.b1 = sig;
ivan@155 162 end
ivan@155 163 if isfield(Problem,'reconstruct')
ivan@155 164 Problem = rmfield(Problem, 'reconstruct');
ivan@155 165 end
ivan@155 166 solver.profile = 0;
ivan@155 167
ivan@155 168 % main loop %
ivan@155 169
ivan@155 170 for i = 1:iternum
ivan@155 171 Problem.A = dico;
ivan@155 172
ivan@155 173 solver = SMALL_solve(Problem, solver);
ivan@155 174
ivan@155 175 switch lower(typeUpdate)
ivan@155 176 case 'mm_cn'
ivan@155 177 [dico, solver.solution] = ...
ivan@155 178 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
ivan@155 179 case 'mm_fn'
ivan@155 180 [dico, solver.solution] = ...
ivan@155 181 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
ivan@155 182 case 'mod_cn'
ivan@155 183 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
ivan@155 184 case 'map_cn'
ivan@155 185 if isfield(DL.param,'muMAP')
ivan@155 186 muMAP = DL.param.muMAP;
ivan@155 187 else
ivan@155 188 muMAP = 1e-4;
ivan@155 189 end
ivan@155 190 [dico, solver.solution] = ...
ivan@155 191 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
ivan@155 192 case 'ksvd_cn'
ivan@155 193 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
ivan@155 194 otherwise
ivan@155 195 error('Dictionary update is not defined');
ivan@155 196 end
ivan@155 197
ivan@155 198 % Set previous solution as the best initial guess
ivan@155 199 % for the next iteration of iterative soft tresholding
ivan@155 200
ivan@155 201 if (strcmpi(solver.toolbox, 'MMbox'))
ivan@155 202 solver.param.initcoeff = solver.solution;
ivan@155 203 end
ivan@155 204
ivan@155 205 % Optional decorrelation of athoms - this is from Boris Mailhe and
ivan@155 206 % we need to test how it preforms with Mehrdad's updates
ivan@155 207
ivan@155 208 if (decorrelate)
ivan@155 209 dico = dico_decorr(dico, mu, solver.solution);
ivan@155 210 end
ivan@155 211
ivan@155 212 if ((show_dictionary)&&(mod(i,show_iter)==0))
ivan@155 213 dictimg = SMALL_showdict(dico,[8 8],...
ivan@155 214 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
ivan@155 215 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@155 216 pause(0.02);
ivan@155 217 end
ivan@155 218 end
ivan@155 219 if isfield(Problem,'b1')
ivan@155 220 Problem.b1 = tmpTraining;
ivan@155 221 end
ivan@155 222 DL.D = dico;
ivan@155 223
ivan@155 224 end
ivan@155 225
ivan@155 226 function Y = colnorms_squared(X)
ivan@155 227
ivan@155 228 % compute in blocks to conserve memory
ivan@155 229 Y = zeros(1,size(X,2));
ivan@155 230 blocksize = 2000;
ivan@155 231 for i = 1:blocksize:size(X,2)
ivan@155 232 blockids = i : min(i+blocksize-1,size(X,2));
ivan@155 233 Y(blockids) = sum(X(:,blockids).^2);
ivan@155 234 end
ivan@155 235
ivan@155 236 end