view DL/two-step DL/SMALL_two_step_DL.m @ 218:c38d965b5a1d luisf_dev

Moved CVX_add_const_Audio_declipping.m from solvers to the examples/AudioInpainting folder.
author Aris Gretsistas <aris.gretsistas@eecs.qmul.ac.uk>
date Thu, 22 Mar 2012 15:37:45 +0000
parents f12a476a4977
children fd0b5d36f6ad
line wrap: on
line source
function DL=SMALL_two_step_DL(Problem, DL)
    
    %%  DL=SMALL_two_step_DL(Problem, DL) learn a dictionary using two_step_DL
    %   The specific parameters of the DL structure are:
    %   -name: can be either 'ols', 'opt', 'MOD', KSVD' or 'LGD'.
    %   -param.learningRate: a step size used by 'ols' and 'opt'. Default: 0.1
    %   for 'ols', 1 for 'opt'.
    %   -param.flow: can be either 'sequential' or 'parallel'. De fault:
    %   'sequential'. Not used by MOD.
    %   -param.coherence: a real number between 0 and 1. If present, then
    %   a low-coherence constraint is added to the learning.
    %
    %   See dico_update.m for more details.

% determine which solver is used for sparse representation %

solver = DL.param.solver;

% determine which type of udate to use ('KSVD', 'MOD', 'ols', 'opt' or 'LGD') %

typeUpdate = DL.name;

sig = Problem.b;

% determine dictionary size %

if (isfield(DL.param,'initdict'))
  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
    dictsize = length(DL.param.initdict);
  else
    dictsize = size(DL.param.initdict,2);
  end
end
if (isfield(DL.param,'dictsize'))    % this superceedes the size determined by initdict
  dictsize = DL.param.dictsize;
end

if (size(sig,2) < dictsize)
  error('Number of training signals is smaller than number of atoms to train');
end


% initialize the dictionary %

if (isfield(DL.param,'initdict'))
  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
    dico = sig(:,DL.param.initdict(1:dictsize));
  else
    if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
      error('Invalid initial dictionary');
    end
    dico = DL.param.initdict(:,1:dictsize);
  end
else
  data_ids = find(colnorms_squared(sig) > 1e-6);   % ensure no zero data elements are chosen
  perm = randperm(length(data_ids));
  dico = sig(:,data_ids(perm(1:dictsize)));
end

% flow: 'sequential' or 'parallel'. If sequential, the residual is updated 
% after each atom update. If parallel, the residual is only updated once 
% the whole dictionary has been computed. Sequential works better, there 
% may be no need to implement parallel. Not used with MOD.

if isfield(DL.param,'flow')
    flow =  DL.param.flow;
else
    flow = 'sequential';
end

% learningRate. If the type is 'ols', it is the descent step of
% the gradient (default: 0.1). If the type is 'mailhe', the 
% descent step is the optimal step*rho (default: 1, although 2 works
% better). Not used for MOD and KSVD.

if isfield(DL.param,'learningRate')
    learningRate = DL.param.learningRate;
else
    switch typeUpdate
        case 'ols'
            learningRate = 0.1;
        otherwise
            learningRate = 1;
    end
end

% number of iterations (default is 40) %

if isfield(DL.param,'iternum')
    iternum = DL.param.iternum;
else
    iternum = 40;
end
% determine if we should do decorrelation in every iteration  %

if isfield(DL.param,'coherence')
    decorrelate = 1;
    mu = DL.param.coherence;
else
    decorrelate = 0;
end

% show dictonary every specified number of iterations

if isfield(DL.param,'show_dict')
    show_dictionary=1;
    show_iter=DL.param.show_dict;
else
    show_dictionary=0;
    show_iter=0;
end

% This is a small patch that needs to be resolved in dictionary learning we
% want sparse representation of training set, and in Problem.b1 in this
% version of software we store the signal that needs to be represented
% (for example the whole image)

tmpTraining = Problem.b1;
Problem.b1 = sig;
if isfield(Problem,'reconstruct')
    Problem = rmfield(Problem, 'reconstruct');
end
solver.profile = 0;

% main loop %

for i = 1:iternum
    Problem.A = dico;
    solver = SMALL_solve(Problem, solver);
    [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
        typeUpdate, flow, learningRate);
    if (decorrelate)
        dico = dico_decorr_symetric(dico, mu, solver.solution);
    end
    
   if ((show_dictionary)&&(mod(i,show_iter)==0))
       dictimg = SMALL_showdict(dico,[8 8],...
            round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');  
       figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
       pause(0.02);
   end
end

Problem.b1 = tmpTraining;
DL.D = dico;

end

function Y = colnorms_squared(X)

% compute in blocks to conserve memory
Y = zeros(1,size(X,2));
blocksize = 2000;
for i = 1:blocksize:size(X,2)
  blockids = i : min(i+blocksize-1,size(X,2));
  Y(blockids) = sum(X(:,blockids).^2);
end

end