ivan@155: function DL = wrapper_mm_DL(Problem, DL) ivan@155: ivan@155: % determine which solver is used for sparse representation % ivan@155: ivan@155: solver = DL.param.solver; ivan@155: ivan@155: % determine which type of udate to use ivan@155: % (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn', ivan@155: % 'MAP_cn', 'KSVD_cn') ivan@155: ivan@155: typeUpdate = DL.name; ivan@155: ivan@155: sig = Problem.b; ivan@155: ivan@155: % determine dictionary size % ivan@155: ivan@155: if (isfield(DL.param,'initdict')) ivan@155: if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) ivan@155: dictsize = length(DL.param.initdict); ivan@155: else ivan@155: dictsize = size(DL.param.initdict,2); ivan@155: end ivan@155: end ivan@155: ivan@155: if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict ivan@155: dictsize = DL.param.dictsize; ivan@155: end ivan@155: ivan@155: if (size(sig,2) < dictsize) ivan@155: error('Number of training signals is smaller than number of atoms to train'); ivan@155: end ivan@155: ivan@155: ivan@155: % initialize the dictionary % ivan@155: ivan@155: if (isfield(DL.param,'initdict')) ivan@155: if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) ivan@155: dico = sig(:,DL.param.initdict(1:dictsize)); ivan@155: else ivan@155: if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2) 1e-6); % ensure no zero data elements are chosen ivan@155: perm = randperm(length(data_ids)); ivan@155: dico = sig(:,data_ids(perm(1:dictsize))); ivan@155: end ivan@155: ivan@155: ivan@155: % number of iterations (default is 40) % ivan@155: ivan@155: if isfield(DL.param,'iternum') ivan@155: iternum = DL.param.iternum; ivan@155: else ivan@155: iternum = 40; ivan@155: end ivan@155: ivan@155: % number of iterations (default is 40) % ivan@155: ivan@155: if isfield(DL.param,'iterDictUpdate') ivan@155: maxIT = DL.param.iterDictUpdate; ivan@155: else ivan@155: maxIT = 1000; ivan@155: end ivan@155: ivan@155: % Stopping criterion for MM dictionary update (default = 1e-7) ivan@155: ivan@155: if isfield(DL.param,'epsDictUpdate') ivan@155: epsD = DL.param.epsDictUpdate; ivan@155: else ivan@155: epsD = 1e-7; ivan@155: end ivan@155: ivan@155: % Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1 ivan@155: % (default cvset is o) % ivan@155: ivan@155: if isfield(DL.param,'cvset') ivan@155: cvset = DL.param.cvset; ivan@155: else ivan@155: cvset = 0; ivan@155: end ivan@155: ivan@155: % determine if we should do decorrelation in every iteration % ivan@155: ivan@155: if isfield(DL.param,'coherence') ivan@155: decorrelate = 1; ivan@155: mu = DL.param.coherence; ivan@155: else ivan@155: decorrelate = 0; ivan@155: end ivan@155: ivan@155: % show dictonary every specified number of iterations ivan@155: ivan@155: if isfield(DL.param,'show_dict') ivan@155: show_dictionary = 1; ivan@155: show_iter = DL.param.show_dict; ivan@155: else ivan@155: show_dictionary = 0; ivan@155: show_iter = 0; ivan@155: end ivan@155: ivan@155: % This is a small patch that needs to be resolved in dictionary learning we ivan@155: % want sparse representation of training set, and in Problem.b1 in this ivan@155: % version of software we store the signal that needs to be represented ivan@155: % (for example the whole image) ivan@155: if isfield(Problem,'b1') ivan@155: tmpTraining = Problem.b1; ivan@155: Problem.b1 = sig; ivan@155: end ivan@155: if isfield(Problem,'reconstruct') ivan@155: Problem = rmfield(Problem, 'reconstruct'); ivan@155: end ivan@155: solver.profile = 0; ivan@155: ivan@155: % main loop % ivan@155: ivan@155: for i = 1:iternum ivan@155: Problem.A = dico; ivan@155: ivan@155: solver = SMALL_solve(Problem, solver); ivan@155: ivan@155: switch lower(typeUpdate) ivan@155: case 'mm_cn' ivan@155: [dico, solver.solution] = ... ivan@155: dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset); ivan@155: case 'mm_fn' ivan@155: [dico, solver.solution] = ... ivan@155: dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset); ivan@155: case 'mod_cn' ivan@155: [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset); ivan@155: case 'map_cn' ivan@155: if isfield(DL.param,'muMAP') ivan@155: muMAP = DL.param.muMAP; ivan@155: else ivan@155: muMAP = 1e-4; ivan@155: end ivan@155: [dico, solver.solution] = ... ivan@155: dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset); ivan@155: case 'ksvd_cn' ivan@155: [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution); ivan@155: otherwise ivan@155: error('Dictionary update is not defined'); ivan@155: end ivan@155: ivan@155: % Set previous solution as the best initial guess ivan@155: % for the next iteration of iterative soft tresholding ivan@155: ivan@155: if (strcmpi(solver.toolbox, 'MMbox')) ivan@155: solver.param.initcoeff = solver.solution; ivan@155: end ivan@155: ivan@155: % Optional decorrelation of athoms - this is from Boris Mailhe and ivan@155: % we need to test how it preforms with Mehrdad's updates ivan@155: ivan@155: if (decorrelate) ivan@155: dico = dico_decorr(dico, mu, solver.solution); ivan@155: end ivan@155: ivan@155: if ((show_dictionary)&&(mod(i,show_iter)==0)) ivan@155: dictimg = SMALL_showdict(dico,[8 8],... ivan@155: round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); ivan@155: figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; ivan@155: pause(0.02); ivan@155: end ivan@155: end ivan@155: if isfield(Problem,'b1') ivan@155: Problem.b1 = tmpTraining; ivan@155: end ivan@155: DL.D = dico; ivan@155: ivan@155: end ivan@155: ivan@155: function Y = colnorms_squared(X) ivan@155: ivan@155: % compute in blocks to conserve memory ivan@155: Y = zeros(1,size(X,2)); ivan@155: blocksize = 2000; ivan@155: for i = 1:blocksize:size(X,2) ivan@155: blockids = i : min(i+blocksize-1,size(X,2)); ivan@155: Y(blockids) = sum(X(:,blockids).^2); ivan@155: end ivan@155: ivan@155: end