view DL/two-step DL/SMALL_two_step_DL.m @ 152:485747bf39e0 ivand_dev

Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Thu, 28 Jul 2011 15:49:32 +0100
parents
children af307f247ac7
line wrap: on
line source
function DL=SMALL_two_step_DL(Problem, DL)

% determine which solver is used for sparse representation %

solver = DL.param.solver;

% determine which type of udate to use ('KSVD', 'MOD', 'ols' or 'mailhe') %

typeUpdate = DL.name;

sig = Problem.b;

% determine dictionary size %

if (isfield(DL.param,'initdict'))
  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
    dictsize = length(DL.param.initdict);
  else
    dictsize = size(DL.param.initdict,2);
  end
end
if (isfield(DL.param,'dictsize'))    % this superceedes the size determined by initdict
  dictsize = DL.param.dictsize;
end

if (size(sig,2) < dictsize)
  error('Number of training signals is smaller than number of atoms to train');
end


% initialize the dictionary %

if (isfield(DL.param,'initdict'))
  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
    dico = sig(:,DL.param.initdict(1:dictsize));
  else
    if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
      error('Invalid initial dictionary');
    end
    dico = DL.param.initdict(:,1:dictsize);
  end
else
  data_ids = find(colnorms_squared(sig) > 1e-6);   % ensure no zero data elements are chosen
  perm = randperm(length(data_ids));
  dico = sig(:,data_ids(perm(1:dictsize)));
end

% flow: 'sequential' or 'parallel'. If sequential, the residual is updated 
% after each atom update. If parallel, the residual is only updated once 
% the whole dictionary has been computed. Sequential works better, there 
% may be no need to implement parallel. Not used with MOD.

if isfield(DL.param,'flow')
    flow =  DL.param.flow;
else
    flow = 'sequential';
end

% learningRate. If the type is 'ols', it is the descent step of
% the gradient (typical choice: 0.1). If the type is 'mailhe', the 
% descent step is the optimal step*rho (typical choice: 1, although 2
% or 3 seems to work better). Not used for MOD and KSVD.

if isfield(DL.param,'learningRate')
    learningRate = DL.param.learningRate;
else
    learningRate = 0.1;
end

% number of iterations (default is 40) %

if isfield(DL.param,'iternum')
    iternum = DL.param.iternum;
else
    iternum = 40;
end
% determine if we should do decorrelation in every iteration  %

if isfield(DL.param,'coherence')
    decorrelate = 1;
    mu = DL.param.coherence;
else
    decorrelate = 0;
end

% show dictonary every specified number of iterations

if (isfield(DL.param,'show_dict'))
    show_dictionary=1;
    show_iter=DL.param.show_dict;
else
    show_dictionary=0;
    show_iter=0;
end

% This is a small patch that needs to be resolved in dictionary learning we
% want sparse representation of training set, and in Problem.b1 in this
% version of software we store the signal that needs to be represented
% (for example the whole image)

tmpTraining = Problem.b1;
Problem.b1 = sig;
Problem = rmfield(Problem, 'reconstruct');
solver.profile = 0;

% main loop %

for i = 1:iternum
    solver = SMALL_solve(Problem, solver);
    [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
        typeUpdate, flow, learningRate);
    if (decorrelate)
        dico = dico_decorr(dico, mu, solver.solution);
    end
    Problem.A = dico;
   if ((show_dictionary)&&(mod(i,show_iter)==0))
       dictimg = SMALL_showdict(dico,[8 8],...
            round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');  
       figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
       pause(0.02);
   end
end

Problem.b1 = tmpTraining;
DL.D = dico;

end