Mercurial > hg > smallbox
view DL/two-step DL/dico_update.m @ 186:9c418bea7f6a bug_386
Addresses Bug #386: removed the 4th output variable (versn) in all calls of function fileparts.
author | luisf <luis.figueira@eecs.qmul.ac.uk> |
---|---|
date | Thu, 09 Feb 2012 17:25:14 +0000 |
parents | 485747bf39e0 |
children | 9eb5f0d4c1a4 5140b0e06c22 |
line wrap: on
line source
function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) % % perform one iteration of dictionary update for dictionary learning % % parameters: % - dico: the initial dictionary with atoms as columns % - sig: the training data % - amp: the amplitude coefficients as a sparse matrix % - type: the algorithm can be one of the following % - ols: fixed step gradient descent % - mailhe: optimal step gradient descent (can be implemented as a % default for ols?) % - MOD: pseudo-inverse of the coefficients % - KSVD: already implemented by Elad % - flow: 'sequential' or 'parallel'. If sequential, the residual is % updated after each atom update. If parallel, the residual is only % updated once the whole dictionary has been computed. Sequential works % better, there may be no need to implement parallel. Not used with % MOD. % - rho: learning rate. If the type is 'ols', it is the descent step of % the gradient (typical choice: 0.1). If the type is 'mailhe', the % descent step is the optimal step*rho (typical choice: 1, although 2 % or 3 seems to work better). Not used for MOD and KSVD. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% if ~ exist( 'rho', 'var' ) || isempty(rho) rho = 0.1; end if ~ exist( 'flow', 'var' ) || isempty(flow) flow = sequential; end res = sig - dico*amp; nb_pattern = size(dico, 2); switch type case 'rand' x = rand(); if x < 1/3 type = 'MOD'; elseif type < 2/3 type = 'mailhe'; else type = 'KSVD'; end end switch type case 'MOD' G = amp*amp'; dico2 = sig*amp'*G^-1; for p = 1:nb_pattern n = norm(dico2(:,p)); % renormalize if n > 0 dico(:,p) = dico2(:,p)/n; amp(p,:) = amp(p,:)*n; end end case 'ols' for p = 1:nb_pattern grad = res*amp(p,:)'; if norm(grad) > 0 pat = dico(:,p) + rho*grad; pat = pat/norm(pat); if nargin >5 && strcmp(flow, 'sequential') res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> end dico(:,p) = pat; end end case 'mailhe' for p = 1:nb_pattern grad = res*amp(p,:)'; if norm(grad) > 0 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; pat = pat/norm(pat); if nargin >5 && strcmp(flow, 'sequential') res = res + (dico(:,p)-pat)*amp(p,:); end dico(:,p) = pat; end end case 'KSVD' for p = 1:nb_pattern index = find(amp(p,:)~=0); if ~isempty(index) patch = res(:,index)+dico(:,p)*amp(p,index); [U,S,V] = svd(patch); if U(:,1)'*dico(:,p) > 0 dico(:,p) = U(:,1); else dico(:,p) = -U(:,1); end dico(:,p) = dico(:,p)/norm(dico(:,p)); amp(p,index) = dico(:,p)'*patch; if nargin >5 && strcmp(flow, 'sequential') res(:,index) = patch-dico(:,p)*amp(p,index); end end end end end