Mercurial > hg > smallbox
view DL/Majorization Minimization DL/wrapper_mm_DL.m @ 173:7426503fc4d1 danieleb
added ramirez_dl dictionary learning case
author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
---|---|
date | Thu, 17 Nov 2011 11:15:02 +0000 |
parents | b14209313ba4 |
children | 0c7c20f3246c |
line wrap: on
line source
function DL = wrapper_mm_DL(Problem, DL) % determine which solver is used for sparse representation % solver = DL.param.solver; % determine which type of udate to use % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn', % 'MAP_cn', 'KSVD_cn') typeUpdate = DL.name; sig = Problem.b; % determine dictionary size % if (isfield(DL.param,'initdict')) if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) dictsize = length(DL.param.initdict); else dictsize = size(DL.param.initdict,2); end end if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict dictsize = DL.param.dictsize; end if (size(sig,2) < dictsize) error('Number of training signals is smaller than number of atoms to train'); end % initialize the dictionary % if (isfield(DL.param,'initdict')) if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) dico = sig(:,DL.param.initdict(1:dictsize)); else if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) error('Invalid initial dictionary'); end dico = DL.param.initdict(:,1:dictsize); end else data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen perm = randperm(length(data_ids)); dico = sig(:,data_ids(perm(1:dictsize))); end % number of iterations (default is 40) % if isfield(DL.param,'iternum') iternum = DL.param.iternum; else iternum = 40; end % number of iterations (default is 40) % if isfield(DL.param,'iterDictUpdate') maxIT = DL.param.iterDictUpdate; else maxIT = 1000; end % Stopping criterion for MM dictionary update (default = 1e-7) if isfield(DL.param,'epsDictUpdate') epsD = DL.param.epsDictUpdate; else epsD = 1e-7; end % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1 % (default cvset is o) % if isfield(DL.param,'cvset') cvset = DL.param.cvset; else cvset = 0; end % determine if we should do decorrelation in every iteration % if isfield(DL.param,'coherence') decorrelate = 1; mu = DL.param.coherence; else decorrelate = 0; end % show dictonary every specified number of iterations if isfield(DL.param,'show_dict') show_dictionary = 1; show_iter = DL.param.show_dict; else show_dictionary = 0; show_iter = 0; end % This is a small patch that needs to be resolved in dictionary learning we % want sparse representation of training set, and in Problem.b1 in this % version of software we store the signal that needs to be represented % (for example the whole image) if isfield(Problem,'b1') tmpTraining = Problem.b1; Problem.b1 = sig; end if isfield(Problem,'reconstruct') Problem = rmfield(Problem, 'reconstruct'); end solver.profile = 0; % main loop % for i = 1:iternum Problem.A = dico; solver = SMALL_solve(Problem, solver); switch lower(typeUpdate) case 'mm_cn' [dico, solver.solution] = ... dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset); case 'mm_fn' [dico, solver.solution] = ... dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset); case 'mod_cn' [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset); case 'map_cn' if isfield(DL.param,'muMAP') muMAP = DL.param.muMAP; else muMAP = 1e-4; end [dico, solver.solution] = ... dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset); case 'ksvd_cn' [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution); otherwise error('Dictionary update is not defined'); end % Set previous solution as the best initial guess % for the next iteration of iterative soft tresholding if (strcmpi(solver.toolbox, 'MMbox')) solver.param.initcoeff = solver.solution; end % Optional decorrelation of athoms - this is from Boris Mailhe and % we need to test how it preforms with Mehrdad's updates if (decorrelate) dico = dico_decorr(dico, mu, solver.solution); end if ((show_dictionary)&&(mod(i,show_iter)==0)) dictimg = SMALL_showdict(dico,[8 8],... round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; pause(0.02); end end if isfield(Problem,'b1') Problem.b1 = tmpTraining; end DL.D = dico; end function Y = colnorms_squared(X) % compute in blocks to conserve memory Y = zeros(1,size(X,2)); blocksize = 2000; for i = 1:blocksize:size(X,2) blockids = i : min(i+blocksize-1,size(X,2)); Y(blockids) = sum(X(:,blockids).^2); end end