Mercurial > hg > smallbox
comparison DL/two-step DL/dico_update.m @ 175:9eb5f0d4c1a4 danieleb
added MOCOD dictionary update
author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
---|---|
date | Thu, 17 Nov 2011 11:17:00 +0000 |
parents | 485747bf39e0 |
children | fd0b5d36f6ad |
comparison
equal
deleted
inserted
replaced
174:dc2f0fa21310 | 175:9eb5f0d4c1a4 |
---|---|
1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) | 1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams) |
2 | 2 |
3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | 3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) | 4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) |
5 % | 5 % |
6 % perform one iteration of dictionary update for dictionary learning | 6 % perform one iteration of dictionary update for dictionary learning |
7 % | 7 % |
8 % parameters: | 8 % parameters: |
9 % - dico: the initial dictionary with atoms as columns | 9 % - dico: the initial dictionary with atoms as columns |
10 % - sig: the training data | 10 % - sig: the training data |
11 % - amp: the amplitude coefficients as a sparse matrix | 11 % - amp: the amplitude coefficients as a sparse matrix |
12 % - type: the algorithm can be one of the following | 12 % - type: the algorithm can be one of the following |
13 % - ols: fixed step gradient descent | 13 % - ols: fixed step gradient descent |
14 % - mailhe: optimal step gradient descent (can be implemented as a | 14 % - mailhe: optimal step gradient descent (can be implemented as a |
15 % default for ols?) | 15 % default for ols?) |
16 % - MOD: pseudo-inverse of the coefficients | 16 % - MOD: pseudo-inverse of the coefficients |
17 % - KSVD: already implemented by Elad | 17 % - KSVD: already implemented by Elad |
18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is | 18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is |
19 % updated after each atom update. If parallel, the residual is only | 19 % updated after each atom update. If parallel, the residual is only |
20 % updated once the whole dictionary has been computed. Sequential works | 20 % updated once the whole dictionary has been computed. Sequential works |
21 % better, there may be no need to implement parallel. Not used with | 21 % better, there may be no need to implement parallel. Not used with |
22 % MOD. | 22 % MOD. |
23 % - rho: learning rate. If the type is 'ols', it is the descent step of | 23 % - rho: learning rate. If the type is 'ols', it is the descent step of |
24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the | 24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the |
25 % descent step is the optimal step*rho (typical choice: 1, although 2 | 25 % descent step is the optimal step*rho (typical choice: 1, although 2 |
26 % or 3 seems to work better). Not used for MOD and KSVD. | 26 % or 3 seems to work better). Not used for MOD and KSVD. |
27 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | 27 % - mocodParams: struct containing the parameters for the MOCOD dictionary |
28 if ~ exist( 'rho', 'var' ) || isempty(rho) | 28 % update (see Ramirez et Al., Sparse modeling with universal priors and |
29 rho = 0.1; | 29 % learned incoherent dictionaries). The required fields are |
30 end | 30 % .Dprev: dictionary at previous optimisation step |
31 | 31 % .zeta: coherence regularization factor |
32 if ~ exist( 'flow', 'var' ) || isempty(flow) | 32 % .eta: atoms norm regularisation factor |
33 flow = sequential; | 33 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
34 end | 34 if ~ exist( 'rho', 'var' ) || isempty(rho) |
35 | 35 rho = 0.1; |
36 res = sig - dico*amp; | |
37 nb_pattern = size(dico, 2); | |
38 | |
39 switch type | |
40 case 'rand' | |
41 x = rand(); | |
42 if x < 1/3 | |
43 type = 'MOD'; | |
44 elseif type < 2/3 | |
45 type = 'mailhe'; | |
46 else | |
47 type = 'KSVD'; | |
48 end | |
49 end | |
50 | |
51 switch type | |
52 case 'MOD' | |
53 G = amp*amp'; | |
54 dico2 = sig*amp'*G^-1; | |
55 for p = 1:nb_pattern | |
56 n = norm(dico2(:,p)); | |
57 % renormalize | |
58 if n > 0 | |
59 dico(:,p) = dico2(:,p)/n; | |
60 amp(p,:) = amp(p,:)*n; | |
61 end | |
62 end | |
63 case 'ols' | |
64 for p = 1:nb_pattern | |
65 grad = res*amp(p,:)'; | |
66 if norm(grad) > 0 | |
67 pat = dico(:,p) + rho*grad; | |
68 pat = pat/norm(pat); | |
69 if nargin >5 && strcmp(flow, 'sequential') | |
70 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> | |
71 end | |
72 dico(:,p) = pat; | |
73 end | |
74 end | |
75 case 'mailhe' | |
76 for p = 1:nb_pattern | |
77 grad = res*amp(p,:)'; | |
78 if norm(grad) > 0 | |
79 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; | |
80 pat = pat/norm(pat); | |
81 if nargin >5 && strcmp(flow, 'sequential') | |
82 res = res + (dico(:,p)-pat)*amp(p,:); | |
83 end | |
84 dico(:,p) = pat; | |
85 end | |
86 end | |
87 case 'KSVD' | |
88 for p = 1:nb_pattern | |
89 index = find(amp(p,:)~=0); | |
90 if ~isempty(index) | |
91 patch = res(:,index)+dico(:,p)*amp(p,index); | |
92 [U,S,V] = svd(patch); | |
93 if U(:,1)'*dico(:,p) > 0 | |
94 dico(:,p) = U(:,1); | |
95 else | |
96 dico(:,p) = -U(:,1); | |
97 end | |
98 dico(:,p) = dico(:,p)/norm(dico(:,p)); | |
99 amp(p,index) = dico(:,p)'*patch; | |
100 if nargin >5 && strcmp(flow, 'sequential') | |
101 res(:,index) = patch-dico(:,p)*amp(p,index); | |
102 end | |
103 end | |
104 end | |
105 end | |
106 end | 36 end |
107 | 37 |
38 if ~ exist( 'flow', 'var' ) || isempty(flow) | |
39 flow = sequential; | |
40 end | |
41 | |
42 res = sig - dico*amp; | |
43 nb_pattern = size(dico, 2); | |
44 | |
45 switch type | |
46 case 'rand' | |
47 x = rand(); | |
48 if x < 1/3 | |
49 type = 'MOD'; | |
50 elseif type < 2/3 | |
51 type = 'mailhe'; | |
52 else | |
53 type = 'KSVD'; | |
54 end | |
55 end | |
56 | |
57 switch upper(type) | |
58 case 'MOD' | |
59 G = amp*amp'; | |
60 dico2 = sig*amp'*G^-1; | |
61 for p = 1:nb_pattern | |
62 n = norm(dico2(:,p)); | |
63 % renormalize | |
64 if n > 0 | |
65 dico(:,p) = dico2(:,p)/n; | |
66 amp(p,:) = amp(p,:)*n; | |
67 end | |
68 end | |
69 case 'OLS' | |
70 for p = 1:nb_pattern | |
71 grad = res*amp(p,:)'; | |
72 if norm(grad) > 0 | |
73 pat = dico(:,p) + rho*grad; | |
74 pat = pat/norm(pat); | |
75 if nargin >5 && strcmp(flow, 'sequential') | |
76 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> | |
77 end | |
78 dico(:,p) = pat; | |
79 end | |
80 end | |
81 case 'MAILHE' | |
82 for p = 1:nb_pattern | |
83 grad = res*amp(p,:)'; | |
84 if norm(grad) > 0 | |
85 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; | |
86 pat = pat/norm(pat); | |
87 if nargin >5 && strcmp(flow, 'sequential') | |
88 res = res + (dico(:,p)-pat)*amp(p,:); | |
89 end | |
90 dico(:,p) = pat; | |
91 end | |
92 end | |
93 case 'KSVD' | |
94 for p = 1:nb_pattern | |
95 index = find(amp(p,:)~=0); | |
96 if ~isempty(index) | |
97 patch = res(:,index)+dico(:,p)*amp(p,index); | |
98 [U,~,V] = svd(patch); | |
99 if U(:,1)'*dico(:,p) > 0 | |
100 dico(:,p) = U(:,1); | |
101 else | |
102 dico(:,p) = -U(:,1); | |
103 end | |
104 dico(:,p) = dico(:,p)/norm(dico(:,p)); | |
105 amp(p,index) = dico(:,p)'*patch; | |
106 if nargin >5 && strcmp(flow, 'sequential') | |
107 res(:,index) = patch-dico(:,p)*amp(p,index); | |
108 end | |
109 end | |
110 end | |
111 case 'MOCOD' | |
112 zeta = mocodParams.zeta; | |
113 eta = mocodParams.eta; | |
114 Dprev = mocodParams.Dprev; | |
115 | |
116 dico = (sig*amp' + 2*(zeta+eta)*Dprev)/... | |
117 (amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev))); | |
118 end | |
119 end | |
120 |