annotate DL/Majorization Minimization DL/wrapper_mm_DL.m @ 211:0c7c20f3246c luisf_dev

Added comments to ~/DL/Majorization Minimization DL/wrapper_mm_DL.m and ~/DL/Majorization Minimization DL/wrapper_mm_solver.m files.
author Aris Gretsistas <aris.gretsistas@elec.qmul.ac.uk>
date Wed, 21 Mar 2012 17:48:39 +0000
parents b14209313ba4
children b9b4dc87f1aa
rev   line source
ivan@155 1 function DL = wrapper_mm_DL(Problem, DL)
aris@211 2 %% SMALL wrapper for Majorization Minimization Dictionary Learning Algorithm
aris@211 3 %
aris@211 4 % Function gets as input Problem and Dictionary Learning (DL) structures
aris@211 5 % and outputs the learned Dictionary.
aris@211 6
aris@211 7 % In Problem structure field b with the training set needs to be defined.
aris@211 8
aris@211 9 % In DL fields with name of the Dictionary update method and parameters
aris@211 10 % for particular dictionary learning technique need to be present. For
aris@211 11 % the orignal version of MM algorithm the update method should be:
aris@211 12 % - 'mm_cn' - Regularized DL with column norm contraint
aris@211 13 % - 'mm_fn' - Regularized DL with Frobenius norm contraint
aris@211 14 % Alternatively, for comparison purposes the following Dictioanry update
aris@211 15 % methods (which do not represent the optimised version of the algorithm)
aris@211 16 % be used:
aris@211 17 % - 'mod_cn' - Method of Optimized Direction
aris@211 18 % - 'map-cn' - Maximum a Posteriory Dictionary update
aris@211 19 % - 'ksvd-cn'- KSVD update
aris@211 20 %
aris@211 21 % - MM-DL - Yaghoobi, M.; Blumensath, T,; Davies M.; , "Dictionary
aris@211 22 % Learning for Sparse Approximation with Majorization Method," IEEE
aris@211 23 % Transactions on Signal Processing, vol.57, no.6, pp.2178-2191, 2009.
aris@211 24
aris@211 25 % Centre for Digital Music, Queen Mary, University of London.
aris@211 26 % This file copyright 2011 Ivan Damnjanovic.
aris@211 27 %
aris@211 28 % This program is free software; you can redistribute it and/or
aris@211 29 % modify it under the terms of the GNU General Public License as
aris@211 30 % published by the Free Software Foundation; either version 2 of the
aris@211 31 % License, or (at your option) any later version. See the file
aris@211 32 % COPYING included with this distribution for more information.
aris@211 33 %%
ivan@155 34
ivan@155 35 % determine which solver is used for sparse representation %
ivan@155 36
ivan@155 37 solver = DL.param.solver;
ivan@155 38
ivan@155 39 % determine which type of udate to use
ivan@155 40 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
ivan@155 41 % 'MAP_cn', 'KSVD_cn')
ivan@155 42
ivan@155 43 typeUpdate = DL.name;
ivan@155 44
ivan@155 45 sig = Problem.b;
ivan@155 46
ivan@155 47 % determine dictionary size %
ivan@155 48
ivan@155 49 if (isfield(DL.param,'initdict'))
ivan@155 50 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@155 51 dictsize = length(DL.param.initdict);
ivan@155 52 else
ivan@155 53 dictsize = size(DL.param.initdict,2);
ivan@155 54 end
ivan@155 55 end
ivan@155 56
ivan@155 57 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
ivan@155 58 dictsize = DL.param.dictsize;
ivan@155 59 end
ivan@155 60
ivan@155 61 if (size(sig,2) < dictsize)
ivan@155 62 error('Number of training signals is smaller than number of atoms to train');
ivan@155 63 end
ivan@155 64
ivan@155 65
ivan@155 66 % initialize the dictionary %
ivan@155 67
ivan@155 68 if (isfield(DL.param,'initdict'))
ivan@155 69 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@155 70 dico = sig(:,DL.param.initdict(1:dictsize));
ivan@155 71 else
ivan@155 72 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
ivan@155 73 error('Invalid initial dictionary');
ivan@155 74 end
ivan@155 75 dico = DL.param.initdict(:,1:dictsize);
ivan@155 76 end
ivan@155 77 else
ivan@155 78 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
ivan@155 79 perm = randperm(length(data_ids));
ivan@155 80 dico = sig(:,data_ids(perm(1:dictsize)));
ivan@155 81 end
ivan@155 82
ivan@155 83
ivan@155 84 % number of iterations (default is 40) %
ivan@155 85
ivan@155 86 if isfield(DL.param,'iternum')
ivan@155 87 iternum = DL.param.iternum;
ivan@155 88 else
ivan@155 89 iternum = 40;
ivan@155 90 end
ivan@155 91
ivan@155 92 % number of iterations (default is 40) %
ivan@155 93
ivan@155 94 if isfield(DL.param,'iterDictUpdate')
ivan@155 95 maxIT = DL.param.iterDictUpdate;
ivan@155 96 else
ivan@155 97 maxIT = 1000;
ivan@155 98 end
ivan@155 99
ivan@155 100 % Stopping criterion for MM dictionary update (default = 1e-7)
ivan@155 101
ivan@155 102 if isfield(DL.param,'epsDictUpdate')
ivan@155 103 epsD = DL.param.epsDictUpdate;
ivan@155 104 else
ivan@155 105 epsD = 1e-7;
ivan@155 106 end
ivan@155 107
ivan@155 108 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
ivan@155 109 % (default cvset is o) %
ivan@155 110
ivan@155 111 if isfield(DL.param,'cvset')
ivan@155 112 cvset = DL.param.cvset;
ivan@155 113 else
ivan@155 114 cvset = 0;
ivan@155 115 end
ivan@155 116
ivan@155 117 % determine if we should do decorrelation in every iteration %
ivan@155 118
ivan@155 119 if isfield(DL.param,'coherence')
ivan@155 120 decorrelate = 1;
ivan@155 121 mu = DL.param.coherence;
ivan@155 122 else
ivan@155 123 decorrelate = 0;
ivan@155 124 end
ivan@155 125
ivan@155 126 % show dictonary every specified number of iterations
ivan@155 127
ivan@155 128 if isfield(DL.param,'show_dict')
ivan@155 129 show_dictionary = 1;
ivan@155 130 show_iter = DL.param.show_dict;
ivan@155 131 else
ivan@155 132 show_dictionary = 0;
ivan@155 133 show_iter = 0;
ivan@155 134 end
ivan@155 135
ivan@155 136 % This is a small patch that needs to be resolved in dictionary learning we
ivan@155 137 % want sparse representation of training set, and in Problem.b1 in this
ivan@155 138 % version of software we store the signal that needs to be represented
ivan@155 139 % (for example the whole image)
ivan@155 140 if isfield(Problem,'b1')
ivan@155 141 tmpTraining = Problem.b1;
ivan@155 142 Problem.b1 = sig;
ivan@155 143 end
ivan@155 144 if isfield(Problem,'reconstruct')
ivan@155 145 Problem = rmfield(Problem, 'reconstruct');
ivan@155 146 end
ivan@155 147 solver.profile = 0;
ivan@155 148
ivan@155 149 % main loop %
ivan@155 150
ivan@155 151 for i = 1:iternum
ivan@155 152 Problem.A = dico;
ivan@155 153
ivan@155 154 solver = SMALL_solve(Problem, solver);
ivan@155 155
ivan@155 156 switch lower(typeUpdate)
ivan@155 157 case 'mm_cn'
ivan@155 158 [dico, solver.solution] = ...
ivan@155 159 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
ivan@155 160 case 'mm_fn'
ivan@155 161 [dico, solver.solution] = ...
ivan@155 162 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
ivan@155 163 case 'mod_cn'
ivan@155 164 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
ivan@155 165 case 'map_cn'
ivan@155 166 if isfield(DL.param,'muMAP')
ivan@155 167 muMAP = DL.param.muMAP;
ivan@155 168 else
ivan@155 169 muMAP = 1e-4;
ivan@155 170 end
ivan@155 171 [dico, solver.solution] = ...
ivan@155 172 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
ivan@155 173 case 'ksvd_cn'
ivan@155 174 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
ivan@155 175 otherwise
ivan@155 176 error('Dictionary update is not defined');
ivan@155 177 end
ivan@155 178
ivan@155 179 % Set previous solution as the best initial guess
ivan@155 180 % for the next iteration of iterative soft tresholding
ivan@155 181
ivan@155 182 if (strcmpi(solver.toolbox, 'MMbox'))
ivan@155 183 solver.param.initcoeff = solver.solution;
ivan@155 184 end
ivan@155 185
ivan@155 186 % Optional decorrelation of athoms - this is from Boris Mailhe and
ivan@155 187 % we need to test how it preforms with Mehrdad's updates
ivan@155 188
ivan@155 189 if (decorrelate)
ivan@155 190 dico = dico_decorr(dico, mu, solver.solution);
ivan@155 191 end
ivan@155 192
ivan@155 193 if ((show_dictionary)&&(mod(i,show_iter)==0))
ivan@155 194 dictimg = SMALL_showdict(dico,[8 8],...
ivan@155 195 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
ivan@155 196 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@155 197 pause(0.02);
ivan@155 198 end
ivan@155 199 end
ivan@155 200 if isfield(Problem,'b1')
ivan@155 201 Problem.b1 = tmpTraining;
ivan@155 202 end
ivan@155 203 DL.D = dico;
ivan@155 204
ivan@155 205 end
ivan@155 206
ivan@155 207 function Y = colnorms_squared(X)
ivan@155 208
ivan@155 209 % compute in blocks to conserve memory
ivan@155 210 Y = zeros(1,size(X,2));
ivan@155 211 blocksize = 2000;
ivan@155 212 for i = 1:blocksize:size(X,2)
ivan@155 213 blockids = i : min(i+blocksize-1,size(X,2));
ivan@155 214 Y(blockids) = sum(X(:,blockids).^2);
ivan@155 215 end
ivan@155 216
ivan@155 217 end