comparison DL/two-step DL/dico_update.m @ 224:fd0b5d36f6ad danieleb

Updated the contents of this branch with the contents of the default branch.
author luisf <luis.figueira@eecs.qmul.ac.uk>
date Thu, 12 Apr 2012 13:52:28 +0100
parents 9eb5f0d4c1a4 233e75809e4a
children
comparison
equal deleted inserted replaced
196:82b0d3f982cb 224:fd0b5d36f6ad
1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho,mocodParams) 1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
2 2
3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) 4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
5 % 5 %
6 % perform one iteration of dictionary update for dictionary learning 6 % perform one iteration of dictionary update for dictionary learning
7 % 7 %
8 % parameters: 8 % parameters:
9 % - dico: the initial dictionary with atoms as columns 9 % - dico: the initial dictionary with atoms as columns
10 % - sig: the training data 10 % - sig: the training data
11 % - amp: the amplitude coefficients as a sparse matrix 11 % - amp: the amplitude coefficients as a sparse matrix
12 % - type: the algorithm can be one of the following 12 % - type: the algorithm can be one of the following
13 % - ols: fixed step gradient descent 13 % - ols: fixed step gradient descent, as described in Olshausen &
14 % - mailhe: optimal step gradient descent (can be implemented as a 14 % Field95
15 % default for ols?) 15 % - opt: optimal step gradient descent, as described in Mailhe et
16 % - MOD: pseudo-inverse of the coefficients 16 % al.08
17 % - KSVD: already implemented by Elad 17 % - MOD: pseudo-inverse of the coefficients, as described in Engan99
18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is 18 % - KSVD: PCA update as described in Aharon06. For fast applications,
19 % updated after each atom update. If parallel, the residual is only 19 % use KSVDbox rather than this code.
20 % updated once the whole dictionary has been computed. Sequential works 20 % - LGD: large step gradient descent. Equivalent to 'opt' with
21 % better, there may be no need to implement parallel. Not used with 21 % rho=2.
22 % MOD. 22 % - flow: 'sequential' or 'parallel'. If sequential, the residual is
23 % - rho: learning rate. If the type is 'ols', it is the descent step of 23 % updated after each atom update. If parallel, the residual is only
24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the 24 % updated once the whole dictionary has been computed.
25 % descent step is the optimal step*rho (typical choice: 1, although 2 25 % Default: Sequential (sequential usually works better). Not used with
26 % or 3 seems to work better). Not used for MOD and KSVD. 26 % MOD.
27 % - mocodParams: struct containing the parameters for the MOCOD dictionary 27 % - rho: learning rate. If the type is 'ols', it is the descent step of
28 % update (see Ramirez et Al., Sparse modeling with universal priors and 28 % the gradient (default: 0.1). If the type is 'opt', the
29 % learned incoherent dictionaries). The required fields are 29 % descent step is the optimal step*rho (default: 1, although 2 works
30 % .Dprev: dictionary at previous optimisation step 30 % better. See LGD for more details). Not used for MOD and KSVD.
31 % .zeta: coherence regularization factor 31 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
32 % .eta: atoms norm regularisation factor 32
33 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 33 if ~ exist( 'flow', 'var' ) || isempty(flow)
34 if ~ exist( 'rho', 'var' ) || isempty(rho) 34 flow = 'sequential';
35 rho = 0.1; 35 end
36
37 res = sig - dico*amp;
38 nb_pattern = size(dico, 2);
39
40 % if the type is random, then randomly pick another type
41 switch type
42 case 'rand'
43 x = rand();
44 if x < 1/3
45 type = 'MOD';
46 elseif type < 2/3
47 type = 'opt';
48 else
49 type = 'KSVD';
50 end
51 end
52
53 % set the learning rate to default if not provided
54 if ~ exist( 'rho', 'var' ) || isempty(rho)
55 switch type
56 case 'ols'
57 rho = 0.1;
58 case 'opt'
59 rho = 1;
60 end
61 end
62
63 switch type
64 case 'MOD'
65 G = amp*amp';
66 dico2 = sig*amp'*G^-1;
67 for p = 1:nb_pattern
68 n = norm(dico2(:,p));
69 % renormalize
70 if n > 0
71 dico(:,p) = dico2(:,p)/n;
72 amp(p,:) = amp(p,:)*n;
73 end
74 end
75 case 'ols'
76 for p = 1:nb_pattern
77 grad = res*amp(p,:)';
78 if norm(grad) > 0
79 pat = dico(:,p) + rho*grad;
80 pat = pat/norm(pat);
81 if nargin >5 && strcmp(flow, 'sequential')
82 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
83 end
84 dico(:,p) = pat;
85 end
86 end
87 case 'opt'
88 for p = 1:nb_pattern
89 index = find(amp(p,:)~=0);
90 vec = amp(p,index);
91 grad = res(:,index)*vec';
92 if norm(grad) > 0
93 pat = (vec*vec')*dico(:,p) + rho*grad;
94 pat = pat/norm(pat);
95 if nargin >5 && strcmp(flow, 'sequential')
96 res(:,index) = res(:,index) + (dico(:,p)-pat)*vec;
97 end
98 dico(:,p) = pat;
99 end
100 end
101 case 'LGD'
102 for p = 1:nb_pattern
103 index = find(amp(p,:)~=0);
104 vec = amp(p,index);
105 grad = res(:,index)*vec';
106 if norm(grad) > 0
107 pat = (vec*vec')*dico(:,p) + 2*grad;
108 pat = pat/norm(pat);
109 if nargin >5 && strcmp(flow, 'sequential')
110 res(:,index) = res(:,index) + (dico(:,p)-pat)*vec;
111 end
112 dico(:,p) = pat;
113 end
114 end
115 case 'KSVD'
116 for p = 1:nb_pattern
117 index = find(amp(p,:)~=0);
118 if ~isempty(index)
119 patch = res(:,index)+dico(:,p)*amp(p,index);
120 [U,~,V] = svd(patch);
121 if U(:,1)'*dico(:,p) > 0
122 dico(:,p) = U(:,1);
123 else
124 dico(:,p) = -U(:,1);
125 end
126 dico(:,p) = dico(:,p)/norm(dico(:,p));
127 amp(p,index) = dico(:,p)'*patch;
128 if nargin >5 && strcmp(flow, 'sequential')
129 res(:,index) = patch-dico(:,p)*amp(p,index);
130 end
131 end
132 end
133 end
36 end 134 end
37 135
38 if ~ exist( 'flow', 'var' ) || isempty(flow)
39 flow = sequential;
40 end
41
42 res = sig - dico*amp;
43 nb_pattern = size(dico, 2);
44
45 switch type
46 case 'rand'
47 x = rand();
48 if x < 1/3
49 type = 'MOD';
50 elseif type < 2/3
51 type = 'mailhe';
52 else
53 type = 'KSVD';
54 end
55 end
56
57 switch upper(type)
58 case 'MOD'
59 G = amp*amp';
60 dico2 = sig*amp'*G^-1;
61 for p = 1:nb_pattern
62 n = norm(dico2(:,p));
63 % renormalize
64 if n > 0
65 dico(:,p) = dico2(:,p)/n;
66 amp(p,:) = amp(p,:)*n;
67 end
68 end
69 case 'OLS'
70 for p = 1:nb_pattern
71 grad = res*amp(p,:)';
72 if norm(grad) > 0
73 pat = dico(:,p) + rho*grad;
74 pat = pat/norm(pat);
75 if nargin >5 && strcmp(flow, 'sequential')
76 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
77 end
78 dico(:,p) = pat;
79 end
80 end
81 case 'MAILHE'
82 for p = 1:nb_pattern
83 grad = res*amp(p,:)';
84 if norm(grad) > 0
85 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
86 pat = pat/norm(pat);
87 if nargin >5 && strcmp(flow, 'sequential')
88 res = res + (dico(:,p)-pat)*amp(p,:);
89 end
90 dico(:,p) = pat;
91 end
92 end
93 case 'KSVD'
94 for p = 1:nb_pattern
95 index = find(amp(p,:)~=0);
96 if ~isempty(index)
97 patch = res(:,index)+dico(:,p)*amp(p,index);
98 [U,~,V] = svd(patch);
99 if U(:,1)'*dico(:,p) > 0
100 dico(:,p) = U(:,1);
101 else
102 dico(:,p) = -U(:,1);
103 end
104 dico(:,p) = dico(:,p)/norm(dico(:,p));
105 amp(p,index) = dico(:,p)'*patch;
106 if nargin >5 && strcmp(flow, 'sequential')
107 res(:,index) = patch-dico(:,p)*amp(p,index);
108 end
109 end
110 end
111 case 'MOCOD'
112 zeta = mocodParams.zeta;
113 eta = mocodParams.eta;
114 Dprev = mocodParams.Dprev;
115
116 dico = (sig*amp' + 2*(zeta+eta)*Dprev)/...
117 (amp*amp' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev)));
118 end
119 end
120