changeset 202:5bb579c9874e luisf_dev

Merge
author bmailhe
date Tue, 20 Mar 2012 14:51:31 +0000
parents 5140b0e06c22 (diff) 751fa3bddd30 (current diff)
children 5fe60504a6a9 233e75809e4a
files config/SMALL_two_step_DL_config.m
diffstat 3 files changed, 235 insertions(+), 49 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/DL/two-step DL/dico_color_separate.m	Tue Mar 20 14:51:31 2012 +0000
@@ -0,0 +1,95 @@
+function [colors nbColors] = dico_color_separate(dico, mu)
+    % DICO_COLOR cluster several dictionaries in pairs of high correlation
+    % atoms. Called by dico_decorr.
+    %
+    % Parameters:
+    % -dico: the dictionaries
+    % -mu: the correlation threshold
+    %
+    % Result:
+    % -colors: a cell array of indices. Two atoms with the same color have
+    % a correlation greater than mu
+    
+    
+    numDico = length(dico);
+    colors = cell(numDico,1);
+    for i = 1:numDico
+        colors{i} = zeros(length(dico{i}),1);
+    end
+    
+    G = cell(numDico);
+    
+    % compute the correlations
+    for i = 1:numDico
+        for j = i+1:numDico
+            G{i,j} = abs(dico{i}'*dico{j});
+        end
+    end
+    
+    % iterate on the correlations higher than mu
+    c = 1;
+    [maxCorr, i, j, m, n] = findMaxCorr(G);
+    while maxCorr > mu
+        % find the highest correlated pair
+        
+        % color them
+        colors{i}(m) = c;
+        colors{j}(n) = c;
+        c = c+1;
+        
+        % make sure these atoms never get selected again
+        % Set to zero relevant lines in the Gram Matrix
+        for j2 = i+1:numDico
+            G{i,j2}(m,:) = 0;
+        end
+        
+        for i2 = 1:i-1
+            G{i2,i}(:,m) = 0;
+        end
+        
+        for j2 = j+1:numDico
+            G{j,j2}(n,:) = 0;
+        end
+        
+        for i2 = 1:j-1
+            G{i2,j}(:,n) = 0;
+        end
+        
+        % find the next correlation
+        [maxCorr, i, j, m, n] = findMaxCorr(G);
+    end
+    
+    % complete the coloring with singletons
+    % index = find(colors==0);
+    % colors(index) = c:c+length(index)-1;
+    nbColors = c-1;
+end
+
+function [val, i, j, m, n] = findMaxCorr(G)
+    %FINDMAXCORR find the maximal correlation in the cellular Gram matrix
+    %
+    %   Input:
+    %   -G: the Gram matrix
+    %
+    %   Output:
+    %   -val: value of the correlation
+    %   -i,j,m,n: indices of the argmax. The maximal correlation is reached
+    %   for the m^th atom of the i^th dictionary and the n^h atom of the
+    %   j^h dictionary
+    
+    val = -1;
+    for tmpI = 1:length(G)
+        for tmpJ = tmpI+1:length(G)
+            [tmpVal tmpM] = max(G{tmpI,tmpJ},[],1);
+            [tmpVal tmpN] = max(tmpVal);
+            if tmpVal > val
+                val = tmpVal; 
+                i = tmpI;
+                j = tmpJ;
+                n = tmpN;
+                m = tmpM(n);
+            end
+        end
+    end
+end
+        
\ No newline at end of file
--- a/DL/two-step DL/dico_decorr_symetric.m	Tue Mar 20 14:28:51 2012 +0000
+++ b/DL/two-step DL/dico_decorr_symetric.m	Tue Mar 20 14:51:31 2012 +0000
@@ -1,15 +1,17 @@
-function dico = dico_decorr_symetric(dico, mu, amp)
+function dico = dico_decorr_symetric(dico, mu)
     %DICO_DECORR decorrelate a dictionary
     %   Parameters:
-    %   dico: the dictionary
+    %   dico: the dictionary, either a matrix or a cell array of matrices.
     %   mu: the coherence threshold
-    %   amp: the amplitude coefficients, only used to decide which atom to
-    %   project
     %
     %   Result:
-    %   dico: a dictionary close to the input one with coherence mu.
+    %   dico: if the input dico was a matrix, then a matrix close to the 
+    %   input one with coherence mu.
+    %   If the input was a cell array, a cell array of the same size
+    %   containing matrices such that the coherence between different cells
+    %   is lower than mu.
     
-    eps = 1e-6; % define tolerance for normalisation term alpha
+    eps = 1e-3; % define tolerance for normalisation term alpha
     
     % convert mu to the to the mean direction
     theta = acos(mu)/2;
@@ -23,39 +25,115 @@
     %         rank = randperm(length(dico));
     %     end
     
-    % several decorrelation iterations might be needed to reach global
-    % coherence mu. niter can be adjusted to needs.
-    niter = 1;
-    while max(max(abs(dico'*dico -eye(length(dico))))) > mu + 0.01
-        % find pairs of high correlation atoms
-        colors = dico_color(dico, mu);
+    % if only one dictionary is provided, then decorrelate it
+    if ~iscell(dico)
+        % several decorrelation iterations might be needed to reach global
+        % coherence mu. niter can be adjusted to needs.
+        niter = 1;
+        while max(max(abs(dico'*dico -eye(length(dico))))) > mu + eps
+            % find pairs of high correlation atoms
+            colors = dico_color(dico, mu);
+            
+            % iterate on all pairs
+            nbColors = max(colors);
+            for c = 1:nbColors
+                index = find(colors==c);
+                if numel(index) == 2
+                    if dico(:,index(1))'*dico(:,index(2)) > 0
+                        %build the basis vectors
+                        v1 = dico(:,index(1))+dico(:,index(2));
+                        v1 = v1/norm(v1);
+                        v2 = dico(:,index(1))-dico(:,index(2));
+                        v2 = v2/norm(v2);
+                        
+                        dico(:,index(1)) = ctheta*v1+stheta*v2;
+                        dico(:,index(2)) = ctheta*v1-stheta*v2;
+                    else
+                        v1 = dico(:,index(1))-dico(:,index(2));
+                        v1 = v1/norm(v1);
+                        v2 = dico(:,index(1))+dico(:,index(2));
+                        v2 = v2/norm(v2);
+                        
+                        dico(:,index(1)) = ctheta*v1+stheta*v2;
+                        dico(:,index(2)) = -ctheta*v1+stheta*v2;
+                    end
+                end
+            end
+            niter = niter+1;
+        end
+        %if a cell array of dictionaries is provided, decorrelate among
+        %different dictionaries only
+    else
+        niter = 1;
+        numDicos = length(dico);
+        G = cell(numDicos);
+        maxCorr = 0;
+        for i = 1:numDicos
+            for j = i+1:numDicos
+                G{i,j} = dico{i}'*dico{j};
+                maxCorr = max(maxCorr,max(max(abs(G{i,j}))));
+            end
+        end
         
-        % iterate on all pairs
-        nbColors = max(colors);
-        for c = 1:nbColors
-            index = find(colors==c);
-            if numel(index) == 2
-                if dico(:,index(1))'*dico(:,index(2)) > 0               
+        while maxCorr > mu + eps
+            % find pairs of high correlation atoms
+            [colors nbColors] = dico_color_separate(dico, mu);
+
+            % iterate on all pairs
+            for c = 1:nbColors
+                for tmpI = 1:numDicos
+                    index = find(colors{tmpI}==c);
+                    if ~isempty(index)
+                        i = tmpI;
+                        m = index;
+                        break;
+                    end
+                end
+                for tmpJ = i+1:numDicos
+                    index = find(colors{tmpJ}==c);
+                    if ~isempty(index)
+                        j = tmpJ;
+                        n = index;
+                        break;
+                    end
+                end
+                
+                if dico{i}(:,m)'*dico{j}(:,n) > 0
                     %build the basis vectors
-                    v1 = dico(:,index(1))+dico(:,index(2));
+                    v1 = dico{i}(:,m)+dico{j}(:,n);
                     v1 = v1/norm(v1);
-                    v2 = dico(:,index(1))-dico(:,index(2));
+                    v2 = dico{i}(:,m)-dico{j}(:,n);
                     v2 = v2/norm(v2);
                     
-                    dico(:,index(1)) = ctheta*v1+stheta*v2;
-                    dico(:,index(2)) = ctheta*v1-stheta*v2;
+                    dico{i}(:,m) = ctheta*v1+stheta*v2;
+                    dico{j}(:,n) = ctheta*v1-stheta*v2;
                 else
-                    v1 = dico(:,index(1))-dico(:,index(2));
+                    v1 = dico{i}(:,m)-dico{j}(:,n);
                     v1 = v1/norm(v1);
-                    v2 = dico(:,index(1))+dico(:,index(2));
+                    v2 = dico{i}(:,m)+dico{j}(:,n);
                     v2 = v2/norm(v2);
                     
-                    dico(:,index(1)) = ctheta*v1+stheta*v2;
-                    dico(:,index(2)) = -ctheta*v1+stheta*v2;
+                    dico{i}(:,m) = ctheta*v1+stheta*v2;
+                    dico{j}(:,n) = -ctheta*v1+stheta*v2;
+                end
+            end
+            niter = niter+1;
+            
+            % Remove noegative components and renormalize
+            for i = 1:length(dico)
+                dico{i} = max(dico{i},0);
+                for m = 1:size(dico{i},2)
+                    dico{i}(:,m) = dico{i}(:,m)/norm(dico{i}(:,m));
+                end
+            end
+            
+            maxCorr = 0;
+            for i = 1:numDicos
+                for j = i+1:numDicos
+                    G{i,j} = dico{i}'*dico{j};
+                    maxCorr = max(maxCorr,max(max(abs(G{i,j}))));
                 end
             end
         end
-        niter = niter+1;
     end
 end
-
--- a/DL/two-step DL/dico_update.m	Tue Mar 20 14:28:51 2012 +0000
+++ b/DL/two-step DL/dico_update.m	Tue Mar 20 14:51:31 2012 +0000
@@ -1,5 +1,5 @@
 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
-
+    
     %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
     % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
     %
@@ -10,44 +10,56 @@
     % - sig: the training data
     % - amp: the amplitude coefficients as a sparse matrix
     % - type: the algorithm can be one of the following
-    %   - ols: fixed step gradient descent
-    %   - mailhe: optimal step gradient descent (can be implemented as a
-    %   default for ols?)
-    %   - MOD: pseudo-inverse of the coefficients
-    %   - KSVD: already implemented by Elad
+    %   - ols: fixed step gradient descent, as described in Olshausen &
+    %   Field95
+    %   - opt: optimal step gradient descent, as described in Mailhe et
+    %   al.08
+    %   - MOD: pseudo-inverse of the coefficients, as described in Engan99
+    %   - KSVD: PCA update as described in Aharon06. For fast applications,
+    %   use KSVDbox rather than this code.
+    %   - LGD: large step gradient descent. Equivalent to 'opt' with
+    %          rho=2.
     % - flow: 'sequential' or 'parallel'. If sequential, the residual is
     % updated after each atom update. If parallel, the residual is only
-    % updated once the whole dictionary has been computed. Sequential works
-    % better, there may be no need to implement parallel. Not used with
+    % updated once the whole dictionary has been computed.
+    % Default: Sequential (sequential usually works better). Not used with
     % MOD.
     % - rho: learning rate. If the type is 'ols', it is the descent step of
-    % the gradient (typical choice: 0.1). If the type is 'mailhe', the 
-    % descent step is the optimal step*rho (typical choice: 1, although 2
-    % or 3 seems to work better). Not used for MOD and KSVD.
+    % the gradient (default: 0.1). If the type is 'opt', the
+    % descent step is the optimal step*rho (default: 1, although 2 works
+    % better. See LGD for more details). Not used for MOD and KSVD.
     %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-    if ~ exist( 'rho', 'var' ) || isempty(rho)
-        rho = 0.1;
-    end
     
     if ~ exist( 'flow', 'var' ) || isempty(flow)
-        flow = sequential;
+        flow = 'sequential';
     end
     
     res = sig - dico*amp;
     nb_pattern = size(dico, 2);
     
+    % if the type is random, then randomly pick another type
     switch type
         case 'rand'
             x = rand();
             if x < 1/3
                 type = 'MOD';
             elseif type < 2/3
-                type = 'mailhe';
+                type = 'opt';
             else
                 type = 'KSVD';
             end
     end
     
+    % set the learning rate to default if not provided
+    if ~ exist( 'rho', 'var' ) || isempty(rho)
+        switch type
+            case 'ols'
+                rho = 0.1;
+            case 'opt'
+                rho = 1;
+        end
+    end
+    
     switch type
         case 'MOD'
             G = amp*amp';
@@ -72,14 +84,15 @@
                     dico(:,p) = pat;
                 end
             end
-        case 'mailhe'
+        case 'opt'
             for p = 1:nb_pattern
-                grad = res*amp(p,:)';
+                vec : amp(p,:);
+                grad = res*vec';
                 if norm(grad) > 0
-                    pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
+                    pat = (vec*vec')*dico(:,p) + rho*grad;
                     pat = pat/norm(pat);
                     if nargin >5 && strcmp(flow, 'sequential')
-                        res = res + (dico(:,p)-pat)*amp(p,:);
+                        res = res + (dico(:,p)-pat)*vec;
                     end
                     dico(:,p) = pat;
                 end
@@ -89,7 +102,7 @@
                 index = find(amp(p,:)~=0);
                 if ~isempty(index)
                     patch = res(:,index)+dico(:,p)*amp(p,index);
-                    [U,S,V] = svd(patch);
+                    [U,~,V] = svd(patch);
                     if U(:,1)'*dico(:,p) > 0
                         dico(:,p) = U(:,1);
                     else