annotate DL/two-step DL/dico_update.m @ 152:485747bf39e0 ivand_dev

Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Thu, 28 Jul 2011 15:49:32 +0100
parents
children 9eb5f0d4c1a4 5140b0e06c22
rev   line source
ivan@152 1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
ivan@152 2
ivan@152 3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
ivan@152 4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
ivan@152 5 %
ivan@152 6 % perform one iteration of dictionary update for dictionary learning
ivan@152 7 %
ivan@152 8 % parameters:
ivan@152 9 % - dico: the initial dictionary with atoms as columns
ivan@152 10 % - sig: the training data
ivan@152 11 % - amp: the amplitude coefficients as a sparse matrix
ivan@152 12 % - type: the algorithm can be one of the following
ivan@152 13 % - ols: fixed step gradient descent
ivan@152 14 % - mailhe: optimal step gradient descent (can be implemented as a
ivan@152 15 % default for ols?)
ivan@152 16 % - MOD: pseudo-inverse of the coefficients
ivan@152 17 % - KSVD: already implemented by Elad
ivan@152 18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is
ivan@152 19 % updated after each atom update. If parallel, the residual is only
ivan@152 20 % updated once the whole dictionary has been computed. Sequential works
ivan@152 21 % better, there may be no need to implement parallel. Not used with
ivan@152 22 % MOD.
ivan@152 23 % - rho: learning rate. If the type is 'ols', it is the descent step of
ivan@152 24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the
ivan@152 25 % descent step is the optimal step*rho (typical choice: 1, although 2
ivan@152 26 % or 3 seems to work better). Not used for MOD and KSVD.
ivan@152 27 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
ivan@152 28 if ~ exist( 'rho', 'var' ) || isempty(rho)
ivan@152 29 rho = 0.1;
ivan@152 30 end
ivan@152 31
ivan@152 32 if ~ exist( 'flow', 'var' ) || isempty(flow)
ivan@152 33 flow = sequential;
ivan@152 34 end
ivan@152 35
ivan@152 36 res = sig - dico*amp;
ivan@152 37 nb_pattern = size(dico, 2);
ivan@152 38
ivan@152 39 switch type
ivan@152 40 case 'rand'
ivan@152 41 x = rand();
ivan@152 42 if x < 1/3
ivan@152 43 type = 'MOD';
ivan@152 44 elseif type < 2/3
ivan@152 45 type = 'mailhe';
ivan@152 46 else
ivan@152 47 type = 'KSVD';
ivan@152 48 end
ivan@152 49 end
ivan@152 50
ivan@152 51 switch type
ivan@152 52 case 'MOD'
ivan@152 53 G = amp*amp';
ivan@152 54 dico2 = sig*amp'*G^-1;
ivan@152 55 for p = 1:nb_pattern
ivan@152 56 n = norm(dico2(:,p));
ivan@152 57 % renormalize
ivan@152 58 if n > 0
ivan@152 59 dico(:,p) = dico2(:,p)/n;
ivan@152 60 amp(p,:) = amp(p,:)*n;
ivan@152 61 end
ivan@152 62 end
ivan@152 63 case 'ols'
ivan@152 64 for p = 1:nb_pattern
ivan@152 65 grad = res*amp(p,:)';
ivan@152 66 if norm(grad) > 0
ivan@152 67 pat = dico(:,p) + rho*grad;
ivan@152 68 pat = pat/norm(pat);
ivan@152 69 if nargin >5 && strcmp(flow, 'sequential')
ivan@152 70 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
ivan@152 71 end
ivan@152 72 dico(:,p) = pat;
ivan@152 73 end
ivan@152 74 end
ivan@152 75 case 'mailhe'
ivan@152 76 for p = 1:nb_pattern
ivan@152 77 grad = res*amp(p,:)';
ivan@152 78 if norm(grad) > 0
ivan@152 79 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
ivan@152 80 pat = pat/norm(pat);
ivan@152 81 if nargin >5 && strcmp(flow, 'sequential')
ivan@152 82 res = res + (dico(:,p)-pat)*amp(p,:);
ivan@152 83 end
ivan@152 84 dico(:,p) = pat;
ivan@152 85 end
ivan@152 86 end
ivan@152 87 case 'KSVD'
ivan@152 88 for p = 1:nb_pattern
ivan@152 89 index = find(amp(p,:)~=0);
ivan@152 90 if ~isempty(index)
ivan@152 91 patch = res(:,index)+dico(:,p)*amp(p,index);
ivan@152 92 [U,S,V] = svd(patch);
ivan@152 93 if U(:,1)'*dico(:,p) > 0
ivan@152 94 dico(:,p) = U(:,1);
ivan@152 95 else
ivan@152 96 dico(:,p) = -U(:,1);
ivan@152 97 end
ivan@152 98 dico(:,p) = dico(:,p)/norm(dico(:,p));
ivan@152 99 amp(p,index) = dico(:,p)'*patch;
ivan@152 100 if nargin >5 && strcmp(flow, 'sequential')
ivan@152 101 res(:,index) = patch-dico(:,p)*amp(p,index);
ivan@152 102 end
ivan@152 103 end
ivan@152 104 end
ivan@152 105 end
ivan@152 106 end
ivan@152 107