ivan@152
|
1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
|
bmailhe@201
|
2
|
ivan@152
|
3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
ivan@152
|
4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
|
ivan@152
|
5 %
|
ivan@152
|
6 % perform one iteration of dictionary update for dictionary learning
|
ivan@152
|
7 %
|
ivan@152
|
8 % parameters:
|
ivan@152
|
9 % - dico: the initial dictionary with atoms as columns
|
ivan@152
|
10 % - sig: the training data
|
ivan@152
|
11 % - amp: the amplitude coefficients as a sparse matrix
|
ivan@152
|
12 % - type: the algorithm can be one of the following
|
bmailhe@201
|
13 % - ols: fixed step gradient descent, as described in Olshausen &
|
bmailhe@201
|
14 % Field95
|
bmailhe@201
|
15 % - opt: optimal step gradient descent, as described in Mailhe et
|
bmailhe@201
|
16 % al.08
|
bmailhe@201
|
17 % - MOD: pseudo-inverse of the coefficients, as described in Engan99
|
bmailhe@201
|
18 % - KSVD: PCA update as described in Aharon06. For fast applications,
|
bmailhe@201
|
19 % use KSVDbox rather than this code.
|
bmailhe@201
|
20 % - LGD: large step gradient descent. Equivalent to 'opt' with
|
bmailhe@201
|
21 % rho=2.
|
ivan@152
|
22 % - flow: 'sequential' or 'parallel'. If sequential, the residual is
|
ivan@152
|
23 % updated after each atom update. If parallel, the residual is only
|
bmailhe@201
|
24 % updated once the whole dictionary has been computed.
|
bmailhe@201
|
25 % Default: Sequential (sequential usually works better). Not used with
|
ivan@152
|
26 % MOD.
|
ivan@152
|
27 % - rho: learning rate. If the type is 'ols', it is the descent step of
|
bmailhe@201
|
28 % the gradient (default: 0.1). If the type is 'opt', the
|
bmailhe@201
|
29 % descent step is the optimal step*rho (default: 1, although 2 works
|
bmailhe@201
|
30 % better. See LGD for more details). Not used for MOD and KSVD.
|
ivan@152
|
31 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
ivan@152
|
32
|
ivan@152
|
33 if ~ exist( 'flow', 'var' ) || isempty(flow)
|
bmailhe@201
|
34 flow = 'sequential';
|
ivan@152
|
35 end
|
ivan@152
|
36
|
ivan@152
|
37 res = sig - dico*amp;
|
ivan@152
|
38 nb_pattern = size(dico, 2);
|
ivan@152
|
39
|
bmailhe@201
|
40 % if the type is random, then randomly pick another type
|
ivan@152
|
41 switch type
|
ivan@152
|
42 case 'rand'
|
ivan@152
|
43 x = rand();
|
ivan@152
|
44 if x < 1/3
|
ivan@152
|
45 type = 'MOD';
|
ivan@152
|
46 elseif type < 2/3
|
bmailhe@201
|
47 type = 'opt';
|
ivan@152
|
48 else
|
ivan@152
|
49 type = 'KSVD';
|
ivan@152
|
50 end
|
ivan@152
|
51 end
|
ivan@152
|
52
|
bmailhe@201
|
53 % set the learning rate to default if not provided
|
bmailhe@201
|
54 if ~ exist( 'rho', 'var' ) || isempty(rho)
|
bmailhe@201
|
55 switch type
|
bmailhe@201
|
56 case 'ols'
|
bmailhe@201
|
57 rho = 0.1;
|
bmailhe@201
|
58 case 'opt'
|
bmailhe@201
|
59 rho = 1;
|
bmailhe@201
|
60 end
|
bmailhe@201
|
61 end
|
bmailhe@201
|
62
|
ivan@152
|
63 switch type
|
ivan@152
|
64 case 'MOD'
|
ivan@152
|
65 G = amp*amp';
|
ivan@152
|
66 dico2 = sig*amp'*G^-1;
|
ivan@152
|
67 for p = 1:nb_pattern
|
ivan@152
|
68 n = norm(dico2(:,p));
|
ivan@152
|
69 % renormalize
|
ivan@152
|
70 if n > 0
|
ivan@152
|
71 dico(:,p) = dico2(:,p)/n;
|
ivan@152
|
72 amp(p,:) = amp(p,:)*n;
|
ivan@152
|
73 end
|
ivan@152
|
74 end
|
ivan@152
|
75 case 'ols'
|
ivan@152
|
76 for p = 1:nb_pattern
|
ivan@152
|
77 grad = res*amp(p,:)';
|
ivan@152
|
78 if norm(grad) > 0
|
ivan@152
|
79 pat = dico(:,p) + rho*grad;
|
ivan@152
|
80 pat = pat/norm(pat);
|
ivan@152
|
81 if nargin >5 && strcmp(flow, 'sequential')
|
ivan@152
|
82 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
|
ivan@152
|
83 end
|
ivan@152
|
84 dico(:,p) = pat;
|
ivan@152
|
85 end
|
ivan@152
|
86 end
|
bmailhe@201
|
87 case 'opt'
|
ivan@152
|
88 for p = 1:nb_pattern
|
bmailhe@207
|
89 index = find(amp(p,:)~=0);
|
bmailhe@207
|
90 vec = amp(p,index);
|
bmailhe@207
|
91 grad = res(:,index)*vec';
|
ivan@152
|
92 if norm(grad) > 0
|
bmailhe@201
|
93 pat = (vec*vec')*dico(:,p) + rho*grad;
|
ivan@152
|
94 pat = pat/norm(pat);
|
ivan@152
|
95 if nargin >5 && strcmp(flow, 'sequential')
|
bmailhe@207
|
96 res(:,index) = res(:,index) + (dico(:,p)-pat)*vec;
|
bmailhe@207
|
97 end
|
bmailhe@207
|
98 dico(:,p) = pat;
|
bmailhe@207
|
99 end
|
bmailhe@207
|
100 end
|
bmailhe@207
|
101 case 'LGD'
|
bmailhe@207
|
102 for p = 1:nb_pattern
|
bmailhe@207
|
103 index = find(amp(p,:)~=0);
|
bmailhe@207
|
104 vec = amp(p,index);
|
bmailhe@207
|
105 grad = res(:,index)*vec';
|
bmailhe@207
|
106 if norm(grad) > 0
|
bmailhe@207
|
107 pat = (vec*vec')*dico(:,p) + 2*grad;
|
bmailhe@207
|
108 pat = pat/norm(pat);
|
bmailhe@207
|
109 if nargin >5 && strcmp(flow, 'sequential')
|
bmailhe@207
|
110 res(:,index) = res(:,index) + (dico(:,p)-pat)*vec;
|
ivan@152
|
111 end
|
ivan@152
|
112 dico(:,p) = pat;
|
ivan@152
|
113 end
|
ivan@152
|
114 end
|
ivan@152
|
115 case 'KSVD'
|
ivan@152
|
116 for p = 1:nb_pattern
|
ivan@152
|
117 index = find(amp(p,:)~=0);
|
ivan@152
|
118 if ~isempty(index)
|
ivan@152
|
119 patch = res(:,index)+dico(:,p)*amp(p,index);
|
bmailhe@201
|
120 [U,~,V] = svd(patch);
|
ivan@152
|
121 if U(:,1)'*dico(:,p) > 0
|
ivan@152
|
122 dico(:,p) = U(:,1);
|
ivan@152
|
123 else
|
ivan@152
|
124 dico(:,p) = -U(:,1);
|
ivan@152
|
125 end
|
ivan@152
|
126 dico(:,p) = dico(:,p)/norm(dico(:,p));
|
ivan@152
|
127 amp(p,index) = dico(:,p)'*patch;
|
ivan@152
|
128 if nargin >5 && strcmp(flow, 'sequential')
|
ivan@152
|
129 res(:,index) = patch-dico(:,p)*amp(p,index);
|
ivan@152
|
130 end
|
ivan@152
|
131 end
|
ivan@152
|
132 end
|
ivan@152
|
133 end
|
ivan@152
|
134 end
|
ivan@152
|
135
|