view DL/two-step DL/dico_update.m @ 175:9eb5f0d4c1a4 danieleb

added MOCOD dictionary update
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Thu, 17 Nov 2011 11:17:00 +0000
parents 485747bf39e0
children fd0b5d36f6ad
line wrap: on
line source
function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
%
% perform one iteration of dictionary update for dictionary learning
%
% parameters:
% - dico: the initial dictionary with atoms as columns
% - sig: the training data
% - amp: the amplitude coefficients as a sparse matrix
% - type: the algorithm can be one of the following
%   - ols: fixed step gradient descent
%   - mailhe: optimal step gradient descent (can be implemented as a
%   default for ols?)
%   - MOD: pseudo-inverse of the coefficients
%   - KSVD: already implemented by Elad
% - flow: 'sequential' or 'parallel'. If sequential, the residual is
% updated after each atom update. If parallel, the residual is only
% updated once the whole dictionary has been computed. Sequential works
% better, there may be no need to implement parallel. Not used with
% MOD.
% - rho: learning rate. If the type is 'ols', it is the descent step of
% the gradient (typical choice: 0.1). If the type is 'mailhe', the
% descent step is the optimal step*rho (typical choice: 1, although 2
% or 3 seems to work better). Not used for MOD and KSVD.
% - mocodParams: struct containing the parameters for the MOCOD dictionary
% update (see Ramirez et Al., Sparse modeling with universal priors and
% learned incoherent dictionaries). The required fields are
%	.Dprev: dictionary at previous optimisation step
%	.zeta: coherence regularization factor
%	.eta: atoms norm regularisation factor
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if ~ exist( 'rho', 'var' ) || isempty(rho)
	rho = 0.1;
end

if ~ exist( 'flow', 'var' ) || isempty(flow)
	flow = sequential;
end

res = sig - dico*amp;
nb_pattern = size(dico, 2);

switch type
	case 'rand'
		x = rand();
		if x < 1/3
			type = 'MOD';
		elseif type < 2/3
			type = 'mailhe';
		else
			type = 'KSVD';
		end
end

switch upper(type)
	case 'MOD'
		G = amp*amp';
		dico2 = sig*amp'*G^-1;
		for p = 1:nb_pattern
			n = norm(dico2(:,p));
			% renormalize
			if n > 0
				dico(:,p) = dico2(:,p)/n;
				amp(p,:) = amp(p,:)*n;
			end
		end
	case 'OLS'
		for p = 1:nb_pattern
			grad = res*amp(p,:)';
			if norm(grad) > 0
				pat = dico(:,p) + rho*grad;
				pat = pat/norm(pat);
				if nargin >5 && strcmp(flow, 'sequential')
					res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
				end
				dico(:,p) = pat;
			end
		end
	case 'MAILHE'
		for p = 1:nb_pattern
			grad = res*amp(p,:)';
			if norm(grad) > 0
				pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
				pat = pat/norm(pat);
				if nargin >5 && strcmp(flow, 'sequential')
					res = res + (dico(:,p)-pat)*amp(p,:);
				end
				dico(:,p) = pat;
			end
		end
	case 'KSVD'
		for p = 1:nb_pattern
			index = find(amp(p,:)~=0);
			if ~isempty(index)
				patch = res(:,index)+dico(:,p)*amp(p,index);
				[U,~,V] = svd(patch);
				if U(:,1)'*dico(:,p) > 0
					dico(:,p) = U(:,1);
				else
					dico(:,p) = -U(:,1);
				end
				dico(:,p) = dico(:,p)/norm(dico(:,p));
				amp(p,index) = dico(:,p)'*patch;
				if nargin >5 && strcmp(flow, 'sequential')
					res(:,index) = patch-dico(:,p)*amp(p,index);
				end
			end
		end
	case 'MOCOD'
		zeta = mocodParams.zeta;
		eta = mocodParams.eta;
		Dprev = mocodParams.Dprev;
		
		dico = (sig*amp' + 2*(zeta+eta)*Dprev)/...
			(amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev)));
end
end