diff DL/two-step DL/SMALL_two_step_DL.m @ 152:485747bf39e0 ivand_dev

Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Thu, 28 Jul 2011 15:49:32 +0100
parents
children af307f247ac7
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/DL/two-step DL/SMALL_two_step_DL.m	Thu Jul 28 15:49:32 2011 +0100
@@ -0,0 +1,127 @@
+function DL=SMALL_two_step_DL(Problem, DL)
+
+% determine which solver is used for sparse representation %
+
+solver = DL.param.solver;
+
+% determine which type of udate to use ('KSVD', 'MOD', 'ols' or 'mailhe') %
+
+typeUpdate = DL.name;
+
+sig = Problem.b;
+
+% determine dictionary size %
+
+if (isfield(DL.param,'initdict'))
+  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
+    dictsize = length(DL.param.initdict);
+  else
+    dictsize = size(DL.param.initdict,2);
+  end
+end
+if (isfield(DL.param,'dictsize'))    % this superceedes the size determined by initdict
+  dictsize = DL.param.dictsize;
+end
+
+if (size(sig,2) < dictsize)
+  error('Number of training signals is smaller than number of atoms to train');
+end
+
+
+% initialize the dictionary %
+
+if (isfield(DL.param,'initdict'))
+  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
+    dico = sig(:,DL.param.initdict(1:dictsize));
+  else
+    if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
+      error('Invalid initial dictionary');
+    end
+    dico = DL.param.initdict(:,1:dictsize);
+  end
+else
+  data_ids = find(colnorms_squared(sig) > 1e-6);   % ensure no zero data elements are chosen
+  perm = randperm(length(data_ids));
+  dico = sig(:,data_ids(perm(1:dictsize)));
+end
+
+% flow: 'sequential' or 'parallel'. If sequential, the residual is updated 
+% after each atom update. If parallel, the residual is only updated once 
+% the whole dictionary has been computed. Sequential works better, there 
+% may be no need to implement parallel. Not used with MOD.
+
+if isfield(DL.param,'flow')
+    flow =  DL.param.flow;
+else
+    flow = 'sequential';
+end
+
+% learningRate. If the type is 'ols', it is the descent step of
+% the gradient (typical choice: 0.1). If the type is 'mailhe', the 
+% descent step is the optimal step*rho (typical choice: 1, although 2
+% or 3 seems to work better). Not used for MOD and KSVD.
+
+if isfield(DL.param,'learningRate')
+    learningRate = DL.param.learningRate;
+else
+    learningRate = 0.1;
+end
+
+% number of iterations (default is 40) %
+
+if isfield(DL.param,'iternum')
+    iternum = DL.param.iternum;
+else
+    iternum = 40;
+end
+% determine if we should do decorrelation in every iteration  %
+
+if isfield(DL.param,'coherence')
+    decorrelate = 1;
+    mu = DL.param.coherence;
+else
+    decorrelate = 0;
+end
+
+% show dictonary every specified number of iterations
+
+if (isfield(DL.param,'show_dict'))
+    show_dictionary=1;
+    show_iter=DL.param.show_dict;
+else
+    show_dictionary=0;
+    show_iter=0;
+end
+
+% This is a small patch that needs to be resolved in dictionary learning we
+% want sparse representation of training set, and in Problem.b1 in this
+% version of software we store the signal that needs to be represented
+% (for example the whole image)
+
+tmpTraining = Problem.b1;
+Problem.b1 = sig;
+Problem = rmfield(Problem, 'reconstruct');
+solver.profile = 0;
+
+% main loop %
+
+for i = 1:iternum
+    solver = SMALL_solve(Problem, solver);
+    [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
+        typeUpdate, flow, learningRate);
+    if (decorrelate)
+        dico = dico_decorr(dico, mu, solver.solution);
+    end
+    Problem.A = dico;
+   if ((show_dictionary)&&(mod(i,show_iter)==0))
+       dictimg = SMALL_showdict(dico,[8 8],...
+            round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');  
+       figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
+       pause(0.02);
+   end
+end
+
+Problem.b1 = tmpTraining;
+DL.D = dico;
+
+end
\ No newline at end of file