Mercurial > hg > smallbox
diff DL/two-step DL/dico_update.m @ 152:485747bf39e0 ivand_dev
Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author | Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk> |
---|---|
date | Thu, 28 Jul 2011 15:49:32 +0100 |
parents | |
children | 9eb5f0d4c1a4 5140b0e06c22 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/two-step DL/dico_update.m Thu Jul 28 15:49:32 2011 +0100 @@ -0,0 +1,107 @@ +function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) + + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) + % + % perform one iteration of dictionary update for dictionary learning + % + % parameters: + % - dico: the initial dictionary with atoms as columns + % - sig: the training data + % - amp: the amplitude coefficients as a sparse matrix + % - type: the algorithm can be one of the following + % - ols: fixed step gradient descent + % - mailhe: optimal step gradient descent (can be implemented as a + % default for ols?) + % - MOD: pseudo-inverse of the coefficients + % - KSVD: already implemented by Elad + % - flow: 'sequential' or 'parallel'. If sequential, the residual is + % updated after each atom update. If parallel, the residual is only + % updated once the whole dictionary has been computed. Sequential works + % better, there may be no need to implement parallel. Not used with + % MOD. + % - rho: learning rate. If the type is 'ols', it is the descent step of + % the gradient (typical choice: 0.1). If the type is 'mailhe', the + % descent step is the optimal step*rho (typical choice: 1, although 2 + % or 3 seems to work better). Not used for MOD and KSVD. + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + if ~ exist( 'rho', 'var' ) || isempty(rho) + rho = 0.1; + end + + if ~ exist( 'flow', 'var' ) || isempty(flow) + flow = sequential; + end + + res = sig - dico*amp; + nb_pattern = size(dico, 2); + + switch type + case 'rand' + x = rand(); + if x < 1/3 + type = 'MOD'; + elseif type < 2/3 + type = 'mailhe'; + else + type = 'KSVD'; + end + end + + switch type + case 'MOD' + G = amp*amp'; + dico2 = sig*amp'*G^-1; + for p = 1:nb_pattern + n = norm(dico2(:,p)); + % renormalize + if n > 0 + dico(:,p) = dico2(:,p)/n; + amp(p,:) = amp(p,:)*n; + end + end + case 'ols' + for p = 1:nb_pattern + grad = res*amp(p,:)'; + if norm(grad) > 0 + pat = dico(:,p) + rho*grad; + pat = pat/norm(pat); + if nargin >5 && strcmp(flow, 'sequential') + res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> + end + dico(:,p) = pat; + end + end + case 'mailhe' + for p = 1:nb_pattern + grad = res*amp(p,:)'; + if norm(grad) > 0 + pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; + pat = pat/norm(pat); + if nargin >5 && strcmp(flow, 'sequential') + res = res + (dico(:,p)-pat)*amp(p,:); + end + dico(:,p) = pat; + end + end + case 'KSVD' + for p = 1:nb_pattern + index = find(amp(p,:)~=0); + if ~isempty(index) + patch = res(:,index)+dico(:,p)*amp(p,index); + [U,S,V] = svd(patch); + if U(:,1)'*dico(:,p) > 0 + dico(:,p) = U(:,1); + else + dico(:,p) = -U(:,1); + end + dico(:,p) = dico(:,p)/norm(dico(:,p)); + amp(p,index) = dico(:,p)'*patch; + if nargin >5 && strcmp(flow, 'sequential') + res(:,index) = patch-dico(:,p)*amp(p,index); + end + end + end + end +end +