view DL/Majorization Minimization DL/wrapper_mm_DL.m @ 211:0c7c20f3246c luisf_dev

Added comments to ~/DL/Majorization Minimization DL/wrapper_mm_DL.m and ~/DL/Majorization Minimization DL/wrapper_mm_solver.m files.
author Aris Gretsistas <aris.gretsistas@elec.qmul.ac.uk>
date Wed, 21 Mar 2012 17:48:39 +0000
parents b14209313ba4
children b9b4dc87f1aa
line wrap: on
line source
function DL = wrapper_mm_DL(Problem, DL)
%% SMALL wrapper for Majorization Minimization Dictionary Learning Algorithm
%
%   Function gets as input Problem and Dictionary Learning (DL) structures 
%   and outputs the learned Dictionary.

%   In Problem structure field b with the training set needs to be defined.
   
%   In DL fields with name of the Dictionary update method and parameters
%   for particular dictionary learning technique need to be present. For 
%   the orignal version of MM algorithm the update method should be:
%       -   'mm_cn' - Regularized DL with column norm contraint
%       -   'mm_fn' - Regularized DL with Frobenius norm contraint
%   Alternatively, for comparison purposes the following Dictioanry update
%   methods (which do not represent the optimised version of the algorithm)
%   be used:
%       -   'mod_cn' - Method of Optimized Direction
%       -   'map-cn' - Maximum a Posteriory Dictionary update
%       -   'ksvd-cn'- KSVD update
%
%   -   MM-DL - Yaghoobi, M.; Blumensath, T,; Davies M.; , "Dictionary
%   Learning for Sparse Approximation with Majorization Method," IEEE
%   Transactions on Signal Processing, vol.57, no.6, pp.2178-2191, 2009.

%   Centre for Digital Music, Queen Mary, University of London.
%   This file copyright 2011 Ivan Damnjanovic.
%
%   This program is free software; you can redistribute it and/or
%   modify it under the terms of the GNU General Public License as
%   published by the Free Software Foundation; either version 2 of the
%   License, or (at your option) any later version.  See the file
%   COPYING included with this distribution for more information.
%%

% determine which solver is used for sparse representation %

solver = DL.param.solver;

% determine which type of udate to use 
%   (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
%   'MAP_cn', 'KSVD_cn')

typeUpdate = DL.name;

sig = Problem.b;

% determine dictionary size %

if (isfield(DL.param,'initdict'))
  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
    dictsize = length(DL.param.initdict);
  else
    dictsize = size(DL.param.initdict,2);
  end
end

if (isfield(DL.param,'dictsize'))    % this superceedes the size determined by initdict
  dictsize = DL.param.dictsize;
end

if (size(sig,2) < dictsize)
  error('Number of training signals is smaller than number of atoms to train');
end


% initialize the dictionary %

if (isfield(DL.param,'initdict'))
  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
    dico = sig(:,DL.param.initdict(1:dictsize));
  else
    if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
      error('Invalid initial dictionary');
    end
    dico = DL.param.initdict(:,1:dictsize);
  end
else
  data_ids = find(colnorms_squared(sig) > 1e-6);   % ensure no zero data elements are chosen
  perm = randperm(length(data_ids));
  dico = sig(:,data_ids(perm(1:dictsize)));
end


% number of iterations (default is 40) %

if isfield(DL.param,'iternum')
    iternum = DL.param.iternum;
else
    iternum = 40;
end

% number of iterations (default is 40) %

if isfield(DL.param,'iterDictUpdate')
    maxIT = DL.param.iterDictUpdate;
else
    maxIT = 1000;
end

% Stopping criterion for MM dictionary update (default = 1e-7)

if isfield(DL.param,'epsDictUpdate')
    epsD = DL.param.epsDictUpdate;
else
    epsD = 1e-7;
end

% Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
% (default cvset is o) %

if isfield(DL.param,'cvset')
    cvset = DL.param.cvset;
else
    cvset = 0;
end

% determine if we should do decorrelation in every iteration  %

if isfield(DL.param,'coherence')
    decorrelate = 1;
    mu = DL.param.coherence;
else
    decorrelate = 0;
end

% show dictonary every specified number of iterations

if isfield(DL.param,'show_dict')
    show_dictionary = 1;
    show_iter = DL.param.show_dict;
else
    show_dictionary = 0;
    show_iter = 0;
end

% This is a small patch that needs to be resolved in dictionary learning we
% want sparse representation of training set, and in Problem.b1 in this
% version of software we store the signal that needs to be represented
% (for example the whole image)
if isfield(Problem,'b1')
    tmpTraining = Problem.b1;
    Problem.b1 = sig;
end
if isfield(Problem,'reconstruct')
    Problem = rmfield(Problem, 'reconstruct');
end
solver.profile = 0;

% main loop %

for i = 1:iternum
    Problem.A = dico;
    
    solver = SMALL_solve(Problem, solver);
    
    switch lower(typeUpdate)
        case 'mm_cn'
            [dico, solver.solution] = ...
                dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
        case 'mm_fn'
            [dico, solver.solution] = ...
                dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
        case 'mod_cn'
            [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
        case 'map_cn'
            if isfield(DL.param,'muMAP')
                muMAP = DL.param.muMAP;
            else
                muMAP = 1e-4;
            end
            [dico, solver.solution] = ...
                dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
        case 'ksvd_cn'
            [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
        otherwise
            error('Dictionary update is not defined');
    end
    
    % Set previous solution as the best initial guess
    % for the next iteration of iterative soft tresholding
    
    if (strcmpi(solver.toolbox, 'MMbox'))
        solver.param.initcoeff = solver.solution;
    end
    
    % Optional decorrelation of athoms - this is from Boris Mailhe and
    % we need to test how it preforms with Mehrdad's updates
    
    if (decorrelate)
        dico = dico_decorr(dico, mu, solver.solution);
    end
    
    if ((show_dictionary)&&(mod(i,show_iter)==0))
        dictimg = SMALL_showdict(dico,[8 8],...
            round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
        figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
        pause(0.02);
    end
end
if isfield(Problem,'b1')
    Problem.b1 = tmpTraining;
end
DL.D = dico;

end

function Y = colnorms_squared(X)

% compute in blocks to conserve memory
Y = zeros(1,size(X,2));
blocksize = 2000;
for i = 1:blocksize:size(X,2)
  blockids = i : min(i+blocksize-1,size(X,2));
  Y(blockids) = sum(X(:,blockids).^2);
end

end