Mercurial > hg > smallbox
diff DL/Majorization Minimization DL/wrapper_mm_DL.m @ 155:b14209313ba4 ivand_dev
Integration of Majorization Minimisation Dictionary Learning
author | Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk> |
---|---|
date | Mon, 22 Aug 2011 11:46:35 +0100 |
parents | |
children | 0c7c20f3246c |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/Majorization Minimization DL/wrapper_mm_DL.m Mon Aug 22 11:46:35 2011 +0100 @@ -0,0 +1,185 @@ +function DL = wrapper_mm_DL(Problem, DL) + +% determine which solver is used for sparse representation % + +solver = DL.param.solver; + +% determine which type of udate to use +% (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn', +% 'MAP_cn', 'KSVD_cn') + +typeUpdate = DL.name; + +sig = Problem.b; + +% determine dictionary size % + +if (isfield(DL.param,'initdict')) + if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) + dictsize = length(DL.param.initdict); + else + dictsize = size(DL.param.initdict,2); + end +end + +if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict + dictsize = DL.param.dictsize; +end + +if (size(sig,2) < dictsize) + error('Number of training signals is smaller than number of atoms to train'); +end + + +% initialize the dictionary % + +if (isfield(DL.param,'initdict')) + if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) + dico = sig(:,DL.param.initdict(1:dictsize)); + else + if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) + error('Invalid initial dictionary'); + end + dico = DL.param.initdict(:,1:dictsize); + end +else + data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen + perm = randperm(length(data_ids)); + dico = sig(:,data_ids(perm(1:dictsize))); +end + + +% number of iterations (default is 40) % + +if isfield(DL.param,'iternum') + iternum = DL.param.iternum; +else + iternum = 40; +end + +% number of iterations (default is 40) % + +if isfield(DL.param,'iterDictUpdate') + maxIT = DL.param.iterDictUpdate; +else + maxIT = 1000; +end + +% Stopping criterion for MM dictionary update (default = 1e-7) + +if isfield(DL.param,'epsDictUpdate') + epsD = DL.param.epsDictUpdate; +else + epsD = 1e-7; +end + +% Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1 +% (default cvset is o) % + +if isfield(DL.param,'cvset') + cvset = DL.param.cvset; +else + cvset = 0; +end + +% determine if we should do decorrelation in every iteration % + +if isfield(DL.param,'coherence') + decorrelate = 1; + mu = DL.param.coherence; +else + decorrelate = 0; +end + +% show dictonary every specified number of iterations + +if isfield(DL.param,'show_dict') + show_dictionary = 1; + show_iter = DL.param.show_dict; +else + show_dictionary = 0; + show_iter = 0; +end + +% This is a small patch that needs to be resolved in dictionary learning we +% want sparse representation of training set, and in Problem.b1 in this +% version of software we store the signal that needs to be represented +% (for example the whole image) +if isfield(Problem,'b1') + tmpTraining = Problem.b1; + Problem.b1 = sig; +end +if isfield(Problem,'reconstruct') + Problem = rmfield(Problem, 'reconstruct'); +end +solver.profile = 0; + +% main loop % + +for i = 1:iternum + Problem.A = dico; + + solver = SMALL_solve(Problem, solver); + + switch lower(typeUpdate) + case 'mm_cn' + [dico, solver.solution] = ... + dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset); + case 'mm_fn' + [dico, solver.solution] = ... + dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset); + case 'mod_cn' + [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset); + case 'map_cn' + if isfield(DL.param,'muMAP') + muMAP = DL.param.muMAP; + else + muMAP = 1e-4; + end + [dico, solver.solution] = ... + dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset); + case 'ksvd_cn' + [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution); + otherwise + error('Dictionary update is not defined'); + end + + % Set previous solution as the best initial guess + % for the next iteration of iterative soft tresholding + + if (strcmpi(solver.toolbox, 'MMbox')) + solver.param.initcoeff = solver.solution; + end + + % Optional decorrelation of athoms - this is from Boris Mailhe and + % we need to test how it preforms with Mehrdad's updates + + if (decorrelate) + dico = dico_decorr(dico, mu, solver.solution); + end + + if ((show_dictionary)&&(mod(i,show_iter)==0)) + dictimg = SMALL_showdict(dico,[8 8],... + round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); + figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; + pause(0.02); + end +end +if isfield(Problem,'b1') + Problem.b1 = tmpTraining; +end +DL.D = dico; + +end + +function Y = colnorms_squared(X) + +% compute in blocks to conserve memory +Y = zeros(1,size(X,2)); +blocksize = 2000; +for i = 1:blocksize:size(X,2) + blockids = i : min(i+blocksize-1,size(X,2)); + Y(blockids) = sum(X(:,blockids).^2); +end + +end \ No newline at end of file