annotate DL/two-step DL/dico_update.m @ 192:f1e601cc916d danieleb

removed error check for wLength
author Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk>
date Thu, 01 Mar 2012 16:57:51 +0000
parents 9eb5f0d4c1a4
children fd0b5d36f6ad
rev   line source
daniele@175 1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams)
ivan@152 2
daniele@175 3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
daniele@175 4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
daniele@175 5 %
daniele@175 6 % perform one iteration of dictionary update for dictionary learning
daniele@175 7 %
daniele@175 8 % parameters:
daniele@175 9 % - dico: the initial dictionary with atoms as columns
daniele@175 10 % - sig: the training data
daniele@175 11 % - amp: the amplitude coefficients as a sparse matrix
daniele@175 12 % - type: the algorithm can be one of the following
daniele@175 13 % - ols: fixed step gradient descent
daniele@175 14 % - mailhe: optimal step gradient descent (can be implemented as a
daniele@175 15 % default for ols?)
daniele@175 16 % - MOD: pseudo-inverse of the coefficients
daniele@175 17 % - KSVD: already implemented by Elad
daniele@175 18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is
daniele@175 19 % updated after each atom update. If parallel, the residual is only
daniele@175 20 % updated once the whole dictionary has been computed. Sequential works
daniele@175 21 % better, there may be no need to implement parallel. Not used with
daniele@175 22 % MOD.
daniele@175 23 % - rho: learning rate. If the type is 'ols', it is the descent step of
daniele@175 24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the
daniele@175 25 % descent step is the optimal step*rho (typical choice: 1, although 2
daniele@175 26 % or 3 seems to work better). Not used for MOD and KSVD.
daniele@175 27 % - mocodParams: struct containing the parameters for the MOCOD dictionary
daniele@175 28 % update (see Ramirez et Al., Sparse modeling with universal priors and
daniele@175 29 % learned incoherent dictionaries). The required fields are
daniele@175 30 % .Dprev: dictionary at previous optimisation step
daniele@175 31 % .zeta: coherence regularization factor
daniele@175 32 % .eta: atoms norm regularisation factor
daniele@175 33 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
daniele@175 34 if ~ exist( 'rho', 'var' ) || isempty(rho)
daniele@175 35 rho = 0.1;
ivan@152 36 end
ivan@152 37
daniele@175 38 if ~ exist( 'flow', 'var' ) || isempty(flow)
daniele@175 39 flow = sequential;
daniele@175 40 end
daniele@175 41
daniele@175 42 res = sig - dico*amp;
daniele@175 43 nb_pattern = size(dico, 2);
daniele@175 44
daniele@175 45 switch type
daniele@175 46 case 'rand'
daniele@175 47 x = rand();
daniele@175 48 if x < 1/3
daniele@175 49 type = 'MOD';
daniele@175 50 elseif type < 2/3
daniele@175 51 type = 'mailhe';
daniele@175 52 else
daniele@175 53 type = 'KSVD';
daniele@175 54 end
daniele@175 55 end
daniele@175 56
daniele@175 57 switch upper(type)
daniele@175 58 case 'MOD'
daniele@175 59 G = amp*amp';
daniele@175 60 dico2 = sig*amp'*G^-1;
daniele@175 61 for p = 1:nb_pattern
daniele@175 62 n = norm(dico2(:,p));
daniele@175 63 % renormalize
daniele@175 64 if n > 0
daniele@175 65 dico(:,p) = dico2(:,p)/n;
daniele@175 66 amp(p,:) = amp(p,:)*n;
daniele@175 67 end
daniele@175 68 end
daniele@175 69 case 'OLS'
daniele@175 70 for p = 1:nb_pattern
daniele@175 71 grad = res*amp(p,:)';
daniele@175 72 if norm(grad) > 0
daniele@175 73 pat = dico(:,p) + rho*grad;
daniele@175 74 pat = pat/norm(pat);
daniele@175 75 if nargin >5 && strcmp(flow, 'sequential')
daniele@175 76 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
daniele@175 77 end
daniele@175 78 dico(:,p) = pat;
daniele@175 79 end
daniele@175 80 end
daniele@175 81 case 'MAILHE'
daniele@175 82 for p = 1:nb_pattern
daniele@175 83 grad = res*amp(p,:)';
daniele@175 84 if norm(grad) > 0
daniele@175 85 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
daniele@175 86 pat = pat/norm(pat);
daniele@175 87 if nargin >5 && strcmp(flow, 'sequential')
daniele@175 88 res = res + (dico(:,p)-pat)*amp(p,:);
daniele@175 89 end
daniele@175 90 dico(:,p) = pat;
daniele@175 91 end
daniele@175 92 end
daniele@175 93 case 'KSVD'
daniele@175 94 for p = 1:nb_pattern
daniele@175 95 index = find(amp(p,:)~=0);
daniele@175 96 if ~isempty(index)
daniele@175 97 patch = res(:,index)+dico(:,p)*amp(p,index);
daniele@175 98 [U,~,V] = svd(patch);
daniele@175 99 if U(:,1)'*dico(:,p) > 0
daniele@175 100 dico(:,p) = U(:,1);
daniele@175 101 else
daniele@175 102 dico(:,p) = -U(:,1);
daniele@175 103 end
daniele@175 104 dico(:,p) = dico(:,p)/norm(dico(:,p));
daniele@175 105 amp(p,index) = dico(:,p)'*patch;
daniele@175 106 if nargin >5 && strcmp(flow, 'sequential')
daniele@175 107 res(:,index) = patch-dico(:,p)*amp(p,index);
daniele@175 108 end
daniele@175 109 end
daniele@175 110 end
daniele@175 111 case 'MOCOD'
daniele@175 112 zeta = mocodParams.zeta;
daniele@175 113 eta = mocodParams.eta;
daniele@175 114 Dprev = mocodParams.Dprev;
daniele@175 115
daniele@175 116 dico = (sig*amp' + 2*(zeta+eta)*Dprev)/...
daniele@175 117 (amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev)));
daniele@175 118 end
daniele@175 119 end
daniele@175 120