changeset 201:5140b0e06c22 luisf_dev

Added separate dictionary decorrelation
author bmailhe
date Tue, 20 Mar 2012 14:50:35 +0000
parents 69ce11724b1f
children 5bb579c9874e
files DL/two-step DL/dico_update.m
diffstat 1 files changed, 34 insertions(+), 21 deletions(-) [+]
line wrap: on
line diff
--- a/DL/two-step DL/dico_update.m	Tue Mar 20 12:25:50 2012 +0000
+++ b/DL/two-step DL/dico_update.m	Tue Mar 20 14:50:35 2012 +0000
@@ -1,5 +1,5 @@
 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
-
+    
     %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
     % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
     %
@@ -10,44 +10,56 @@
     % - sig: the training data
     % - amp: the amplitude coefficients as a sparse matrix
     % - type: the algorithm can be one of the following
-    %   - ols: fixed step gradient descent
-    %   - mailhe: optimal step gradient descent (can be implemented as a
-    %   default for ols?)
-    %   - MOD: pseudo-inverse of the coefficients
-    %   - KSVD: already implemented by Elad
+    %   - ols: fixed step gradient descent, as described in Olshausen &
+    %   Field95
+    %   - opt: optimal step gradient descent, as described in Mailhe et
+    %   al.08
+    %   - MOD: pseudo-inverse of the coefficients, as described in Engan99
+    %   - KSVD: PCA update as described in Aharon06. For fast applications,
+    %   use KSVDbox rather than this code.
+    %   - LGD: large step gradient descent. Equivalent to 'opt' with
+    %          rho=2.
     % - flow: 'sequential' or 'parallel'. If sequential, the residual is
     % updated after each atom update. If parallel, the residual is only
-    % updated once the whole dictionary has been computed. Sequential works
-    % better, there may be no need to implement parallel. Not used with
+    % updated once the whole dictionary has been computed.
+    % Default: Sequential (sequential usually works better). Not used with
     % MOD.
     % - rho: learning rate. If the type is 'ols', it is the descent step of
-    % the gradient (typical choice: 0.1). If the type is 'mailhe', the 
-    % descent step is the optimal step*rho (typical choice: 1, although 2
-    % or 3 seems to work better). Not used for MOD and KSVD.
+    % the gradient (default: 0.1). If the type is 'opt', the
+    % descent step is the optimal step*rho (default: 1, although 2 works
+    % better. See LGD for more details). Not used for MOD and KSVD.
     %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-    if ~ exist( 'rho', 'var' ) || isempty(rho)
-        rho = 0.1;
-    end
     
     if ~ exist( 'flow', 'var' ) || isempty(flow)
-        flow = sequential;
+        flow = 'sequential';
     end
     
     res = sig - dico*amp;
     nb_pattern = size(dico, 2);
     
+    % if the type is random, then randomly pick another type
     switch type
         case 'rand'
             x = rand();
             if x < 1/3
                 type = 'MOD';
             elseif type < 2/3
-                type = 'mailhe';
+                type = 'opt';
             else
                 type = 'KSVD';
             end
     end
     
+    % set the learning rate to default if not provided
+    if ~ exist( 'rho', 'var' ) || isempty(rho)
+        switch type
+            case 'ols'
+                rho = 0.1;
+            case 'opt'
+                rho = 1;
+        end
+    end
+    
     switch type
         case 'MOD'
             G = amp*amp';
@@ -72,14 +84,15 @@
                     dico(:,p) = pat;
                 end
             end
-        case 'mailhe'
+        case 'opt'
             for p = 1:nb_pattern
-                grad = res*amp(p,:)';
+                vec : amp(p,:);
+                grad = res*vec';
                 if norm(grad) > 0
-                    pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
+                    pat = (vec*vec')*dico(:,p) + rho*grad;
                     pat = pat/norm(pat);
                     if nargin >5 && strcmp(flow, 'sequential')
-                        res = res + (dico(:,p)-pat)*amp(p,:);
+                        res = res + (dico(:,p)-pat)*vec;
                     end
                     dico(:,p) = pat;
                 end
@@ -89,7 +102,7 @@
                 index = find(amp(p,:)~=0);
                 if ~isempty(index)
                     patch = res(:,index)+dico(:,p)*amp(p,index);
-                    [U,S,V] = svd(patch);
+                    [U,~,V] = svd(patch);
                     if U(:,1)'*dico(:,p) > 0
                         dico(:,p) = U(:,1);
                     else