comparison toolboxes/FullBNT-1.0.7/KPMstats/clg_Mstep.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [mu, Sigma, B] = clg_Mstep(w, Y, YY, YTY, X, XX, XY, varargin)
2 % MSTEP_CLG Compute ML/MAP estimates for a conditional linear Gaussian
3 % [mu, Sigma, B] = Mstep_clg(w, Y, YY, YTY, X, XX, XY, varargin)
4 %
5 % We fit P(Y|X,Q=i) = N(Y; B_i X + mu_i, Sigma_i)
6 % and w(i,t) = p(M(t)=i|y(t)) = posterior responsibility
7 % See www.ai.mit.edu/~murphyk/Papers/learncg.pdf.
8 %
9 % See process_options for how to specify the input arguments.
10 %
11 % INPUTS:
12 % w(i) = sum_t w(i,t) = responsibilities for each mixture component
13 % If there is only one mixture component (i.e., Q does not exist),
14 % then w(i) = N = nsamples, and
15 % all references to i can be replaced by 1.
16 % Y(:,i) = sum_t w(i,t) y(:,t) = weighted observations
17 % YY(:,:,i) = sum_t w(i,t) y(:,t) y(:,t)' = weighted outer product
18 % YTY(i) = sum_t w(i,t) y(:,t)' y(:,t) = weighted inner product
19 % You only need to pass in YTY if Sigma is to be estimated as spherical.
20 %
21 % In the regression context, we must also pass in the following
22 % X(:,i) = sum_t w(i,t) x(:,t) = weighted inputs
23 % XX(:,:,i) = sum_t w(i,t) x(:,t) x(:,t)' = weighted outer product
24 % XY(i) = sum_t w(i,t) x(:,t) y(:,t)' = weighted outer product
25 %
26 % Optional inputs (default values in [])
27 %
28 % 'cov_type' - 'full', 'diag' or 'spherical' ['full']
29 % 'tied_cov' - 1 (Sigma) or 0 (Sigma_i) [0]
30 % 'clamped_cov' - pass in clamped value, or [] if unclamped [ [] ]
31 % 'clamped_mean' - pass in clamped value, or [] if unclamped [ [] ]
32 % 'clamped_weights' - pass in clamped value, or [] if unclamped [ [] ]
33 % 'cov_prior' - added to Sigma(:,:,i) to ensure psd [0.01*eye(d,d,Q)]
34 %
35 % If cov is tied, Sigma has size d*d.
36 % But diagonal and spherical covariances are represented in full size.
37
38 [cov_type, tied_cov, ...
39 clamped_cov, clamped_mean, clamped_weights, cov_prior, ...
40 xs, ys, post] = ...
41 process_options(varargin, ...
42 'cov_type', 'full', 'tied_cov', 0, 'clamped_cov', [], 'clamped_mean', [], ...
43 'clamped_weights', [], 'cov_prior', [], ...
44 'xs', [], 'ys', [], 'post', []);
45
46 [Ysz Q] = size(Y);
47
48 if isempty(X) % no regression
49 %B = [];
50 B2 = zeros(Ysz, 1, Q);
51 for i=1:Q
52 B(:,:,i) = B2(:,1:0,i); % make an empty array of size Ysz x 0 x Q
53 end
54 [mu, Sigma] = mixgauss_Mstep(w, Y, YY, YTY, varargin{:});
55 return;
56 end
57
58
59 N = sum(w);
60 if isempty(cov_prior)
61 cov_prior = 0.01*repmat(eye(Ysz,Ysz), [1 1 Q]);
62 end
63 %YY = YY + cov_prior; % regularize the scatter matrix
64
65 % Set any zero weights to one before dividing
66 % This is valid because w(i)=0 => Y(:,i)=0, etc
67 w = w + (w==0);
68
69 Xsz = size(X,1);
70 % Append 1 to X to get Z
71 ZZ = zeros(Xsz+1, Xsz+1, Q);
72 ZY = zeros(Xsz+1, Ysz, Q);
73 for i=1:Q
74 ZZ(:,:,i) = [XX(:,:,i) X(:,i);
75 X(:,i)' w(i)];
76 ZY(:,:,i) = [XY(:,:,i);
77 Y(:,i)'];
78 end
79
80
81 %%% Estimate mean and regression
82
83 if ~isempty(clamped_weights) & ~isempty(clamped_mean)
84 B = clamped_weights;
85 mu = clamped_mean;
86 end
87 if ~isempty(clamped_weights) & isempty(clamped_mean)
88 B = clamped_weights;
89 % eqn 5
90 mu = zeros(Ysz, Q);
91 for i=1:Q
92 mu(:,i) = (Y(:,i) - B(:,:,i)*X(:,i)) / w(i);
93 end
94 end
95 if isempty(clamped_weights) & ~isempty(clamped_mean)
96 mu = clamped_mean;
97 % eqn 3
98 B = zeros(Ysz, Xsz, Q);
99 for i=1:Q
100 tmp = XY(:,:,i)' - mu(:,i)*X(:,i)';
101 %B(:,:,i) = tmp * inv(XX(:,:,i));
102 B(:,:,i) = (XX(:,:,i) \ tmp')';
103 end
104 end
105 if isempty(clamped_weights) & isempty(clamped_mean)
106 mu = zeros(Ysz, Q);
107 B = zeros(Ysz, Xsz, Q);
108 % Nothing is clamped, so we must estimate B and mu jointly
109 for i=1:Q
110 % eqn 9
111 if rcond(ZZ(:,:,i)) < 1e-10
112 sprintf('clg_Mstep warning: ZZ(:,:,%d) is ill-conditioned', i);
113 % probably because there are too few cases for a high-dimensional input
114 ZZ(:,:,i) = ZZ(:,:,i) + 1e-5*eye(Xsz+1);
115 end
116 %A = ZY(:,:,i)' * inv(ZZ(:,:,i));
117 A = (ZZ(:,:,i) \ ZY(:,:,i))';
118 B(:,:,i) = A(:, 1:Xsz);
119 mu(:,i) = A(:, Xsz+1);
120 end
121 end
122
123 if ~isempty(clamped_cov)
124 Sigma = clamped_cov;
125 return;
126 end
127
128
129 %%% Estimate covariance
130
131 % Spherical
132 if cov_type(1)=='s'
133 if ~tied_cov
134 Sigma = zeros(Ysz, Ysz, Q);
135 for i=1:Q
136 % eqn 16
137 A = [B(:,:,i) mu(:,i)];
138 %s = trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i)) / (Ysz*w(i)); % wrong!
139 s = (YTY(i) + trace(A'*A*ZZ(:,:,i)) - trace(2*A*ZY(:,:,i))) / (Ysz*w(i));
140 Sigma(:,:,i) = s*eye(Ysz,Ysz);
141
142 %%%%%%%%%%%%%%%%%%% debug
143 if ~isempty(xs)
144 [nx T] = size(xs);
145 zs = [xs; ones(1,T)];
146 yty = 0;
147 zAAz = 0;
148 yAz = 0;
149 for t=1:T
150 yty = yty + ys(:,t)'*ys(:,t) * post(i,t);
151 zAAz = zAAz + zs(:,t)'*A'*A*zs(:,t)*post(i,t);
152 yAz = yAz + ys(:,t)'*A*zs(:,t)*post(i,t);
153 end
154 assert(approxeq(yty, YTY(i)))
155 assert(approxeq(zAAz, trace(A'*A*ZZ(:,:,i))))
156 assert(approxeq(yAz, trace(A*ZY(:,:,i))))
157 s2 = (yty + zAAz - 2*yAz) / (Ysz*w(i));
158 assert(approxeq(s,s2))
159 end
160 %%%%%%%%%%%%%%% end debug
161
162 end
163 else
164 S = 0;
165 for i=1:Q
166 % eqn 18
167 A = [B(:,:,i) mu(:,i)];
168 S = S + trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i));
169 end
170 Sigma = repmat(S / (N*Ysz), [1 1 Q]);
171 end
172 else % Full/diagonal
173 if ~tied_cov
174 Sigma = zeros(Ysz, Ysz, Q);
175 for i=1:Q
176 A = [B(:,:,i) mu(:,i)];
177 % eqn 10
178 SS = (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A') / w(i);
179 if cov_type(1)=='d'
180 Sigma(:,:,i) = diag(diag(SS));
181 else
182 Sigma(:,:,i) = SS;
183 end
184 end
185 else % tied
186 SS = zeros(Ysz, Ysz);
187 for i=1:Q
188 A = [B(:,:,i) mu(:,i)];
189 % eqn 13
190 SS = SS + (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A');
191 end
192 SS = SS / N;
193 if cov_type(1)=='d'
194 Sigma = diag(diag(SS));
195 else
196 Sigma = SS;
197 end
198 Sigma = repmat(Sigma, [1 1 Q]);
199 end
200 end
201
202 Sigma = Sigma + cov_prior;
203