ivan@155: function DL = wrapper_mm_DL(Problem, DL) aris@211: %% SMALL wrapper for Majorization Minimization Dictionary Learning Algorithm aris@211: % aris@211: % Function gets as input Problem and Dictionary Learning (DL) structures aris@211: % and outputs the learned Dictionary. aris@211: aris@211: % In Problem structure field b with the training set needs to be defined. aris@211: aris@211: % In DL fields with name of the Dictionary update method and parameters aris@211: % for particular dictionary learning technique need to be present. For aris@211: % the orignal version of MM algorithm the update method should be: aris@211: % - 'mm_cn' - Regularized DL with column norm contraint aris@211: % - 'mm_fn' - Regularized DL with Frobenius norm contraint aris@211: % Alternatively, for comparison purposes the following Dictioanry update aris@211: % methods (which do not represent the optimised version of the algorithm) aris@211: % be used: aris@211: % - 'mod_cn' - Method of Optimized Direction aris@211: % - 'map-cn' - Maximum a Posteriory Dictionary update aris@211: % - 'ksvd-cn'- KSVD update aris@211: % aris@211: % - MM-DL - Yaghoobi, M.; Blumensath, T,; Davies M.; , "Dictionary aris@211: % Learning for Sparse Approximation with Majorization Method," IEEE aris@211: % Transactions on Signal Processing, vol.57, no.6, pp.2178-2191, 2009. aris@211: aris@211: % Centre for Digital Music, Queen Mary, University of London. aris@211: % This file copyright 2011 Ivan Damnjanovic. aris@211: % aris@211: % This program is free software; you can redistribute it and/or aris@211: % modify it under the terms of the GNU General Public License as aris@211: % published by the Free Software Foundation; either version 2 of the aris@211: % License, or (at your option) any later version. See the file aris@211: % COPYING included with this distribution for more information. aris@211: %% ivan@155: ivan@155: % determine which solver is used for sparse representation % ivan@155: ivan@155: solver = DL.param.solver; ivan@155: ivan@155: % determine which type of udate to use ivan@155: % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn', ivan@155: % 'MAP_cn', 'KSVD_cn') ivan@155: ivan@155: typeUpdate = DL.name; ivan@155: ivan@155: sig = Problem.b; ivan@155: ivan@155: % determine dictionary size % ivan@155: ivan@155: if (isfield(DL.param,'initdict')) ivan@155: if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) ivan@155: dictsize = length(DL.param.initdict); ivan@155: else ivan@155: dictsize = size(DL.param.initdict,2); ivan@155: end ivan@155: end ivan@155: ivan@155: if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict ivan@155: dictsize = DL.param.dictsize; ivan@155: end ivan@155: ivan@155: if (size(sig,2) < dictsize) ivan@155: error('Number of training signals is smaller than number of atoms to train'); ivan@155: end ivan@155: ivan@155: ivan@155: % initialize the dictionary % ivan@155: ivan@155: if (isfield(DL.param,'initdict')) ivan@155: if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) ivan@155: dico = sig(:,DL.param.initdict(1:dictsize)); ivan@155: else ivan@155: if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2) 1e-6); % ensure no zero data elements are chosen ivan@155: perm = randperm(length(data_ids)); ivan@155: dico = sig(:,data_ids(perm(1:dictsize))); ivan@155: end ivan@155: ivan@155: ivan@155: % number of iterations (default is 40) % ivan@155: ivan@155: if isfield(DL.param,'iternum') ivan@155: iternum = DL.param.iternum; ivan@155: else ivan@155: iternum = 40; ivan@155: end ivan@155: ivan@155: % number of iterations (default is 40) % ivan@155: ivan@155: if isfield(DL.param,'iterDictUpdate') ivan@155: maxIT = DL.param.iterDictUpdate; ivan@155: else ivan@155: maxIT = 1000; ivan@155: end ivan@155: ivan@155: % Stopping criterion for MM dictionary update (default = 1e-7) ivan@155: ivan@155: if isfield(DL.param,'epsDictUpdate') ivan@155: epsD = DL.param.epsDictUpdate; ivan@155: else ivan@155: epsD = 1e-7; ivan@155: end ivan@155: ivan@155: % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1 ivan@155: % (default cvset is o) % ivan@155: ivan@155: if isfield(DL.param,'cvset') ivan@155: cvset = DL.param.cvset; ivan@155: else ivan@155: cvset = 0; ivan@155: end ivan@155: ivan@155: % determine if we should do decorrelation in every iteration % ivan@155: ivan@155: if isfield(DL.param,'coherence') ivan@155: decorrelate = 1; ivan@155: mu = DL.param.coherence; ivan@155: else ivan@155: decorrelate = 0; ivan@155: end ivan@155: ivan@155: % show dictonary every specified number of iterations ivan@155: ivan@155: if isfield(DL.param,'show_dict') ivan@155: show_dictionary = 1; ivan@155: show_iter = DL.param.show_dict; ivan@155: else ivan@155: show_dictionary = 0; ivan@155: show_iter = 0; ivan@155: end ivan@155: ivan@155: % This is a small patch that needs to be resolved in dictionary learning we ivan@155: % want sparse representation of training set, and in Problem.b1 in this ivan@155: % version of software we store the signal that needs to be represented ivan@155: % (for example the whole image) ivan@155: if isfield(Problem,'b1') ivan@155: tmpTraining = Problem.b1; ivan@155: Problem.b1 = sig; ivan@155: end ivan@155: if isfield(Problem,'reconstruct') ivan@155: Problem = rmfield(Problem, 'reconstruct'); ivan@155: end ivan@155: solver.profile = 0; ivan@155: ivan@155: % main loop % ivan@155: ivan@155: for i = 1:iternum ivan@155: Problem.A = dico; ivan@155: ivan@155: solver = SMALL_solve(Problem, solver); ivan@155: ivan@155: switch lower(typeUpdate) ivan@155: case 'mm_cn' ivan@155: [dico, solver.solution] = ... ivan@155: dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset); ivan@155: case 'mm_fn' ivan@155: [dico, solver.solution] = ... ivan@155: dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset); ivan@155: case 'mod_cn' ivan@155: [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset); ivan@155: case 'map_cn' ivan@155: if isfield(DL.param,'muMAP') ivan@155: muMAP = DL.param.muMAP; ivan@155: else ivan@155: muMAP = 1e-4; ivan@155: end ivan@155: [dico, solver.solution] = ... ivan@155: dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset); ivan@155: case 'ksvd_cn' ivan@155: [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution); ivan@155: otherwise ivan@155: error('Dictionary update is not defined'); ivan@155: end ivan@155: ivan@155: % Set previous solution as the best initial guess ivan@155: % for the next iteration of iterative soft tresholding ivan@155: ivan@155: if (strcmpi(solver.toolbox, 'MMbox')) ivan@155: solver.param.initcoeff = solver.solution; ivan@155: end ivan@155: ivan@155: % Optional decorrelation of athoms - this is from Boris Mailhe and ivan@155: % we need to test how it preforms with Mehrdad's updates ivan@155: ivan@155: if (decorrelate) ivan@155: dico = dico_decorr(dico, mu, solver.solution); ivan@155: end ivan@155: ivan@155: if ((show_dictionary)&&(mod(i,show_iter)==0)) ivan@155: dictimg = SMALL_showdict(dico,[8 8],... ivan@155: round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); ivan@155: figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; ivan@155: pause(0.02); ivan@155: end ivan@155: end ivan@155: if isfield(Problem,'b1') ivan@155: Problem.b1 = tmpTraining; ivan@155: end ivan@155: DL.D = dico; ivan@155: ivan@155: end ivan@155: ivan@155: function Y = colnorms_squared(X) ivan@155: ivan@155: % compute in blocks to conserve memory ivan@155: Y = zeros(1,size(X,2)); ivan@155: blocksize = 2000; ivan@155: for i = 1:blocksize:size(X,2) ivan@155: blockids = i : min(i+blocksize-1,size(X,2)); ivan@155: Y(blockids) = sum(X(:,blockids).^2); ivan@155: end ivan@155: ivan@155: end