Mercurial > hg > smallbox
view DL/two-step DL/dico_update.m @ 175:9eb5f0d4c1a4 danieleb
added MOCOD dictionary update
author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
---|---|
date | Thu, 17 Nov 2011 11:17:00 +0000 |
parents | 485747bf39e0 |
children | fd0b5d36f6ad |
line wrap: on
line source
function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) % % perform one iteration of dictionary update for dictionary learning % % parameters: % - dico: the initial dictionary with atoms as columns % - sig: the training data % - amp: the amplitude coefficients as a sparse matrix % - type: the algorithm can be one of the following % - ols: fixed step gradient descent % - mailhe: optimal step gradient descent (can be implemented as a % default for ols?) % - MOD: pseudo-inverse of the coefficients % - KSVD: already implemented by Elad % - flow: 'sequential' or 'parallel'. If sequential, the residual is % updated after each atom update. If parallel, the residual is only % updated once the whole dictionary has been computed. Sequential works % better, there may be no need to implement parallel. Not used with % MOD. % - rho: learning rate. If the type is 'ols', it is the descent step of % the gradient (typical choice: 0.1). If the type is 'mailhe', the % descent step is the optimal step*rho (typical choice: 1, although 2 % or 3 seems to work better). Not used for MOD and KSVD. % - mocodParams: struct containing the parameters for the MOCOD dictionary % update (see Ramirez et Al., Sparse modeling with universal priors and % learned incoherent dictionaries). The required fields are % .Dprev: dictionary at previous optimisation step % .zeta: coherence regularization factor % .eta: atoms norm regularisation factor %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% if ~ exist( 'rho', 'var' ) || isempty(rho) rho = 0.1; end if ~ exist( 'flow', 'var' ) || isempty(flow) flow = sequential; end res = sig - dico*amp; nb_pattern = size(dico, 2); switch type case 'rand' x = rand(); if x < 1/3 type = 'MOD'; elseif type < 2/3 type = 'mailhe'; else type = 'KSVD'; end end switch upper(type) case 'MOD' G = amp*amp'; dico2 = sig*amp'*G^-1; for p = 1:nb_pattern n = norm(dico2(:,p)); % renormalize if n > 0 dico(:,p) = dico2(:,p)/n; amp(p,:) = amp(p,:)*n; end end case 'OLS' for p = 1:nb_pattern grad = res*amp(p,:)'; if norm(grad) > 0 pat = dico(:,p) + rho*grad; pat = pat/norm(pat); if nargin >5 && strcmp(flow, 'sequential') res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> end dico(:,p) = pat; end end case 'MAILHE' for p = 1:nb_pattern grad = res*amp(p,:)'; if norm(grad) > 0 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; pat = pat/norm(pat); if nargin >5 && strcmp(flow, 'sequential') res = res + (dico(:,p)-pat)*amp(p,:); end dico(:,p) = pat; end end case 'KSVD' for p = 1:nb_pattern index = find(amp(p,:)~=0); if ~isempty(index) patch = res(:,index)+dico(:,p)*amp(p,index); [U,~,V] = svd(patch); if U(:,1)'*dico(:,p) > 0 dico(:,p) = U(:,1); else dico(:,p) = -U(:,1); end dico(:,p) = dico(:,p)/norm(dico(:,p)); amp(p,index) = dico(:,p)'*patch; if nargin >5 && strcmp(flow, 'sequential') res(:,index) = patch-dico(:,p)*amp(p,index); end end end case 'MOCOD' zeta = mocodParams.zeta; eta = mocodParams.eta; Dprev = mocodParams.Dprev; dico = (sig*amp' + 2*(zeta+eta)*Dprev)/... (amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev))); end end