view DL/two-step DL/dico_update.m @ 152:485747bf39e0 ivand_dev

Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Thu, 28 Jul 2011 15:49:32 +0100
parents
children 9eb5f0d4c1a4 5140b0e06c22
line wrap: on
line source
function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
    %
    % perform one iteration of dictionary update for dictionary learning
    %
    % parameters:
    % - dico: the initial dictionary with atoms as columns
    % - sig: the training data
    % - amp: the amplitude coefficients as a sparse matrix
    % - type: the algorithm can be one of the following
    %   - ols: fixed step gradient descent
    %   - mailhe: optimal step gradient descent (can be implemented as a
    %   default for ols?)
    %   - MOD: pseudo-inverse of the coefficients
    %   - KSVD: already implemented by Elad
    % - flow: 'sequential' or 'parallel'. If sequential, the residual is
    % updated after each atom update. If parallel, the residual is only
    % updated once the whole dictionary has been computed. Sequential works
    % better, there may be no need to implement parallel. Not used with
    % MOD.
    % - rho: learning rate. If the type is 'ols', it is the descent step of
    % the gradient (typical choice: 0.1). If the type is 'mailhe', the 
    % descent step is the optimal step*rho (typical choice: 1, although 2
    % or 3 seems to work better). Not used for MOD and KSVD.
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    if ~ exist( 'rho', 'var' ) || isempty(rho)
        rho = 0.1;
    end
    
    if ~ exist( 'flow', 'var' ) || isempty(flow)
        flow = sequential;
    end
    
    res = sig - dico*amp;
    nb_pattern = size(dico, 2);
    
    switch type
        case 'rand'
            x = rand();
            if x < 1/3
                type = 'MOD';
            elseif type < 2/3
                type = 'mailhe';
            else
                type = 'KSVD';
            end
    end
    
    switch type
        case 'MOD'
            G = amp*amp';
            dico2 = sig*amp'*G^-1;
            for p = 1:nb_pattern
                n = norm(dico2(:,p));
                % renormalize
                if n > 0
                    dico(:,p) = dico2(:,p)/n;
                    amp(p,:) = amp(p,:)*n;
                end
            end
        case 'ols'
            for p = 1:nb_pattern
                grad = res*amp(p,:)';
                if norm(grad) > 0
                    pat = dico(:,p) + rho*grad;
                    pat = pat/norm(pat);
                    if nargin >5 && strcmp(flow, 'sequential')
                        res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
                    end
                    dico(:,p) = pat;
                end
            end
        case 'mailhe'
            for p = 1:nb_pattern
                grad = res*amp(p,:)';
                if norm(grad) > 0
                    pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
                    pat = pat/norm(pat);
                    if nargin >5 && strcmp(flow, 'sequential')
                        res = res + (dico(:,p)-pat)*amp(p,:);
                    end
                    dico(:,p) = pat;
                end
            end
        case 'KSVD'
            for p = 1:nb_pattern
                index = find(amp(p,:)~=0);
                if ~isempty(index)
                    patch = res(:,index)+dico(:,p)*amp(p,index);
                    [U,S,V] = svd(patch);
                    if U(:,1)'*dico(:,p) > 0
                        dico(:,p) = U(:,1);
                    else
                        dico(:,p) = -U(:,1);
                    end
                    dico(:,p) = dico(:,p)/norm(dico(:,p));
                    amp(p,index) = dico(:,p)'*patch;
                    if nargin >5 && strcmp(flow, 'sequential')
                        res(:,index) = patch-dico(:,p)*amp(p,index);
                    end
                end
            end
    end
end