daniele@175
|
1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams)
|
ivan@152
|
2
|
daniele@175
|
3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
daniele@175
|
4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
|
daniele@175
|
5 %
|
daniele@175
|
6 % perform one iteration of dictionary update for dictionary learning
|
daniele@175
|
7 %
|
daniele@175
|
8 % parameters:
|
daniele@175
|
9 % - dico: the initial dictionary with atoms as columns
|
daniele@175
|
10 % - sig: the training data
|
daniele@175
|
11 % - amp: the amplitude coefficients as a sparse matrix
|
daniele@175
|
12 % - type: the algorithm can be one of the following
|
daniele@175
|
13 % - ols: fixed step gradient descent
|
daniele@175
|
14 % - mailhe: optimal step gradient descent (can be implemented as a
|
daniele@175
|
15 % default for ols?)
|
daniele@175
|
16 % - MOD: pseudo-inverse of the coefficients
|
daniele@175
|
17 % - KSVD: already implemented by Elad
|
daniele@175
|
18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is
|
daniele@175
|
19 % updated after each atom update. If parallel, the residual is only
|
daniele@175
|
20 % updated once the whole dictionary has been computed. Sequential works
|
daniele@175
|
21 % better, there may be no need to implement parallel. Not used with
|
daniele@175
|
22 % MOD.
|
daniele@175
|
23 % - rho: learning rate. If the type is 'ols', it is the descent step of
|
daniele@175
|
24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the
|
daniele@175
|
25 % descent step is the optimal step*rho (typical choice: 1, although 2
|
daniele@175
|
26 % or 3 seems to work better). Not used for MOD and KSVD.
|
daniele@175
|
27 % - mocodParams: struct containing the parameters for the MOCOD dictionary
|
daniele@175
|
28 % update (see Ramirez et Al., Sparse modeling with universal priors and
|
daniele@175
|
29 % learned incoherent dictionaries). The required fields are
|
daniele@175
|
30 % .Dprev: dictionary at previous optimisation step
|
daniele@175
|
31 % .zeta: coherence regularization factor
|
daniele@175
|
32 % .eta: atoms norm regularisation factor
|
daniele@175
|
33 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
daniele@175
|
34 if ~ exist( 'rho', 'var' ) || isempty(rho)
|
daniele@175
|
35 rho = 0.1;
|
ivan@152
|
36 end
|
ivan@152
|
37
|
daniele@175
|
38 if ~ exist( 'flow', 'var' ) || isempty(flow)
|
daniele@175
|
39 flow = sequential;
|
daniele@175
|
40 end
|
daniele@175
|
41
|
daniele@175
|
42 res = sig - dico*amp;
|
daniele@175
|
43 nb_pattern = size(dico, 2);
|
daniele@175
|
44
|
daniele@175
|
45 switch type
|
daniele@175
|
46 case 'rand'
|
daniele@175
|
47 x = rand();
|
daniele@175
|
48 if x < 1/3
|
daniele@175
|
49 type = 'MOD';
|
daniele@175
|
50 elseif type < 2/3
|
daniele@175
|
51 type = 'mailhe';
|
daniele@175
|
52 else
|
daniele@175
|
53 type = 'KSVD';
|
daniele@175
|
54 end
|
daniele@175
|
55 end
|
daniele@175
|
56
|
daniele@175
|
57 switch upper(type)
|
daniele@175
|
58 case 'MOD'
|
daniele@175
|
59 G = amp*amp';
|
daniele@175
|
60 dico2 = sig*amp'*G^-1;
|
daniele@175
|
61 for p = 1:nb_pattern
|
daniele@175
|
62 n = norm(dico2(:,p));
|
daniele@175
|
63 % renormalize
|
daniele@175
|
64 if n > 0
|
daniele@175
|
65 dico(:,p) = dico2(:,p)/n;
|
daniele@175
|
66 amp(p,:) = amp(p,:)*n;
|
daniele@175
|
67 end
|
daniele@175
|
68 end
|
daniele@175
|
69 case 'OLS'
|
daniele@175
|
70 for p = 1:nb_pattern
|
daniele@175
|
71 grad = res*amp(p,:)';
|
daniele@175
|
72 if norm(grad) > 0
|
daniele@175
|
73 pat = dico(:,p) + rho*grad;
|
daniele@175
|
74 pat = pat/norm(pat);
|
daniele@175
|
75 if nargin >5 && strcmp(flow, 'sequential')
|
daniele@175
|
76 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
|
daniele@175
|
77 end
|
daniele@175
|
78 dico(:,p) = pat;
|
daniele@175
|
79 end
|
daniele@175
|
80 end
|
daniele@175
|
81 case 'MAILHE'
|
daniele@175
|
82 for p = 1:nb_pattern
|
daniele@175
|
83 grad = res*amp(p,:)';
|
daniele@175
|
84 if norm(grad) > 0
|
daniele@175
|
85 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
|
daniele@175
|
86 pat = pat/norm(pat);
|
daniele@175
|
87 if nargin >5 && strcmp(flow, 'sequential')
|
daniele@175
|
88 res = res + (dico(:,p)-pat)*amp(p,:);
|
daniele@175
|
89 end
|
daniele@175
|
90 dico(:,p) = pat;
|
daniele@175
|
91 end
|
daniele@175
|
92 end
|
daniele@175
|
93 case 'KSVD'
|
daniele@175
|
94 for p = 1:nb_pattern
|
daniele@175
|
95 index = find(amp(p,:)~=0);
|
daniele@175
|
96 if ~isempty(index)
|
daniele@175
|
97 patch = res(:,index)+dico(:,p)*amp(p,index);
|
daniele@175
|
98 [U,~,V] = svd(patch);
|
daniele@175
|
99 if U(:,1)'*dico(:,p) > 0
|
daniele@175
|
100 dico(:,p) = U(:,1);
|
daniele@175
|
101 else
|
daniele@175
|
102 dico(:,p) = -U(:,1);
|
daniele@175
|
103 end
|
daniele@175
|
104 dico(:,p) = dico(:,p)/norm(dico(:,p));
|
daniele@175
|
105 amp(p,index) = dico(:,p)'*patch;
|
daniele@175
|
106 if nargin >5 && strcmp(flow, 'sequential')
|
daniele@175
|
107 res(:,index) = patch-dico(:,p)*amp(p,index);
|
daniele@175
|
108 end
|
daniele@175
|
109 end
|
daniele@175
|
110 end
|
daniele@175
|
111 case 'MOCOD'
|
daniele@175
|
112 zeta = mocodParams.zeta;
|
daniele@175
|
113 eta = mocodParams.eta;
|
daniele@175
|
114 Dprev = mocodParams.Dprev;
|
daniele@175
|
115
|
daniele@175
|
116 dico = (sig*amp' + 2*(zeta+eta)*Dprev)/...
|
daniele@175
|
117 (amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev)));
|
daniele@175
|
118 end
|
daniele@175
|
119 end
|
daniele@175
|
120
|