diff DL/two-step DL/dico_update.m @ 152:485747bf39e0 ivand_dev

Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Thu, 28 Jul 2011 15:49:32 +0100
parents
children 9eb5f0d4c1a4 5140b0e06c22
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/DL/two-step DL/dico_update.m	Thu Jul 28 15:49:32 2011 +0100
@@ -0,0 +1,107 @@
+function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
+
+    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+    % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
+    %
+    % perform one iteration of dictionary update for dictionary learning
+    %
+    % parameters:
+    % - dico: the initial dictionary with atoms as columns
+    % - sig: the training data
+    % - amp: the amplitude coefficients as a sparse matrix
+    % - type: the algorithm can be one of the following
+    %   - ols: fixed step gradient descent
+    %   - mailhe: optimal step gradient descent (can be implemented as a
+    %   default for ols?)
+    %   - MOD: pseudo-inverse of the coefficients
+    %   - KSVD: already implemented by Elad
+    % - flow: 'sequential' or 'parallel'. If sequential, the residual is
+    % updated after each atom update. If parallel, the residual is only
+    % updated once the whole dictionary has been computed. Sequential works
+    % better, there may be no need to implement parallel. Not used with
+    % MOD.
+    % - rho: learning rate. If the type is 'ols', it is the descent step of
+    % the gradient (typical choice: 0.1). If the type is 'mailhe', the 
+    % descent step is the optimal step*rho (typical choice: 1, although 2
+    % or 3 seems to work better). Not used for MOD and KSVD.
+    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+    if ~ exist( 'rho', 'var' ) || isempty(rho)
+        rho = 0.1;
+    end
+    
+    if ~ exist( 'flow', 'var' ) || isempty(flow)
+        flow = sequential;
+    end
+    
+    res = sig - dico*amp;
+    nb_pattern = size(dico, 2);
+    
+    switch type
+        case 'rand'
+            x = rand();
+            if x < 1/3
+                type = 'MOD';
+            elseif type < 2/3
+                type = 'mailhe';
+            else
+                type = 'KSVD';
+            end
+    end
+    
+    switch type
+        case 'MOD'
+            G = amp*amp';
+            dico2 = sig*amp'*G^-1;
+            for p = 1:nb_pattern
+                n = norm(dico2(:,p));
+                % renormalize
+                if n > 0
+                    dico(:,p) = dico2(:,p)/n;
+                    amp(p,:) = amp(p,:)*n;
+                end
+            end
+        case 'ols'
+            for p = 1:nb_pattern
+                grad = res*amp(p,:)';
+                if norm(grad) > 0
+                    pat = dico(:,p) + rho*grad;
+                    pat = pat/norm(pat);
+                    if nargin >5 && strcmp(flow, 'sequential')
+                        res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
+                    end
+                    dico(:,p) = pat;
+                end
+            end
+        case 'mailhe'
+            for p = 1:nb_pattern
+                grad = res*amp(p,:)';
+                if norm(grad) > 0
+                    pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
+                    pat = pat/norm(pat);
+                    if nargin >5 && strcmp(flow, 'sequential')
+                        res = res + (dico(:,p)-pat)*amp(p,:);
+                    end
+                    dico(:,p) = pat;
+                end
+            end
+        case 'KSVD'
+            for p = 1:nb_pattern
+                index = find(amp(p,:)~=0);
+                if ~isempty(index)
+                    patch = res(:,index)+dico(:,p)*amp(p,index);
+                    [U,S,V] = svd(patch);
+                    if U(:,1)'*dico(:,p) > 0
+                        dico(:,p) = U(:,1);
+                    else
+                        dico(:,p) = -U(:,1);
+                    end
+                    dico(:,p) = dico(:,p)/norm(dico(:,p));
+                    amp(p,index) = dico(:,p)'*patch;
+                    if nargin >5 && strcmp(flow, 'sequential')
+                        res(:,index) = patch-dico(:,p)*amp(p,index);
+                    end
+                end
+            end
+    end
+end
+