diff DL/Majorization Minimization DL/wrapper_mm_DL.m @ 155:b14209313ba4 ivand_dev

Integration of Majorization Minimisation Dictionary Learning
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Mon, 22 Aug 2011 11:46:35 +0100
parents
children 0c7c20f3246c
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/DL/Majorization Minimization DL/wrapper_mm_DL.m	Mon Aug 22 11:46:35 2011 +0100
@@ -0,0 +1,185 @@
+function DL = wrapper_mm_DL(Problem, DL)
+
+% determine which solver is used for sparse representation %
+
+solver = DL.param.solver;
+
+% determine which type of udate to use 
+%   (Mehrdad Yaghoobi implementations: 'MM_cn', MM_fn', 'MOD_cn',
+%   'MAP_cn', 'KSVD_cn')
+
+typeUpdate = DL.name;
+
+sig = Problem.b;
+
+% determine dictionary size %
+
+if (isfield(DL.param,'initdict'))
+  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
+    dictsize = length(DL.param.initdict);
+  else
+    dictsize = size(DL.param.initdict,2);
+  end
+end
+
+if (isfield(DL.param,'dictsize'))    % this superceedes the size determined by initdict
+  dictsize = DL.param.dictsize;
+end
+
+if (size(sig,2) < dictsize)
+  error('Number of training signals is smaller than number of atoms to train');
+end
+
+
+% initialize the dictionary %
+
+if (isfield(DL.param,'initdict'))
+  if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
+    dico = sig(:,DL.param.initdict(1:dictsize));
+  else
+    if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
+      error('Invalid initial dictionary');
+    end
+    dico = DL.param.initdict(:,1:dictsize);
+  end
+else
+  data_ids = find(colnorms_squared(sig) > 1e-6);   % ensure no zero data elements are chosen
+  perm = randperm(length(data_ids));
+  dico = sig(:,data_ids(perm(1:dictsize)));
+end
+
+
+% number of iterations (default is 40) %
+
+if isfield(DL.param,'iternum')
+    iternum = DL.param.iternum;
+else
+    iternum = 40;
+end
+
+% number of iterations (default is 40) %
+
+if isfield(DL.param,'iterDictUpdate')
+    maxIT = DL.param.iterDictUpdate;
+else
+    maxIT = 1000;
+end
+
+% Stopping criterion for MM dictionary update (default = 1e-7)
+
+if isfield(DL.param,'epsDictUpdate')
+    epsD = DL.param.epsDictUpdate;
+else
+    epsD = 1e-7;
+end
+
+% Dictionary constraint - 0 = Non convex ||d|| = 1, 1 = Convex ||d||<=1
+% (default cvset is o) %
+
+if isfield(DL.param,'cvset')
+    cvset = DL.param.cvset;
+else
+    cvset = 0;
+end
+
+% determine if we should do decorrelation in every iteration  %
+
+if isfield(DL.param,'coherence')
+    decorrelate = 1;
+    mu = DL.param.coherence;
+else
+    decorrelate = 0;
+end
+
+% show dictonary every specified number of iterations
+
+if isfield(DL.param,'show_dict')
+    show_dictionary = 1;
+    show_iter = DL.param.show_dict;
+else
+    show_dictionary = 0;
+    show_iter = 0;
+end
+
+% This is a small patch that needs to be resolved in dictionary learning we
+% want sparse representation of training set, and in Problem.b1 in this
+% version of software we store the signal that needs to be represented
+% (for example the whole image)
+if isfield(Problem,'b1')
+    tmpTraining = Problem.b1;
+    Problem.b1 = sig;
+end
+if isfield(Problem,'reconstruct')
+    Problem = rmfield(Problem, 'reconstruct');
+end
+solver.profile = 0;
+
+% main loop %
+
+for i = 1:iternum
+    Problem.A = dico;
+    
+    solver = SMALL_solve(Problem, solver);
+    
+    switch lower(typeUpdate)
+        case 'mm_cn'
+            [dico, solver.solution] = ...
+                dict_update_REG_cn(dico, sig, solver.solution, maxIT, epsD, cvset);
+        case 'mm_fn'
+            [dico, solver.solution] = ...
+                dict_update_REG_fn(dico, sig, solver.solution, maxIT, epsD, cvset);
+        case 'mod_cn'
+            [dico, solver.solution] = dict_update_MOD_cn(sig, solver.solution, cvset);
+        case 'map_cn'
+            if isfield(DL.param,'muMAP')
+                muMAP = DL.param.muMAP;
+            else
+                muMAP = 1e-4;
+            end
+            [dico, solver.solution] = ...
+                dict_update_MAP_cn(dico, sig, solver.solution, muMAP, maxIT, epsD, cvset);
+        case 'ksvd_cn'
+            [dico, solver.solution] = dict_update_KSVD_cn(dico, sig, solver.solution);
+        otherwise
+            error('Dictionary update is not defined');
+    end
+    
+    % Set previous solution as the best initial guess
+    % for the next iteration of iterative soft tresholding
+    
+    if (strcmpi(solver.toolbox, 'MMbox'))
+        solver.param.initcoeff = solver.solution;
+    end
+    
+    % Optional decorrelation of athoms - this is from Boris Mailhe and
+    % we need to test how it preforms with Mehrdad's updates
+    
+    if (decorrelate)
+        dico = dico_decorr(dico, mu, solver.solution);
+    end
+    
+    if ((show_dictionary)&&(mod(i,show_iter)==0))
+        dictimg = SMALL_showdict(dico,[8 8],...
+            round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
+        figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
+        pause(0.02);
+    end
+end
+if isfield(Problem,'b1')
+    Problem.b1 = tmpTraining;
+end
+DL.D = dico;
+
+end
+
+function Y = colnorms_squared(X)
+
+% compute in blocks to conserve memory
+Y = zeros(1,size(X,2));
+blocksize = 2000;
+for i = 1:blocksize:size(X,2)
+  blockids = i : min(i+blocksize-1,size(X,2));
+  Y(blockids) = sum(X(:,blockids).^2);
+end
+
+end
\ No newline at end of file