annotate DL/Majorization Minimization DL/wrapper_mm_DL.m @ 173:7426503fc4d1 danieleb

added ramirez_dl dictionary learning case
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Thu, 17 Nov 2011 11:15:02 +0000
parents b14209313ba4
children 0c7c20f3246c
rev   line source
ivan@155 1 function DL = wrapper_mm_DL(Problem, DL)
ivan@155 2
ivan@155 3 % determine which solver is used for sparse representation %
ivan@155 4
ivan@155 5 solver = DL.param.solver;
ivan@155 6
ivan@155 7 % determine which type of udate to use
ivan@155 8 % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
ivan@155 9 % 'MAP_cn', 'KSVD_cn')
ivan@155 10
ivan@155 11 typeUpdate = DL.name;
ivan@155 12
ivan@155 13 sig = Problem.b;
ivan@155 14
ivan@155 15 % determine dictionary size %
ivan@155 16
ivan@155 17 if (isfield(DL.param,'initdict'))
ivan@155 18 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@155 19 dictsize = length(DL.param.initdict);
ivan@155 20 else
ivan@155 21 dictsize = size(DL.param.initdict,2);
ivan@155 22 end
ivan@155 23 end
ivan@155 24
ivan@155 25 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
ivan@155 26 dictsize = DL.param.dictsize;
ivan@155 27 end
ivan@155 28
ivan@155 29 if (size(sig,2) < dictsize)
ivan@155 30 error('Number of training signals is smaller than number of atoms to train');
ivan@155 31 end
ivan@155 32
ivan@155 33
ivan@155 34 % initialize the dictionary %
ivan@155 35
ivan@155 36 if (isfield(DL.param,'initdict'))
ivan@155 37 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
ivan@155 38 dico = sig(:,DL.param.initdict(1:dictsize));
ivan@155 39 else
ivan@155 40 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
ivan@155 41 error('Invalid initial dictionary');
ivan@155 42 end
ivan@155 43 dico = DL.param.initdict(:,1:dictsize);
ivan@155 44 end
ivan@155 45 else
ivan@155 46 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
ivan@155 47 perm = randperm(length(data_ids));
ivan@155 48 dico = sig(:,data_ids(perm(1:dictsize)));
ivan@155 49 end
ivan@155 50
ivan@155 51
ivan@155 52 % number of iterations (default is 40) %
ivan@155 53
ivan@155 54 if isfield(DL.param,'iternum')
ivan@155 55 iternum = DL.param.iternum;
ivan@155 56 else
ivan@155 57 iternum = 40;
ivan@155 58 end
ivan@155 59
ivan@155 60 % number of iterations (default is 40) %
ivan@155 61
ivan@155 62 if isfield(DL.param,'iterDictUpdate')
ivan@155 63 maxIT = DL.param.iterDictUpdate;
ivan@155 64 else
ivan@155 65 maxIT = 1000;
ivan@155 66 end
ivan@155 67
ivan@155 68 % Stopping criterion for MM dictionary update (default = 1e-7)
ivan@155 69
ivan@155 70 if isfield(DL.param,'epsDictUpdate')
ivan@155 71 epsD = DL.param.epsDictUpdate;
ivan@155 72 else
ivan@155 73 epsD = 1e-7;
ivan@155 74 end
ivan@155 75
ivan@155 76 % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
ivan@155 77 % (default cvset is o) %
ivan@155 78
ivan@155 79 if isfield(DL.param,'cvset')
ivan@155 80 cvset = DL.param.cvset;
ivan@155 81 else
ivan@155 82 cvset = 0;
ivan@155 83 end
ivan@155 84
ivan@155 85 % determine if we should do decorrelation in every iteration %
ivan@155 86
ivan@155 87 if isfield(DL.param,'coherence')
ivan@155 88 decorrelate = 1;
ivan@155 89 mu = DL.param.coherence;
ivan@155 90 else
ivan@155 91 decorrelate = 0;
ivan@155 92 end
ivan@155 93
ivan@155 94 % show dictonary every specified number of iterations
ivan@155 95
ivan@155 96 if isfield(DL.param,'show_dict')
ivan@155 97 show_dictionary = 1;
ivan@155 98 show_iter = DL.param.show_dict;
ivan@155 99 else
ivan@155 100 show_dictionary = 0;
ivan@155 101 show_iter = 0;
ivan@155 102 end
ivan@155 103
ivan@155 104 % This is a small patch that needs to be resolved in dictionary learning we
ivan@155 105 % want sparse representation of training set, and in Problem.b1 in this
ivan@155 106 % version of software we store the signal that needs to be represented
ivan@155 107 % (for example the whole image)
ivan@155 108 if isfield(Problem,'b1')
ivan@155 109 tmpTraining = Problem.b1;
ivan@155 110 Problem.b1 = sig;
ivan@155 111 end
ivan@155 112 if isfield(Problem,'reconstruct')
ivan@155 113 Problem = rmfield(Problem, 'reconstruct');
ivan@155 114 end
ivan@155 115 solver.profile = 0;
ivan@155 116
ivan@155 117 % main loop %
ivan@155 118
ivan@155 119 for i = 1:iternum
ivan@155 120 Problem.A = dico;
ivan@155 121
ivan@155 122 solver = SMALL_solve(Problem, solver);
ivan@155 123
ivan@155 124 switch lower(typeUpdate)
ivan@155 125 case 'mm_cn'
ivan@155 126 [dico, solver.solution] = ...
ivan@155 127 dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
ivan@155 128 case 'mm_fn'
ivan@155 129 [dico, solver.solution] = ...
ivan@155 130 dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
ivan@155 131 case 'mod_cn'
ivan@155 132 [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
ivan@155 133 case 'map_cn'
ivan@155 134 if isfield(DL.param,'muMAP')
ivan@155 135 muMAP = DL.param.muMAP;
ivan@155 136 else
ivan@155 137 muMAP = 1e-4;
ivan@155 138 end
ivan@155 139 [dico, solver.solution] = ...
ivan@155 140 dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
ivan@155 141 case 'ksvd_cn'
ivan@155 142 [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
ivan@155 143 otherwise
ivan@155 144 error('Dictionary update is not defined');
ivan@155 145 end
ivan@155 146
ivan@155 147 % Set previous solution as the best initial guess
ivan@155 148 % for the next iteration of iterative soft tresholding
ivan@155 149
ivan@155 150 if (strcmpi(solver.toolbox, 'MMbox'))
ivan@155 151 solver.param.initcoeff = solver.solution;
ivan@155 152 end
ivan@155 153
ivan@155 154 % Optional decorrelation of athoms - this is from Boris Mailhe and
ivan@155 155 % we need to test how it preforms with Mehrdad's updates
ivan@155 156
ivan@155 157 if (decorrelate)
ivan@155 158 dico = dico_decorr(dico, mu, solver.solution);
ivan@155 159 end
ivan@155 160
ivan@155 161 if ((show_dictionary)&&(mod(i,show_iter)==0))
ivan@155 162 dictimg = SMALL_showdict(dico,[8 8],...
ivan@155 163 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
ivan@155 164 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
ivan@155 165 pause(0.02);
ivan@155 166 end
ivan@155 167 end
ivan@155 168 if isfield(Problem,'b1')
ivan@155 169 Problem.b1 = tmpTraining;
ivan@155 170 end
ivan@155 171 DL.D = dico;
ivan@155 172
ivan@155 173 end
ivan@155 174
ivan@155 175 function Y = colnorms_squared(X)
ivan@155 176
ivan@155 177 % compute in blocks to conserve memory
ivan@155 178 Y = zeros(1,size(X,2));
ivan@155 179 blocksize = 2000;
ivan@155 180 for i = 1:blocksize:size(X,2)
ivan@155 181 blockids = i : min(i+blocksize-1,size(X,2));
ivan@155 182 Y(blockids) = sum(X(:,blockids).^2);
ivan@155 183 end
ivan@155 184
ivan@155 185 end