Mercurial > hg > smallbox
comparison DL/two-step DL/dico_update.m @ 224:fd0b5d36f6ad danieleb
Updated the contents of this branch with the contents of the default branch.
author | luisf <luis.figueira@eecs.qmul.ac.uk> |
---|---|
date | Thu, 12 Apr 2012 13:52:28 +0100 |
parents | 9eb5f0d4c1a4 233e75809e4a |
children |
comparison
equal
deleted
inserted
replaced
196:82b0d3f982cb | 224:fd0b5d36f6ad |
---|---|
1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams) | 1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) |
2 | 2 |
3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | 3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) | 4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) |
5 % | 5 % |
6 % perform one iteration of dictionary update for dictionary learning | 6 % perform one iteration of dictionary update for dictionary learning |
7 % | 7 % |
8 % parameters: | 8 % parameters: |
9 % - dico: the initial dictionary with atoms as columns | 9 % - dico: the initial dictionary with atoms as columns |
10 % - sig: the training data | 10 % - sig: the training data |
11 % - amp: the amplitude coefficients as a sparse matrix | 11 % - amp: the amplitude coefficients as a sparse matrix |
12 % - type: the algorithm can be one of the following | 12 % - type: the algorithm can be one of the following |
13 % - ols: fixed step gradient descent | 13 % - ols: fixed step gradient descent, as described in Olshausen & |
14 % - mailhe: optimal step gradient descent (can be implemented as a | 14 % Field95 |
15 % default for ols?) | 15 % - opt: optimal step gradient descent, as described in Mailhe et |
16 % - MOD: pseudo-inverse of the coefficients | 16 % al.08 |
17 % - KSVD: already implemented by Elad | 17 % - MOD: pseudo-inverse of the coefficients, as described in Engan99 |
18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is | 18 % - KSVD: PCA update as described in Aharon06. For fast applications, |
19 % updated after each atom update. If parallel, the residual is only | 19 % use KSVDbox rather than this code. |
20 % updated once the whole dictionary has been computed. Sequential works | 20 % - LGD: large step gradient descent. Equivalent to 'opt' with |
21 % better, there may be no need to implement parallel. Not used with | 21 % rho=2. |
22 % MOD. | 22 % - flow: 'sequential' or 'parallel'. If sequential, the residual is |
23 % - rho: learning rate. If the type is 'ols', it is the descent step of | 23 % updated after each atom update. If parallel, the residual is only |
24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the | 24 % updated once the whole dictionary has been computed. |
25 % descent step is the optimal step*rho (typical choice: 1, although 2 | 25 % Default: Sequential (sequential usually works better). Not used with |
26 % or 3 seems to work better). Not used for MOD and KSVD. | 26 % MOD. |
27 % - mocodParams: struct containing the parameters for the MOCOD dictionary | 27 % - rho: learning rate. If the type is 'ols', it is the descent step of |
28 % update (see Ramirez et Al., Sparse modeling with universal priors and | 28 % the gradient (default: 0.1). If the type is 'opt', the |
29 % learned incoherent dictionaries). The required fields are | 29 % descent step is the optimal step*rho (default: 1, although 2 works |
30 % .Dprev: dictionary at previous optimisation step | 30 % better. See LGD for more details). Not used for MOD and KSVD. |
31 % .zeta: coherence regularization factor | 31 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
32 % .eta: atoms norm regularisation factor | 32 |
33 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | 33 if ~ exist( 'flow', 'var' ) || isempty(flow) |
34 if ~ exist( 'rho', 'var' ) || isempty(rho) | 34 flow = 'sequential'; |
35 rho = 0.1; | 35 end |
36 | |
37 res = sig - dico*amp; | |
38 nb_pattern = size(dico, 2); | |
39 | |
40 % if the type is random, then randomly pick another type | |
41 switch type | |
42 case 'rand' | |
43 x = rand(); | |
44 if x < 1/3 | |
45 type = 'MOD'; | |
46 elseif type < 2/3 | |
47 type = 'opt'; | |
48 else | |
49 type = 'KSVD'; | |
50 end | |
51 end | |
52 | |
53 % set the learning rate to default if not provided | |
54 if ~ exist( 'rho', 'var' ) || isempty(rho) | |
55 switch type | |
56 case 'ols' | |
57 rho = 0.1; | |
58 case 'opt' | |
59 rho = 1; | |
60 end | |
61 end | |
62 | |
63 switch type | |
64 case 'MOD' | |
65 G = amp*amp'; | |
66 dico2 = sig*amp'*G^-1; | |
67 for p = 1:nb_pattern | |
68 n = norm(dico2(:,p)); | |
69 % renormalize | |
70 if n > 0 | |
71 dico(:,p) = dico2(:,p)/n; | |
72 amp(p,:) = amp(p,:)*n; | |
73 end | |
74 end | |
75 case 'ols' | |
76 for p = 1:nb_pattern | |
77 grad = res*amp(p,:)'; | |
78 if norm(grad) > 0 | |
79 pat = dico(:,p) + rho*grad; | |
80 pat = pat/norm(pat); | |
81 if nargin >5 && strcmp(flow, 'sequential') | |
82 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> | |
83 end | |
84 dico(:,p) = pat; | |
85 end | |
86 end | |
87 case 'opt' | |
88 for p = 1:nb_pattern | |
89 index = find(amp(p,:)~=0); | |
90 vec = amp(p,index); | |
91 grad = res(:,index)*vec'; | |
92 if norm(grad) > 0 | |
93 pat = (vec*vec')*dico(:,p) + rho*grad; | |
94 pat = pat/norm(pat); | |
95 if nargin >5 && strcmp(flow, 'sequential') | |
96 res(:,index) = res(:,index) + (dico(:,p)-pat)*vec; | |
97 end | |
98 dico(:,p) = pat; | |
99 end | |
100 end | |
101 case 'LGD' | |
102 for p = 1:nb_pattern | |
103 index = find(amp(p,:)~=0); | |
104 vec = amp(p,index); | |
105 grad = res(:,index)*vec'; | |
106 if norm(grad) > 0 | |
107 pat = (vec*vec')*dico(:,p) + 2*grad; | |
108 pat = pat/norm(pat); | |
109 if nargin >5 && strcmp(flow, 'sequential') | |
110 res(:,index) = res(:,index) + (dico(:,p)-pat)*vec; | |
111 end | |
112 dico(:,p) = pat; | |
113 end | |
114 end | |
115 case 'KSVD' | |
116 for p = 1:nb_pattern | |
117 index = find(amp(p,:)~=0); | |
118 if ~isempty(index) | |
119 patch = res(:,index)+dico(:,p)*amp(p,index); | |
120 [U,~,V] = svd(patch); | |
121 if U(:,1)'*dico(:,p) > 0 | |
122 dico(:,p) = U(:,1); | |
123 else | |
124 dico(:,p) = -U(:,1); | |
125 end | |
126 dico(:,p) = dico(:,p)/norm(dico(:,p)); | |
127 amp(p,index) = dico(:,p)'*patch; | |
128 if nargin >5 && strcmp(flow, 'sequential') | |
129 res(:,index) = patch-dico(:,p)*amp(p,index); | |
130 end | |
131 end | |
132 end | |
133 end | |
36 end | 134 end |
37 | 135 |
38 if ~ exist( 'flow', 'var' ) || isempty(flow) | |
39 flow = sequential; | |
40 end | |
41 | |
42 res = sig - dico*amp; | |
43 nb_pattern = size(dico, 2); | |
44 | |
45 switch type | |
46 case 'rand' | |
47 x = rand(); | |
48 if x < 1/3 | |
49 type = 'MOD'; | |
50 elseif type < 2/3 | |
51 type = 'mailhe'; | |
52 else | |
53 type = 'KSVD'; | |
54 end | |
55 end | |
56 | |
57 switch upper(type) | |
58 case 'MOD' | |
59 G = amp*amp'; | |
60 dico2 = sig*amp'*G^-1; | |
61 for p = 1:nb_pattern | |
62 n = norm(dico2(:,p)); | |
63 % renormalize | |
64 if n > 0 | |
65 dico(:,p) = dico2(:,p)/n; | |
66 amp(p,:) = amp(p,:)*n; | |
67 end | |
68 end | |
69 case 'OLS' | |
70 for p = 1:nb_pattern | |
71 grad = res*amp(p,:)'; | |
72 if norm(grad) > 0 | |
73 pat = dico(:,p) + rho*grad; | |
74 pat = pat/norm(pat); | |
75 if nargin >5 && strcmp(flow, 'sequential') | |
76 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> | |
77 end | |
78 dico(:,p) = pat; | |
79 end | |
80 end | |
81 case 'MAILHE' | |
82 for p = 1:nb_pattern | |
83 grad = res*amp(p,:)'; | |
84 if norm(grad) > 0 | |
85 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; | |
86 pat = pat/norm(pat); | |
87 if nargin >5 && strcmp(flow, 'sequential') | |
88 res = res + (dico(:,p)-pat)*amp(p,:); | |
89 end | |
90 dico(:,p) = pat; | |
91 end | |
92 end | |
93 case 'KSVD' | |
94 for p = 1:nb_pattern | |
95 index = find(amp(p,:)~=0); | |
96 if ~isempty(index) | |
97 patch = res(:,index)+dico(:,p)*amp(p,index); | |
98 [U,~,V] = svd(patch); | |
99 if U(:,1)'*dico(:,p) > 0 | |
100 dico(:,p) = U(:,1); | |
101 else | |
102 dico(:,p) = -U(:,1); | |
103 end | |
104 dico(:,p) = dico(:,p)/norm(dico(:,p)); | |
105 amp(p,index) = dico(:,p)'*patch; | |
106 if nargin >5 && strcmp(flow, 'sequential') | |
107 res(:,index) = patch-dico(:,p)*amp(p,index); | |
108 end | |
109 end | |
110 end | |
111 case 'MOCOD' | |
112 zeta = mocodParams.zeta; | |
113 eta = mocodParams.eta; | |
114 Dprev = mocodParams.Dprev; | |
115 | |
116 dico = (sig*amp' + 2*(zeta+eta)*Dprev)/... | |
117 (amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev))); | |
118 end | |
119 end | |
120 |