wolffd@0
|
1 function [mu, Sigma, B] = clg_Mstep(w, Y, YY, YTY, X, XX, XY, varargin)
|
wolffd@0
|
2 % MSTEP_CLG Compute ML/MAP estimates for a conditional linear Gaussian
|
wolffd@0
|
3 % [mu, Sigma, B] = Mstep_clg(w, Y, YY, YTY, X, XX, XY, varargin)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % We fit P(Y|X,Q=i) = N(Y; B_i X + mu_i, Sigma_i)
|
wolffd@0
|
6 % and w(i,t) = p(M(t)=i|y(t)) = posterior responsibility
|
wolffd@0
|
7 % See www.ai.mit.edu/~murphyk/Papers/learncg.pdf.
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % See process_options for how to specify the input arguments.
|
wolffd@0
|
10 %
|
wolffd@0
|
11 % INPUTS:
|
wolffd@0
|
12 % w(i) = sum_t w(i,t) = responsibilities for each mixture component
|
wolffd@0
|
13 % If there is only one mixture component (i.e., Q does not exist),
|
wolffd@0
|
14 % then w(i) = N = nsamples, and
|
wolffd@0
|
15 % all references to i can be replaced by 1.
|
wolffd@0
|
16 % Y(:,i) = sum_t w(i,t) y(:,t) = weighted observations
|
wolffd@0
|
17 % YY(:,:,i) = sum_t w(i,t) y(:,t) y(:,t)' = weighted outer product
|
wolffd@0
|
18 % YTY(i) = sum_t w(i,t) y(:,t)' y(:,t) = weighted inner product
|
wolffd@0
|
19 % You only need to pass in YTY if Sigma is to be estimated as spherical.
|
wolffd@0
|
20 %
|
wolffd@0
|
21 % In the regression context, we must also pass in the following
|
wolffd@0
|
22 % X(:,i) = sum_t w(i,t) x(:,t) = weighted inputs
|
wolffd@0
|
23 % XX(:,:,i) = sum_t w(i,t) x(:,t) x(:,t)' = weighted outer product
|
wolffd@0
|
24 % XY(i) = sum_t w(i,t) x(:,t) y(:,t)' = weighted outer product
|
wolffd@0
|
25 %
|
wolffd@0
|
26 % Optional inputs (default values in [])
|
wolffd@0
|
27 %
|
wolffd@0
|
28 % 'cov_type' - 'full', 'diag' or 'spherical' ['full']
|
wolffd@0
|
29 % 'tied_cov' - 1 (Sigma) or 0 (Sigma_i) [0]
|
wolffd@0
|
30 % 'clamped_cov' - pass in clamped value, or [] if unclamped [ [] ]
|
wolffd@0
|
31 % 'clamped_mean' - pass in clamped value, or [] if unclamped [ [] ]
|
wolffd@0
|
32 % 'clamped_weights' - pass in clamped value, or [] if unclamped [ [] ]
|
wolffd@0
|
33 % 'cov_prior' - added to Sigma(:,:,i) to ensure psd [0.01*eye(d,d,Q)]
|
wolffd@0
|
34 %
|
wolffd@0
|
35 % If cov is tied, Sigma has size d*d.
|
wolffd@0
|
36 % But diagonal and spherical covariances are represented in full size.
|
wolffd@0
|
37
|
wolffd@0
|
38 [cov_type, tied_cov, ...
|
wolffd@0
|
39 clamped_cov, clamped_mean, clamped_weights, cov_prior, ...
|
wolffd@0
|
40 xs, ys, post] = ...
|
wolffd@0
|
41 process_options(varargin, ...
|
wolffd@0
|
42 'cov_type', 'full', 'tied_cov', 0, 'clamped_cov', [], 'clamped_mean', [], ...
|
wolffd@0
|
43 'clamped_weights', [], 'cov_prior', [], ...
|
wolffd@0
|
44 'xs', [], 'ys', [], 'post', []);
|
wolffd@0
|
45
|
wolffd@0
|
46 [Ysz Q] = size(Y);
|
wolffd@0
|
47
|
wolffd@0
|
48 if isempty(X) % no regression
|
wolffd@0
|
49 %B = [];
|
wolffd@0
|
50 B2 = zeros(Ysz, 1, Q);
|
wolffd@0
|
51 for i=1:Q
|
wolffd@0
|
52 B(:,:,i) = B2(:,1:0,i); % make an empty array of size Ysz x 0 x Q
|
wolffd@0
|
53 end
|
wolffd@0
|
54 [mu, Sigma] = mixgauss_Mstep(w, Y, YY, YTY, varargin{:});
|
wolffd@0
|
55 return;
|
wolffd@0
|
56 end
|
wolffd@0
|
57
|
wolffd@0
|
58
|
wolffd@0
|
59 N = sum(w);
|
wolffd@0
|
60 if isempty(cov_prior)
|
wolffd@0
|
61 cov_prior = 0.01*repmat(eye(Ysz,Ysz), [1 1 Q]);
|
wolffd@0
|
62 end
|
wolffd@0
|
63 %YY = YY + cov_prior; % regularize the scatter matrix
|
wolffd@0
|
64
|
wolffd@0
|
65 % Set any zero weights to one before dividing
|
wolffd@0
|
66 % This is valid because w(i)=0 => Y(:,i)=0, etc
|
wolffd@0
|
67 w = w + (w==0);
|
wolffd@0
|
68
|
wolffd@0
|
69 Xsz = size(X,1);
|
wolffd@0
|
70 % Append 1 to X to get Z
|
wolffd@0
|
71 ZZ = zeros(Xsz+1, Xsz+1, Q);
|
wolffd@0
|
72 ZY = zeros(Xsz+1, Ysz, Q);
|
wolffd@0
|
73 for i=1:Q
|
wolffd@0
|
74 ZZ(:,:,i) = [XX(:,:,i) X(:,i);
|
wolffd@0
|
75 X(:,i)' w(i)];
|
wolffd@0
|
76 ZY(:,:,i) = [XY(:,:,i);
|
wolffd@0
|
77 Y(:,i)'];
|
wolffd@0
|
78 end
|
wolffd@0
|
79
|
wolffd@0
|
80
|
wolffd@0
|
81 %%% Estimate mean and regression
|
wolffd@0
|
82
|
wolffd@0
|
83 if ~isempty(clamped_weights) & ~isempty(clamped_mean)
|
wolffd@0
|
84 B = clamped_weights;
|
wolffd@0
|
85 mu = clamped_mean;
|
wolffd@0
|
86 end
|
wolffd@0
|
87 if ~isempty(clamped_weights) & isempty(clamped_mean)
|
wolffd@0
|
88 B = clamped_weights;
|
wolffd@0
|
89 % eqn 5
|
wolffd@0
|
90 mu = zeros(Ysz, Q);
|
wolffd@0
|
91 for i=1:Q
|
wolffd@0
|
92 mu(:,i) = (Y(:,i) - B(:,:,i)*X(:,i)) / w(i);
|
wolffd@0
|
93 end
|
wolffd@0
|
94 end
|
wolffd@0
|
95 if isempty(clamped_weights) & ~isempty(clamped_mean)
|
wolffd@0
|
96 mu = clamped_mean;
|
wolffd@0
|
97 % eqn 3
|
wolffd@0
|
98 B = zeros(Ysz, Xsz, Q);
|
wolffd@0
|
99 for i=1:Q
|
wolffd@0
|
100 tmp = XY(:,:,i)' - mu(:,i)*X(:,i)';
|
wolffd@0
|
101 %B(:,:,i) = tmp * inv(XX(:,:,i));
|
wolffd@0
|
102 B(:,:,i) = (XX(:,:,i) \ tmp')';
|
wolffd@0
|
103 end
|
wolffd@0
|
104 end
|
wolffd@0
|
105 if isempty(clamped_weights) & isempty(clamped_mean)
|
wolffd@0
|
106 mu = zeros(Ysz, Q);
|
wolffd@0
|
107 B = zeros(Ysz, Xsz, Q);
|
wolffd@0
|
108 % Nothing is clamped, so we must estimate B and mu jointly
|
wolffd@0
|
109 for i=1:Q
|
wolffd@0
|
110 % eqn 9
|
wolffd@0
|
111 if rcond(ZZ(:,:,i)) < 1e-10
|
wolffd@0
|
112 sprintf('clg_Mstep warning: ZZ(:,:,%d) is ill-conditioned', i);
|
wolffd@0
|
113 % probably because there are too few cases for a high-dimensional input
|
wolffd@0
|
114 ZZ(:,:,i) = ZZ(:,:,i) + 1e-5*eye(Xsz+1);
|
wolffd@0
|
115 end
|
wolffd@0
|
116 %A = ZY(:,:,i)' * inv(ZZ(:,:,i));
|
wolffd@0
|
117 A = (ZZ(:,:,i) \ ZY(:,:,i))';
|
wolffd@0
|
118 B(:,:,i) = A(:, 1:Xsz);
|
wolffd@0
|
119 mu(:,i) = A(:, Xsz+1);
|
wolffd@0
|
120 end
|
wolffd@0
|
121 end
|
wolffd@0
|
122
|
wolffd@0
|
123 if ~isempty(clamped_cov)
|
wolffd@0
|
124 Sigma = clamped_cov;
|
wolffd@0
|
125 return;
|
wolffd@0
|
126 end
|
wolffd@0
|
127
|
wolffd@0
|
128
|
wolffd@0
|
129 %%% Estimate covariance
|
wolffd@0
|
130
|
wolffd@0
|
131 % Spherical
|
wolffd@0
|
132 if cov_type(1)=='s'
|
wolffd@0
|
133 if ~tied_cov
|
wolffd@0
|
134 Sigma = zeros(Ysz, Ysz, Q);
|
wolffd@0
|
135 for i=1:Q
|
wolffd@0
|
136 % eqn 16
|
wolffd@0
|
137 A = [B(:,:,i) mu(:,i)];
|
wolffd@0
|
138 %s = trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i)) / (Ysz*w(i)); % wrong!
|
wolffd@0
|
139 s = (YTY(i) + trace(A'*A*ZZ(:,:,i)) - trace(2*A*ZY(:,:,i))) / (Ysz*w(i));
|
wolffd@0
|
140 Sigma(:,:,i) = s*eye(Ysz,Ysz);
|
wolffd@0
|
141
|
wolffd@0
|
142 %%%%%%%%%%%%%%%%%%% debug
|
wolffd@0
|
143 if ~isempty(xs)
|
wolffd@0
|
144 [nx T] = size(xs);
|
wolffd@0
|
145 zs = [xs; ones(1,T)];
|
wolffd@0
|
146 yty = 0;
|
wolffd@0
|
147 zAAz = 0;
|
wolffd@0
|
148 yAz = 0;
|
wolffd@0
|
149 for t=1:T
|
wolffd@0
|
150 yty = yty + ys(:,t)'*ys(:,t) * post(i,t);
|
wolffd@0
|
151 zAAz = zAAz + zs(:,t)'*A'*A*zs(:,t)*post(i,t);
|
wolffd@0
|
152 yAz = yAz + ys(:,t)'*A*zs(:,t)*post(i,t);
|
wolffd@0
|
153 end
|
wolffd@0
|
154 assert(approxeq(yty, YTY(i)))
|
wolffd@0
|
155 assert(approxeq(zAAz, trace(A'*A*ZZ(:,:,i))))
|
wolffd@0
|
156 assert(approxeq(yAz, trace(A*ZY(:,:,i))))
|
wolffd@0
|
157 s2 = (yty + zAAz - 2*yAz) / (Ysz*w(i));
|
wolffd@0
|
158 assert(approxeq(s,s2))
|
wolffd@0
|
159 end
|
wolffd@0
|
160 %%%%%%%%%%%%%%% end debug
|
wolffd@0
|
161
|
wolffd@0
|
162 end
|
wolffd@0
|
163 else
|
wolffd@0
|
164 S = 0;
|
wolffd@0
|
165 for i=1:Q
|
wolffd@0
|
166 % eqn 18
|
wolffd@0
|
167 A = [B(:,:,i) mu(:,i)];
|
wolffd@0
|
168 S = S + trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i));
|
wolffd@0
|
169 end
|
wolffd@0
|
170 Sigma = repmat(S / (N*Ysz), [1 1 Q]);
|
wolffd@0
|
171 end
|
wolffd@0
|
172 else % Full/diagonal
|
wolffd@0
|
173 if ~tied_cov
|
wolffd@0
|
174 Sigma = zeros(Ysz, Ysz, Q);
|
wolffd@0
|
175 for i=1:Q
|
wolffd@0
|
176 A = [B(:,:,i) mu(:,i)];
|
wolffd@0
|
177 % eqn 10
|
wolffd@0
|
178 SS = (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A') / w(i);
|
wolffd@0
|
179 if cov_type(1)=='d'
|
wolffd@0
|
180 Sigma(:,:,i) = diag(diag(SS));
|
wolffd@0
|
181 else
|
wolffd@0
|
182 Sigma(:,:,i) = SS;
|
wolffd@0
|
183 end
|
wolffd@0
|
184 end
|
wolffd@0
|
185 else % tied
|
wolffd@0
|
186 SS = zeros(Ysz, Ysz);
|
wolffd@0
|
187 for i=1:Q
|
wolffd@0
|
188 A = [B(:,:,i) mu(:,i)];
|
wolffd@0
|
189 % eqn 13
|
wolffd@0
|
190 SS = SS + (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A');
|
wolffd@0
|
191 end
|
wolffd@0
|
192 SS = SS / N;
|
wolffd@0
|
193 if cov_type(1)=='d'
|
wolffd@0
|
194 Sigma = diag(diag(SS));
|
wolffd@0
|
195 else
|
wolffd@0
|
196 Sigma = SS;
|
wolffd@0
|
197 end
|
wolffd@0
|
198 Sigma = repmat(Sigma, [1 1 Q]);
|
wolffd@0
|
199 end
|
wolffd@0
|
200 end
|
wolffd@0
|
201
|
wolffd@0
|
202 Sigma = Sigma + cov_prior;
|
wolffd@0
|
203
|