changeset 152:485747bf39e0 ivand_dev

Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Thu, 28 Jul 2011 15:49:32 +0100
parents fec205ec6ef6
children af307f247ac7
files DL/two-step DL/SMALL_two_step_DL.m DL/two-step DL/dico_color.m DL/two-step DL/dico_decorr.m DL/two-step DL/dico_update.m examples/Image Denoising/SMALL_ImgDenoise_DL_test_KSVDvsRLSDLAvsTwoStepMOD.m util/SMALL_init_DL.m util/SMALL_learn.m
diffstat 7 files changed, 654 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/DL/two-step DL/SMALL_two_step_DL.m	Thu Jul 28 15:49:32 2011 +0100
@@ -0,0 +1,127 @@
+function DL=SMALL_two_step_DL(Problem, DL)
+
+% determine which solver is used for sparse representation %
+
+solver = DL.param.solver;
+
+% determine which type of udate to use ('KSVD', 'MOD', 'ols' or 'mailhe') %
+
+typeUpdate = DL.name;
+
+sig = Problem.b;
+
+% determine dictionary size %
+
+if (isfield(DL.param,'initdict'))
+  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
+    dictsize = length(DL.param.initdict);
+  else
+    dictsize = size(DL.param.initdict,2);
+  end
+end
+if (isfield(DL.param,'dictsize'))    % this superceedes the size determined by initdict
+  dictsize = DL.param.dictsize;
+end
+
+if (size(sig,2) < dictsize)
+  error('Number of training signals is smaller than number of atoms to train');
+end
+
+
+% initialize the dictionary %
+
+if (isfield(DL.param,'initdict'))
+  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
+    dico = sig(:,DL.param.initdict(1:dictsize));
+  else
+    if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
+      error('Invalid initial dictionary');
+    end
+    dico = DL.param.initdict(:,1:dictsize);
+  end
+else
+  data_ids = find(colnorms_squared(sig) > 1e-6);   % ensure no zero data elements are chosen
+  perm = randperm(length(data_ids));
+  dico = sig(:,data_ids(perm(1:dictsize)));
+end
+
+% flow: 'sequential' or 'parallel'. If sequential, the residual is updated 
+% after each atom update. If parallel, the residual is only updated once 
+% the whole dictionary has been computed. Sequential works better, there 
+% may be no need to implement parallel. Not used with MOD.
+
+if isfield(DL.param,'flow')
+    flow =  DL.param.flow;
+else
+    flow = 'sequential';
+end
+
+% learningRate. If the type is 'ols', it is the descent step of
+% the gradient (typical choice: 0.1). If the type is 'mailhe', the 
+% descent step is the optimal step*rho (typical choice: 1, although 2
+% or 3 seems to work better). Not used for MOD and KSVD.
+
+if isfield(DL.param,'learningRate')
+    learningRate = DL.param.learningRate;
+else
+    learningRate = 0.1;
+end
+
+% number of iterations (default is 40) %
+
+if isfield(DL.param,'iternum')
+    iternum = DL.param.iternum;
+else
+    iternum = 40;
+end
+% determine if we should do decorrelation in every iteration  %
+
+if isfield(DL.param,'coherence')
+    decorrelate = 1;
+    mu = DL.param.coherence;
+else
+    decorrelate = 0;
+end
+
+% show dictonary every specified number of iterations
+
+if (isfield(DL.param,'show_dict'))
+    show_dictionary=1;
+    show_iter=DL.param.show_dict;
+else
+    show_dictionary=0;
+    show_iter=0;
+end
+
+% This is a small patch that needs to be resolved in dictionary learning we
+% want sparse representation of training set, and in Problem.b1 in this
+% version of software we store the signal that needs to be represented
+% (for example the whole image)
+
+tmpTraining = Problem.b1;
+Problem.b1 = sig;
+Problem = rmfield(Problem, 'reconstruct');
+solver.profile = 0;
+
+% main loop %
+
+for i = 1:iternum
+    solver = SMALL_solve(Problem, solver);
+    [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
+        typeUpdate, flow, learningRate);
+    if (decorrelate)
+        dico = dico_decorr(dico, mu, solver.solution);
+    end
+    Problem.A = dico;
+   if ((show_dictionary)&&(mod(i,show_iter)==0))
+       dictimg = SMALL_showdict(dico,[8 8],...
+            round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');  
+       figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
+       pause(0.02);
+   end
+end
+
+Problem.b1 = tmpTraining;
+DL.D = dico;
+
+end
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/DL/two-step DL/dico_color.m	Thu Jul 28 15:49:32 2011 +0100
@@ -0,0 +1,46 @@
+function colors = dico_color(dico, mu)
+    % DICO_COLOR cluster a dictionary in pairs of high correlation atoms.
+    % Called by dico_decorr.
+    %
+    % Parameters:
+    % -dico: the dictionary
+    % -mu: the correlation threshold
+    %
+    % Result:
+    % -colors: a vector of indices. Two atoms with the same color have a 
+    % correlation greater than mu 
+    
+    numAtoms = length(dico);
+    colors = zeros(numAtoms, 1);
+    
+    % compute the correlations
+    G = abs(dico'*dico);
+    G = G-eye(size(G));
+    
+    % iterate on the correlations higher than mu
+    c = 1;   
+    maxCorr = max(max(G));
+    while maxCorr > mu
+        % find the highest correlated pair
+        x = find(max(G)==maxCorr, 1);
+        y = find(G(x,:)==maxCorr, 1);
+        
+        % color them
+        colors(x) = c;
+        colors(y) = c;
+        c = c+1;
+        
+        % make sure these atoms never get selected again
+        G(x,:) = 0;
+        G(:,x) = 0;
+        G(y,:) = 0;
+        G(:,y) = 0;
+        
+        % find the next correlation
+        maxCorr = max(max(G));
+    end
+    
+    % complete the coloring with singletons
+    index = find(colors==0);
+    colors(index) = c:c+length(index)-1;
+end
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/DL/two-step DL/dico_decorr.m	Thu Jul 28 15:49:32 2011 +0100
@@ -0,0 +1,48 @@
+function dico = dico_decorr(dico, mu, amp)
+    %DICO_DECORR decorrelate a dictionary
+    %   Parameters:
+    %   dico: the dictionary
+    %   mu: the coherence threshold
+    %   amp: the amplitude coefficients, only used to decide which atom to
+    %   project
+    %
+    %   Result:
+    %   dico: a dictionary close to the input one with coherence mu.
+    
+    % compute atom weights
+    if nargin > 2
+        rank = sum(amp.*amp, 2);
+    else
+        rank = randperm(length(dico));
+    end
+    
+    % several decorrelation iterations might be needed to reach global
+    % coherence mu. niter can be adjusted to needs.
+    niter = 1;
+    while niter < 5 && ...
+            max(max(abs(dico'*dico -eye(length(dico))))) > mu + 10^-6
+        % find pairs of high correlation atoms
+        colors = dico_color(dico, mu);
+        
+        % iterate on all pairs
+        nbColors = max(colors);
+        for c = 1:nbColors
+            index = find(colors==c);
+            if numel(index) == 2
+                % decide which atom to change (the one with lowest weight)
+                if rank(index(1)) < rank(index(2))
+                    index = fliplr(index);
+                end
+                
+                % update the atom
+                corr = dico(:,index(1))'*dico(:,index(2));
+                alpha = sqrt((1-mu*mu)/(1-corr*corr));
+                beta = corr*alpha-mu*sign(corr);
+                dico(:,index(2)) = alpha*dico(:,index(2))...
+                    -beta*dico(:,index(1));
+            end
+        end
+        niter = niter+1;
+    end
+end
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/DL/two-step DL/dico_update.m	Thu Jul 28 15:49:32 2011 +0100
@@ -0,0 +1,107 @@
+function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
+
+    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+    % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
+    %
+    % perform one iteration of dictionary update for dictionary learning
+    %
+    % parameters:
+    % - dico: the initial dictionary with atoms as columns
+    % - sig: the training data
+    % - amp: the amplitude coefficients as a sparse matrix
+    % - type: the algorithm can be one of the following
+    %   - ols: fixed step gradient descent
+    %   - mailhe: optimal step gradient descent (can be implemented as a
+    %   default for ols?)
+    %   - MOD: pseudo-inverse of the coefficients
+    %   - KSVD: already implemented by Elad
+    % - flow: 'sequential' or 'parallel'. If sequential, the residual is
+    % updated after each atom update. If parallel, the residual is only
+    % updated once the whole dictionary has been computed. Sequential works
+    % better, there may be no need to implement parallel. Not used with
+    % MOD.
+    % - rho: learning rate. If the type is 'ols', it is the descent step of
+    % the gradient (typical choice: 0.1). If the type is 'mailhe', the 
+    % descent step is the optimal step*rho (typical choice: 1, although 2
+    % or 3 seems to work better). Not used for MOD and KSVD.
+    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+    if ~ exist( 'rho', 'var' ) || isempty(rho)
+        rho = 0.1;
+    end
+    
+    if ~ exist( 'flow', 'var' ) || isempty(flow)
+        flow = sequential;
+    end
+    
+    res = sig - dico*amp;
+    nb_pattern = size(dico, 2);
+    
+    switch type
+        case 'rand'
+            x = rand();
+            if x < 1/3
+                type = 'MOD';
+            elseif type < 2/3
+                type = 'mailhe';
+            else
+                type = 'KSVD';
+            end
+    end
+    
+    switch type
+        case 'MOD'
+            G = amp*amp';
+            dico2 = sig*amp'*G^-1;
+            for p = 1:nb_pattern
+                n = norm(dico2(:,p));
+                % renormalize
+                if n > 0
+                    dico(:,p) = dico2(:,p)/n;
+                    amp(p,:) = amp(p,:)*n;
+                end
+            end
+        case 'ols'
+            for p = 1:nb_pattern
+                grad = res*amp(p,:)';
+                if norm(grad) > 0
+                    pat = dico(:,p) + rho*grad;
+                    pat = pat/norm(pat);
+                    if nargin >5 && strcmp(flow, 'sequential')
+                        res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
+                    end
+                    dico(:,p) = pat;
+                end
+            end
+        case 'mailhe'
+            for p = 1:nb_pattern
+                grad = res*amp(p,:)';
+                if norm(grad) > 0
+                    pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
+                    pat = pat/norm(pat);
+                    if nargin >5 && strcmp(flow, 'sequential')
+                        res = res + (dico(:,p)-pat)*amp(p,:);
+                    end
+                    dico(:,p) = pat;
+                end
+            end
+        case 'KSVD'
+            for p = 1:nb_pattern
+                index = find(amp(p,:)~=0);
+                if ~isempty(index)
+                    patch = res(:,index)+dico(:,p)*amp(p,index);
+                    [U,S,V] = svd(patch);
+                    if U(:,1)'*dico(:,p) > 0
+                        dico(:,p) = U(:,1);
+                    else
+                        dico(:,p) = -U(:,1);
+                    end
+                    dico(:,p) = dico(:,p)/norm(dico(:,p));
+                    amp(p,index) = dico(:,p)'*patch;
+                    if nargin >5 && strcmp(flow, 'sequential')
+                        res(:,index) = patch-dico(:,p)*amp(p,index);
+                    end
+                end
+            end
+    end
+end
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/examples/Image Denoising/SMALL_ImgDenoise_DL_test_KSVDvsRLSDLAvsTwoStepMOD.m	Thu Jul 28 15:49:32 2011 +0100
@@ -0,0 +1,288 @@
+%%  Dictionary Learning for Image Denoising - KSVD vs Recursive Least Squares
+%
+%   This file contains an example of how SMALLbox can be used to test different
+%   dictionary learning techniques in Image Denoising problem.
+%   It calls generateImageDenoiseProblem that will let you to choose image,
+%   add noise and use noisy image to generate training set for dictionary
+%   learning.
+%   Two dictionary learning techniques were compared:
+%   -   KSVD - M. Elad, R. Rubinstein, and M. Zibulevsky, "Efficient
+%              Implementation of the K-SVD Algorithm using Batch Orthogonal
+%              Matching Pursuit", Technical Report - CS, Technion, April 2008.
+%   -   RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares
+%       Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on,
+%       vol.58, no.4, pp.2121-2130, April 2010
+%
+
+
+%   Centre for Digital Music, Queen Mary, University of London.
+%   This file copyright 2011 Ivan Damnjanovic.
+%
+%   This program is free software; you can redistribute it and/or
+%   modify it under the terms of the GNU General Public License as
+%   published by the Free Software Foundation; either version 2 of the
+%   License, or (at your option) any later version.  See the file
+%   COPYING included with this distribution for more information.
+%   
+%%
+
+
+
+%   If you want to load the image outside of generateImageDenoiseProblem
+%   function uncomment following lines. This can be useful if you want to
+%   denoise more then one image for example.
+%   Here we are loading test_image.mat that contains structure with 5 images : lena,
+%   barbara,boat, house and peppers.
+clear;
+TMPpath=pwd;
+FS=filesep;
+[pathstr1, name, ext, versn] = fileparts(which('SMALLboxSetup.m'));
+cd([pathstr1,FS,'data',FS,'images']);
+load('test_image.mat');
+cd(TMPpath);
+
+%   Deffining the noise levels that we want to test
+
+noise_level=[10 20 25 50 100];
+
+%   Here we loop through different noise levels and images 
+
+for noise_ind=2:2
+for im_num=1:1
+
+% Defining Image Denoising Problem as Dictionary Learning
+% Problem. As an input we set the number of training patches.
+
+SMALL.Problem = generateImageDenoiseProblem(test_image(im_num).i, 40000, '',256, noise_level(noise_ind));
+SMALL.Problem.name=int2str(im_num);
+
+Edata=sqrt(prod(SMALL.Problem.blocksize)) * SMALL.Problem.sigma * SMALL.Problem.gain;
+maxatoms = floor(prod(SMALL.Problem.blocksize)/2);
+
+%   results structure is to store all results
+
+results(noise_ind,im_num).noisy_psnr=SMALL.Problem.noisy_psnr;
+
+%%
+%   Use KSVD Dictionary Learning Algorithm to Learn overcomplete dictionary
+
+%   Initialising Dictionary structure
+%   Setting Dictionary structure fields (toolbox, name, param, D and time)
+%   to zero values
+
+SMALL.DL(1)=SMALL_init_DL();
+
+% Defining the parameters needed for dictionary learning
+
+SMALL.DL(1).toolbox = 'KSVD';
+SMALL.DL(1).name = 'ksvd';
+
+%   Defining the parameters for KSVD
+%   In this example we are learning 256 atoms in 20 iterations, so that
+%   every patch in the training set can be represented with target error in
+%   L2-norm (Edata)
+%   Type help ksvd in MATLAB prompt for more options.
+
+
+SMALL.DL(1).param=struct(...
+    'Edata', Edata,...
+    'initdict', SMALL.Problem.initdict,...
+    'dictsize', SMALL.Problem.p,...
+    'exact', 1, ...
+    'iternum', 20,...
+    'memusage', 'high');
+
+%   Learn the dictionary
+
+SMALL.DL(1) = SMALL_learn(SMALL.Problem, SMALL.DL(1));
+
+%   Set SMALL.Problem.A dictionary
+%   (backward compatiblity with SPARCO: solver structure communicate
+%   only with Problem structure, ie no direct communication between DL and
+%   solver structures)
+
+SMALL.Problem.A = SMALL.DL(1).D;
+SMALL.Problem.reconstruct = @(x) ImgDenoise_reconstruct(x, SMALL.Problem);
+
+%%
+%   Initialising solver structure
+%   Setting solver structure fields (toolbox, name, param, solution,
+%   reconstructed and time) to zero values
+
+SMALL.solver(1)=SMALL_init_solver;
+
+% Defining the parameters needed for image denoising
+
+SMALL.solver(1).toolbox='ompbox';
+SMALL.solver(1).name='omp2';
+SMALL.solver(1).param=struct(...
+    'epsilon',Edata,...
+    'maxatoms', maxatoms); 
+
+%   Denoising the image - find the sparse solution in the learned
+%   dictionary for all patches in the image and the end it uses
+%   reconstruction function to reconstruct the patches and put them into a
+%   denoised image
+
+SMALL.solver(1)=SMALL_solve(SMALL.Problem, SMALL.solver(1));
+
+%   Show PSNR after reconstruction
+
+SMALL.solver(1).reconstructed.psnr
+
+%%
+%   For comparison purposes we will denoise image with overcomplete DCT
+%   here
+%   Set SMALL.Problem.A dictionary to be oDCT (i.e. Problem.initdict -
+%   since initial dictionaruy is already set to be oDCT when generating the
+%   denoising problem
+
+
+%   Initialising solver structure
+%   Setting solver structure fields (toolbox, name, param, solution,
+%   reconstructed and time) to zero values
+
+SMALL.solver(2)=SMALL_init_solver;
+
+% Defining the parameters needed for image denoising
+
+SMALL.solver(2).toolbox='ompbox';
+SMALL.solver(2).name='omp2';
+SMALL.solver(2).param=struct(...
+    'epsilon',Edata,...
+    'maxatoms', maxatoms); 
+
+%   Initialising Dictionary structure
+%   Setting Dictionary structure fields (toolbox, name, param, D and time)
+%   to zero values
+
+SMALL.DL(2)=SMALL_init_DL('TwoStepDL', 'MOD', '', 1);
+
+
+%   Defining the parameters for MOD
+%   In this example we are learning 256 atoms in 20 iterations, so that
+%   every patch in the training set can be represented with target error in
+%   L2-norm (EData)
+%   Type help ksvd in MATLAB prompt for more options.
+
+
+SMALL.DL(2).param=struct(...
+    'solver', SMALL.solver(2),...
+    'initdict', SMALL.Problem.initdict,...
+    'dictsize', SMALL.Problem.p,...
+    'iternum', 40,...
+    'mu', 0.7,...
+    'show_dict', 1);
+
+%   Learn the dictionary
+
+SMALL.DL(2) = SMALL_learn(SMALL.Problem, SMALL.DL(2));
+
+%   Set SMALL.Problem.A dictionary
+%   (backward compatiblity with SPARCO: solver structure communicate
+%   only with Problem structure, ie no direct communication between DL and
+%   solver structures)
+
+SMALL.Problem.A = SMALL.DL(2).D;
+SMALL.Problem.reconstruct = @(x) ImgDenoise_reconstruct(x, SMALL.Problem);
+
+%   Denoising the image - find the sparse solution in the learned
+%   dictionary for all patches in the image and the end it uses
+%   reconstruction function to reconstruct the patches and put them into a
+%   denoised image
+
+SMALL.solver(2)=SMALL_solve(SMALL.Problem, SMALL.solver(2));
+
+%%
+% In the b1 field all patches from the image are stored. For RLS-DLA we
+% will first exclude all the patches that have l2 norm smaller then
+% threshold and then take min(40000, number_of_remaining_patches) in
+% ascending order as our training set (SMALL.Problem.b)
+
+X=SMALL.Problem.b1;
+X_norm=sqrt(sum(X.^2, 1));
+[X_norm_sort, p]=sort(X_norm);
+p1=p(X_norm_sort>Edata);
+if size(p1,2)>40000
+    p2 = randperm(size(p1,2));
+    p2=sort(p2(1:40000));
+    size(p2,2)
+    SMALL.Problem.b=X(:,p1(p2));
+else 
+    size(p1,2)
+    SMALL.Problem.b=X(:,p1);
+
+end
+
+%   Forgetting factor for RLS-DLA algorithm, in this case we are using
+%   fixed value
+
+lambda=0.9998
+
+%   Use Recursive Least Squares
+%   to Learn overcomplete dictionary 
+
+%   Initialising Dictionary structure
+%   Setting Dictionary structure fields (toolbox, name, param, D and time)
+%   to zero values
+
+SMALL.DL(3)=SMALL_init_DL();
+
+%   Defining fields needed for dictionary learning
+
+SMALL.DL(3).toolbox = 'SMALL';
+SMALL.DL(3).name = 'SMALL_rlsdla';
+SMALL.DL(3).param=struct(...
+    'Edata', Edata,...
+    'initdict', SMALL.Problem.initdict,...
+    'dictsize', SMALL.Problem.p,...
+    'forgettingMode', 'FIX',...
+    'forgettingFactor', lambda,...
+    'show_dict', 1000);
+
+
+SMALL.DL(3) = SMALL_learn(SMALL.Problem, SMALL.DL(3));
+
+%   Initialising solver structure
+%   Setting solver structure fields (toolbox, name, param, solution,
+%   reconstructed and time) to zero values
+
+SMALL.Problem.A = SMALL.DL(3).D;
+SMALL.Problem.reconstruct = @(x) ImgDenoise_reconstruct(x, SMALL.Problem);
+
+SMALL.solver(3)=SMALL_init_solver;
+
+% Defining the parameters needed for image denoising
+
+SMALL.solver(3).toolbox='ompbox';
+SMALL.solver(3).name='omp2';
+SMALL.solver(3).param=struct(...
+    'epsilon',Edata,...
+    'maxatoms', maxatoms); 
+
+
+SMALL.solver(3)=SMALL_solve(SMALL.Problem, SMALL.solver(3));
+
+SMALL.solver(3).reconstructed.psnr
+
+
+% show results %
+
+SMALL_ImgDeNoiseResult(SMALL);
+
+results(noise_ind,im_num).psnr.ksvd=SMALL.solver(1).reconstructed.psnr;
+results(noise_ind,im_num).psnr.odct=SMALL.solver(2).reconstructed.psnr;
+results(noise_ind,im_num).psnr.rlsdla=SMALL.solver(3).reconstructed.psnr;
+results(noise_ind,im_num).vmrse.ksvd=SMALL.solver(1).reconstructed.vmrse;
+results(noise_ind,im_num).vmrse.odct=SMALL.solver(2).reconstructed.vmrse;
+results(noise_ind,im_num).vmrse.rlsdla=SMALL.solver(3).reconstructed.vmrse;
+results(noise_ind,im_num).ssim.ksvd=SMALL.solver(1).reconstructed.ssim;
+results(noise_ind,im_num).ssim.odct=SMALL.solver(2).reconstructed.ssim;
+results(noise_ind,im_num).ssim.rlsdla=SMALL.solver(3).reconstructed.ssim;
+
+results(noise_ind,im_num).time.ksvd=SMALL.solver(1).time+SMALL.DL(1).time;
+results(noise_ind,im_num).time.rlsdla.time=SMALL.solver(3).time+SMALL.DL(3).time;
+clear SMALL;
+end
+end
+% save results.mat results
--- a/util/SMALL_init_DL.m	Tue Jul 26 16:01:17 2011 +0100
+++ b/util/SMALL_init_DL.m	Thu Jul 28 15:49:32 2011 +0100
@@ -1,4 +1,4 @@
-function DL = SMALL_init_DL(varargin)
+function DL = SMALL_init_DL(toolbox, name, param, profile)
 %%   Function initialise SMALL structure for Dictionary Learning.
 %   Optional input variables:
 %       toolbox - name of Dictionary Learning toolbox you want to use
@@ -17,9 +17,27 @@
 %
 %%
 
-DL.toolbox=[];
-DL.name=[];
-DL.param=[];
+if ~ exist( 'toolbox', 'var' ) || isempty(toolbox) 
+    DL.toolbox = []; 
+else
+    DL.toolbox = toolbox;
+end
+if ~ exist( 'name', 'var' ) || isempty(name) 
+    DL.name = [];
+else
+    DL.name = name;
+end
+if ~ exist( 'param', 'var' ) || isempty(param) 
+    DL.param = [];
+else
+    DL.param = param;
+end
+if ~ exist( 'profile', 'var' ) || isempty(profile) 
+    DL.profile = 1;
+else
+    DL.profile = profile;
+end
+
 DL.D=[];
 DL.time=[];
 end
\ No newline at end of file
--- a/util/SMALL_learn.m	Tue Jul 26 16:01:17 2011 +0100
+++ b/util/SMALL_learn.m	Thu Jul 28 15:49:32 2011 +0100
@@ -18,8 +18,9 @@
 %   License, or (at your option) any later version.  See the file
 %   COPYING included with this distribution for more information.
 %%
- 
+if (DL.profile)
   fprintf('\nStarting Dictionary Learning %s... \n', DL.name);
+end
   start=cputime;
   tStart=tic;
   if strcmpi(DL.toolbox,'KSVD')
@@ -58,6 +59,17 @@
         D(:,i)=D(:,i)/norm(D(:,i));
     end
     
+   elseif strcmpi(DL.toolbox,'TwoStepDL')
+        
+    DL=SMALL_two_step_DL(Problem, DL);
+    
+    %   we need to make sure that columns are normalised to
+    %   unit lenght.
+    
+    for i = 1: size(DL.D,2)
+        DL.D(:,i)=DL.D(:,i)/norm(DL.D(:,i));
+    end
+    D = DL.D; 
 %   To introduce new dictionary learning technique put the files in
 %   your Matlab path. Next, unique name <TolboxID> for your toolbox needs 
 %   to be defined and also prefferd API for toolbox functions <Preffered_API>
@@ -83,9 +95,11 @@
 %%
 %   Dictionary Learning time
 tElapsed=toc(tStart);
-  DL.time = cputime - start;
+DL.time = cputime - start;
+if (DL.profile)
   fprintf('\n%s finished task in %2f seconds (cpu time). \n', DL.name, DL.time);
   fprintf('\n%s finished task in %2f seconds (tic-toc time). \n', DL.name, tElapsed);
+end
 DL.time=tElapsed;
 %   If dictionary is given as a sparse matrix change it to full