Mercurial > hg > smallbox
view DL/Majorization Minimization DL/wrapper_mm_DL.m @ 216:a986ee86651e luisf_dev
Calls SMALLboxInit in the beginning of both solve and learn, in order not to lose the SMALL_path variable.
author | luisf <luis.figueira@eecs.qmul.ac.uk> |
---|---|
date | Thu, 22 Mar 2012 11:41:04 +0000 |
parents | b9b4dc87f1aa |
children | 4337e28183f1 |
line wrap: on
line source
function DL = wrapper_mm_DL(Problem, DL) %% SMALL wrapper for Majorization Minimization Dictionary Learning Algorithm % % Function gets as input Problem and Dictionary Learning (DL) structures % and outputs the learned Dictionary. % In Problem structure field b with the training set needs to be defined. % In DL structure field with name of the Dictionary update method needs % to be present. For the orignal version of MM algorithm the update % method should be: % - 'mm_cn' - Regularized DL with column norm contraint % - 'mm_fn' - Regularized DL with Frobenius norm contraint % Alternatively, for comparison purposes the following Dictioanry update % methods (which do not represent the optimised version of the algorithm) % be used: % - 'mod_cn' - Method of Optimized Direction % - 'map-cn' - Maximum a Posteriory Dictionary update % - 'ksvd-cn'- KSVD update % % DL.param.solver structure is also required. For the original version of % MM algorithm, DL.param.solver.toolbox should be 'MMbox'. The parameters % in DL.param.solver.param should be set accordingly. Type help % wrapper_mm_solver for more details. % % - MM-DL - Yaghoobi, M.; Blumensath, T,; Davies M.; , "Dictionary % Learning for Sparse Approximation with Majorization Method," IEEE % Transactions on Signal Processing, vol.57, no.6, pp.2178-2191, 2009. % Centre for Digital Music, Queen Mary, University of London. % This file copyright 2011 Ivan Damnjanovic. % % This program is free software; you can redistribute it and/or % modify it under the terms of the GNU General Public License as % published by the Free Software Foundation; either version 2 of the % License, or (at your option) any later version. See the file % COPYING included with this distribution for more information. %% % determine which solver is used for sparse representation % solver = DL.param.solver; % determine which type of udate to use % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn', % 'MAP_cn', 'KSVD_cn') typeUpdate = DL.name; sig = Problem.b; % determine dictionary size % if (isfield(DL.param,'initdict')) if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) dictsize = length(DL.param.initdict); else dictsize = size(DL.param.initdict,2); end end if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict dictsize = DL.param.dictsize; end if (size(sig,2) < dictsize) error('Number of training signals is smaller than number of atoms to train'); end % initialize the dictionary % if (isfield(DL.param,'initdict')) if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) dico = sig(:,DL.param.initdict(1:dictsize)); else if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) error('Invalid initial dictionary'); end dico = DL.param.initdict(:,1:dictsize); end else data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen perm = randperm(length(data_ids)); dico = sig(:,data_ids(perm(1:dictsize))); end % number of iterations (default is 40) % if isfield(DL.param,'iternum') iternum = DL.param.iternum; else iternum = 40; end % number of iterations (default is 40) % if isfield(DL.param,'iterDictUpdate') maxIT = DL.param.iterDictUpdate; else maxIT = 1000; end % Stopping criterion for MM dictionary update (default = 1e-7) if isfield(DL.param,'epsDictUpdate') epsD = DL.param.epsDictUpdate; else epsD = 1e-7; end % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1 % (default cvset is o) % if isfield(DL.param,'cvset') cvset = DL.param.cvset; else cvset = 0; end % determine if we should do decorrelation in every iteration % if isfield(DL.param,'coherence') decorrelate = 1; mu = DL.param.coherence; else decorrelate = 0; end % show dictonary every specified number of iterations if isfield(DL.param,'show_dict') show_dictionary = 1; show_iter = DL.param.show_dict; else show_dictionary = 0; show_iter = 0; end % This is a small patch that needs to be resolved in dictionary learning we % want sparse representation of training set, and in Problem.b1 in this % version of software we store the signal that needs to be represented % (for example the whole image) if isfield(Problem,'b1') tmpTraining = Problem.b1; Problem.b1 = sig; end if isfield(Problem,'reconstruct') Problem = rmfield(Problem, 'reconstruct'); end solver.profile = 0; % main loop % for i = 1:iternum Problem.A = dico; solver = SMALL_solve(Problem, solver); switch lower(typeUpdate) case 'mm_cn' [dico, solver.solution] = ... dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset); case 'mm_fn' [dico, solver.solution] = ... dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset); case 'mod_cn' [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset); case 'map_cn' if isfield(DL.param,'muMAP') muMAP = DL.param.muMAP; else muMAP = 1e-4; end [dico, solver.solution] = ... dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset); case 'ksvd_cn' [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution); otherwise error('Dictionary update is not defined'); end % Set previous solution as the best initial guess % for the next iteration of iterative soft tresholding if (strcmpi(solver.toolbox, 'MMbox')) solver.param.initcoeff = solver.solution; end % Optional decorrelation of athoms - this is from Boris Mailhe and % we need to test how it preforms with Mehrdad's updates if (decorrelate) dico = dico_decorr(dico, mu, solver.solution); end if ((show_dictionary)&&(mod(i,show_iter)==0)) dictimg = SMALL_showdict(dico,[8 8],... round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; pause(0.02); end end if isfield(Problem,'b1') Problem.b1 = tmpTraining; end DL.D = dico; end function Y = colnorms_squared(X) % compute in blocks to conserve memory Y = zeros(1,size(X,2)); blocksize = 2000; for i = 1:blocksize:size(X,2) blockids = i : min(i+blocksize-1,size(X,2)); Y(blockids) = sum(X(:,blockids).^2); end end