changeset 175:9eb5f0d4c1a4 danieleb

added MOCOD dictionary update
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Thu, 17 Nov 2011 11:17:00 +0000
parents dc2f0fa21310
children d0645d5fca7d
files DL/two-step DL/dico_update.m
diffstat 1 files changed, 117 insertions(+), 104 deletions(-) [+]
line wrap: on
line diff
--- a/DL/two-step DL/dico_update.m	Thu Nov 17 11:16:15 2011 +0000
+++ b/DL/two-step DL/dico_update.m	Thu Nov 17 11:17:00 2011 +0000
@@ -1,107 +1,120 @@
-function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
+function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams)
 
-    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-    % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
-    %
-    % perform one iteration of dictionary update for dictionary learning
-    %
-    % parameters:
-    % - dico: the initial dictionary with atoms as columns
-    % - sig: the training data
-    % - amp: the amplitude coefficients as a sparse matrix
-    % - type: the algorithm can be one of the following
-    %   - ols: fixed step gradient descent
-    %   - mailhe: optimal step gradient descent (can be implemented as a
-    %   default for ols?)
-    %   - MOD: pseudo-inverse of the coefficients
-    %   - KSVD: already implemented by Elad
-    % - flow: 'sequential' or 'parallel'. If sequential, the residual is
-    % updated after each atom update. If parallel, the residual is only
-    % updated once the whole dictionary has been computed. Sequential works
-    % better, there may be no need to implement parallel. Not used with
-    % MOD.
-    % - rho: learning rate. If the type is 'ols', it is the descent step of
-    % the gradient (typical choice: 0.1). If the type is 'mailhe', the 
-    % descent step is the optimal step*rho (typical choice: 1, although 2
-    % or 3 seems to work better). Not used for MOD and KSVD.
-    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-    if ~ exist( 'rho', 'var' ) || isempty(rho)
-        rho = 0.1;
-    end
-    
-    if ~ exist( 'flow', 'var' ) || isempty(flow)
-        flow = sequential;
-    end
-    
-    res = sig - dico*amp;
-    nb_pattern = size(dico, 2);
-    
-    switch type
-        case 'rand'
-            x = rand();
-            if x < 1/3
-                type = 'MOD';
-            elseif type < 2/3
-                type = 'mailhe';
-            else
-                type = 'KSVD';
-            end
-    end
-    
-    switch type
-        case 'MOD'
-            G = amp*amp';
-            dico2 = sig*amp'*G^-1;
-            for p = 1:nb_pattern
-                n = norm(dico2(:,p));
-                % renormalize
-                if n > 0
-                    dico(:,p) = dico2(:,p)/n;
-                    amp(p,:) = amp(p,:)*n;
-                end
-            end
-        case 'ols'
-            for p = 1:nb_pattern
-                grad = res*amp(p,:)';
-                if norm(grad) > 0
-                    pat = dico(:,p) + rho*grad;
-                    pat = pat/norm(pat);
-                    if nargin >5 && strcmp(flow, 'sequential')
-                        res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
-                    end
-                    dico(:,p) = pat;
-                end
-            end
-        case 'mailhe'
-            for p = 1:nb_pattern
-                grad = res*amp(p,:)';
-                if norm(grad) > 0
-                    pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
-                    pat = pat/norm(pat);
-                    if nargin >5 && strcmp(flow, 'sequential')
-                        res = res + (dico(:,p)-pat)*amp(p,:);
-                    end
-                    dico(:,p) = pat;
-                end
-            end
-        case 'KSVD'
-            for p = 1:nb_pattern
-                index = find(amp(p,:)~=0);
-                if ~isempty(index)
-                    patch = res(:,index)+dico(:,p)*amp(p,index);
-                    [U,S,V] = svd(patch);
-                    if U(:,1)'*dico(:,p) > 0
-                        dico(:,p) = U(:,1);
-                    else
-                        dico(:,p) = -U(:,1);
-                    end
-                    dico(:,p) = dico(:,p)/norm(dico(:,p));
-                    amp(p,index) = dico(:,p)'*patch;
-                    if nargin >5 && strcmp(flow, 'sequential')
-                        res(:,index) = patch-dico(:,p)*amp(p,index);
-                    end
-                end
-            end
-    end
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+% [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
+%
+% perform one iteration of dictionary update for dictionary learning
+%
+% parameters:
+% - dico: the initial dictionary with atoms as columns
+% - sig: the training data
+% - amp: the amplitude coefficients as a sparse matrix
+% - type: the algorithm can be one of the following
+%   - ols: fixed step gradient descent
+%   - mailhe: optimal step gradient descent (can be implemented as a
+%   default for ols?)
+%   - MOD: pseudo-inverse of the coefficients
+%   - KSVD: already implemented by Elad
+% - flow: 'sequential' or 'parallel'. If sequential, the residual is
+% updated after each atom update. If parallel, the residual is only
+% updated once the whole dictionary has been computed. Sequential works
+% better, there may be no need to implement parallel. Not used with
+% MOD.
+% - rho: learning rate. If the type is 'ols', it is the descent step of
+% the gradient (typical choice: 0.1). If the type is 'mailhe', the
+% descent step is the optimal step*rho (typical choice: 1, although 2
+% or 3 seems to work better). Not used for MOD and KSVD.
+% - mocodParams: struct containing the parameters for the MOCOD dictionary
+% update (see Ramirez et Al., Sparse modeling with universal priors and
+% learned incoherent dictionaries). The required fields are
+%	.Dprev: dictionary at previous optimisation step
+%	.zeta: coherence regularization factor
+%	.eta: atoms norm regularisation factor
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+if ~ exist( 'rho', 'var' ) || isempty(rho)
+	rho = 0.1;
 end
 
+if ~ exist( 'flow', 'var' ) || isempty(flow)
+	flow = sequential;
+end
+
+res = sig - dico*amp;
+nb_pattern = size(dico, 2);
+
+switch type
+	case 'rand'
+		x = rand();
+		if x < 1/3
+			type = 'MOD';
+		elseif type < 2/3
+			type = 'mailhe';
+		else
+			type = 'KSVD';
+		end
+end
+
+switch upper(type)
+	case 'MOD'
+		G = amp*amp';
+		dico2 = sig*amp'*G^-1;
+		for p = 1:nb_pattern
+			n = norm(dico2(:,p));
+			% renormalize
+			if n > 0
+				dico(:,p) = dico2(:,p)/n;
+				amp(p,:) = amp(p,:)*n;
+			end
+		end
+	case 'OLS'
+		for p = 1:nb_pattern
+			grad = res*amp(p,:)';
+			if norm(grad) > 0
+				pat = dico(:,p) + rho*grad;
+				pat = pat/norm(pat);
+				if nargin >5 && strcmp(flow, 'sequential')
+					res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
+				end
+				dico(:,p) = pat;
+			end
+		end
+	case 'MAILHE'
+		for p = 1:nb_pattern
+			grad = res*amp(p,:)';
+			if norm(grad) > 0
+				pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
+				pat = pat/norm(pat);
+				if nargin >5 && strcmp(flow, 'sequential')
+					res = res + (dico(:,p)-pat)*amp(p,:);
+				end
+				dico(:,p) = pat;
+			end
+		end
+	case 'KSVD'
+		for p = 1:nb_pattern
+			index = find(amp(p,:)~=0);
+			if ~isempty(index)
+				patch = res(:,index)+dico(:,p)*amp(p,index);
+				[U,~,V] = svd(patch);
+				if U(:,1)'*dico(:,p) > 0
+					dico(:,p) = U(:,1);
+				else
+					dico(:,p) = -U(:,1);
+				end
+				dico(:,p) = dico(:,p)/norm(dico(:,p));
+				amp(p,index) = dico(:,p)'*patch;
+				if nargin >5 && strcmp(flow, 'sequential')
+					res(:,index) = patch-dico(:,p)*amp(p,index);
+				end
+			end
+		end
+	case 'MOCOD'
+		zeta = mocodParams.zeta;
+		eta = mocodParams.eta;
+		Dprev = mocodParams.Dprev;
+		
+		dico = (sig*amp' + 2*(zeta+eta)*Dprev)/...
+			(amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev)));
+end
+end
+