Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/KPMstats/clg_Mstep.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [mu, Sigma, B] = clg_Mstep(w, Y, YY, YTY, X, XX, XY, varargin) | |
2 % MSTEP_CLG Compute ML/MAP estimates for a conditional linear Gaussian | |
3 % [mu, Sigma, B] = Mstep_clg(w, Y, YY, YTY, X, XX, XY, varargin) | |
4 % | |
5 % We fit P(Y|X,Q=i) = N(Y; B_i X + mu_i, Sigma_i) | |
6 % and w(i,t) = p(M(t)=i|y(t)) = posterior responsibility | |
7 % See www.ai.mit.edu/~murphyk/Papers/learncg.pdf. | |
8 % | |
9 % See process_options for how to specify the input arguments. | |
10 % | |
11 % INPUTS: | |
12 % w(i) = sum_t w(i,t) = responsibilities for each mixture component | |
13 % If there is only one mixture component (i.e., Q does not exist), | |
14 % then w(i) = N = nsamples, and | |
15 % all references to i can be replaced by 1. | |
16 % Y(:,i) = sum_t w(i,t) y(:,t) = weighted observations | |
17 % YY(:,:,i) = sum_t w(i,t) y(:,t) y(:,t)' = weighted outer product | |
18 % YTY(i) = sum_t w(i,t) y(:,t)' y(:,t) = weighted inner product | |
19 % You only need to pass in YTY if Sigma is to be estimated as spherical. | |
20 % | |
21 % In the regression context, we must also pass in the following | |
22 % X(:,i) = sum_t w(i,t) x(:,t) = weighted inputs | |
23 % XX(:,:,i) = sum_t w(i,t) x(:,t) x(:,t)' = weighted outer product | |
24 % XY(i) = sum_t w(i,t) x(:,t) y(:,t)' = weighted outer product | |
25 % | |
26 % Optional inputs (default values in []) | |
27 % | |
28 % 'cov_type' - 'full', 'diag' or 'spherical' ['full'] | |
29 % 'tied_cov' - 1 (Sigma) or 0 (Sigma_i) [0] | |
30 % 'clamped_cov' - pass in clamped value, or [] if unclamped [ [] ] | |
31 % 'clamped_mean' - pass in clamped value, or [] if unclamped [ [] ] | |
32 % 'clamped_weights' - pass in clamped value, or [] if unclamped [ [] ] | |
33 % 'cov_prior' - added to Sigma(:,:,i) to ensure psd [0.01*eye(d,d,Q)] | |
34 % | |
35 % If cov is tied, Sigma has size d*d. | |
36 % But diagonal and spherical covariances are represented in full size. | |
37 | |
38 [cov_type, tied_cov, ... | |
39 clamped_cov, clamped_mean, clamped_weights, cov_prior, ... | |
40 xs, ys, post] = ... | |
41 process_options(varargin, ... | |
42 'cov_type', 'full', 'tied_cov', 0, 'clamped_cov', [], 'clamped_mean', [], ... | |
43 'clamped_weights', [], 'cov_prior', [], ... | |
44 'xs', [], 'ys', [], 'post', []); | |
45 | |
46 [Ysz Q] = size(Y); | |
47 | |
48 if isempty(X) % no regression | |
49 %B = []; | |
50 B2 = zeros(Ysz, 1, Q); | |
51 for i=1:Q | |
52 B(:,:,i) = B2(:,1:0,i); % make an empty array of size Ysz x 0 x Q | |
53 end | |
54 [mu, Sigma] = mixgauss_Mstep(w, Y, YY, YTY, varargin{:}); | |
55 return; | |
56 end | |
57 | |
58 | |
59 N = sum(w); | |
60 if isempty(cov_prior) | |
61 cov_prior = 0.01*repmat(eye(Ysz,Ysz), [1 1 Q]); | |
62 end | |
63 %YY = YY + cov_prior; % regularize the scatter matrix | |
64 | |
65 % Set any zero weights to one before dividing | |
66 % This is valid because w(i)=0 => Y(:,i)=0, etc | |
67 w = w + (w==0); | |
68 | |
69 Xsz = size(X,1); | |
70 % Append 1 to X to get Z | |
71 ZZ = zeros(Xsz+1, Xsz+1, Q); | |
72 ZY = zeros(Xsz+1, Ysz, Q); | |
73 for i=1:Q | |
74 ZZ(:,:,i) = [XX(:,:,i) X(:,i); | |
75 X(:,i)' w(i)]; | |
76 ZY(:,:,i) = [XY(:,:,i); | |
77 Y(:,i)']; | |
78 end | |
79 | |
80 | |
81 %%% Estimate mean and regression | |
82 | |
83 if ~isempty(clamped_weights) & ~isempty(clamped_mean) | |
84 B = clamped_weights; | |
85 mu = clamped_mean; | |
86 end | |
87 if ~isempty(clamped_weights) & isempty(clamped_mean) | |
88 B = clamped_weights; | |
89 % eqn 5 | |
90 mu = zeros(Ysz, Q); | |
91 for i=1:Q | |
92 mu(:,i) = (Y(:,i) - B(:,:,i)*X(:,i)) / w(i); | |
93 end | |
94 end | |
95 if isempty(clamped_weights) & ~isempty(clamped_mean) | |
96 mu = clamped_mean; | |
97 % eqn 3 | |
98 B = zeros(Ysz, Xsz, Q); | |
99 for i=1:Q | |
100 tmp = XY(:,:,i)' - mu(:,i)*X(:,i)'; | |
101 %B(:,:,i) = tmp * inv(XX(:,:,i)); | |
102 B(:,:,i) = (XX(:,:,i) \ tmp')'; | |
103 end | |
104 end | |
105 if isempty(clamped_weights) & isempty(clamped_mean) | |
106 mu = zeros(Ysz, Q); | |
107 B = zeros(Ysz, Xsz, Q); | |
108 % Nothing is clamped, so we must estimate B and mu jointly | |
109 for i=1:Q | |
110 % eqn 9 | |
111 if rcond(ZZ(:,:,i)) < 1e-10 | |
112 sprintf('clg_Mstep warning: ZZ(:,:,%d) is ill-conditioned', i); | |
113 % probably because there are too few cases for a high-dimensional input | |
114 ZZ(:,:,i) = ZZ(:,:,i) + 1e-5*eye(Xsz+1); | |
115 end | |
116 %A = ZY(:,:,i)' * inv(ZZ(:,:,i)); | |
117 A = (ZZ(:,:,i) \ ZY(:,:,i))'; | |
118 B(:,:,i) = A(:, 1:Xsz); | |
119 mu(:,i) = A(:, Xsz+1); | |
120 end | |
121 end | |
122 | |
123 if ~isempty(clamped_cov) | |
124 Sigma = clamped_cov; | |
125 return; | |
126 end | |
127 | |
128 | |
129 %%% Estimate covariance | |
130 | |
131 % Spherical | |
132 if cov_type(1)=='s' | |
133 if ~tied_cov | |
134 Sigma = zeros(Ysz, Ysz, Q); | |
135 for i=1:Q | |
136 % eqn 16 | |
137 A = [B(:,:,i) mu(:,i)]; | |
138 %s = trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i)) / (Ysz*w(i)); % wrong! | |
139 s = (YTY(i) + trace(A'*A*ZZ(:,:,i)) - trace(2*A*ZY(:,:,i))) / (Ysz*w(i)); | |
140 Sigma(:,:,i) = s*eye(Ysz,Ysz); | |
141 | |
142 %%%%%%%%%%%%%%%%%%% debug | |
143 if ~isempty(xs) | |
144 [nx T] = size(xs); | |
145 zs = [xs; ones(1,T)]; | |
146 yty = 0; | |
147 zAAz = 0; | |
148 yAz = 0; | |
149 for t=1:T | |
150 yty = yty + ys(:,t)'*ys(:,t) * post(i,t); | |
151 zAAz = zAAz + zs(:,t)'*A'*A*zs(:,t)*post(i,t); | |
152 yAz = yAz + ys(:,t)'*A*zs(:,t)*post(i,t); | |
153 end | |
154 assert(approxeq(yty, YTY(i))) | |
155 assert(approxeq(zAAz, trace(A'*A*ZZ(:,:,i)))) | |
156 assert(approxeq(yAz, trace(A*ZY(:,:,i)))) | |
157 s2 = (yty + zAAz - 2*yAz) / (Ysz*w(i)); | |
158 assert(approxeq(s,s2)) | |
159 end | |
160 %%%%%%%%%%%%%%% end debug | |
161 | |
162 end | |
163 else | |
164 S = 0; | |
165 for i=1:Q | |
166 % eqn 18 | |
167 A = [B(:,:,i) mu(:,i)]; | |
168 S = S + trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i)); | |
169 end | |
170 Sigma = repmat(S / (N*Ysz), [1 1 Q]); | |
171 end | |
172 else % Full/diagonal | |
173 if ~tied_cov | |
174 Sigma = zeros(Ysz, Ysz, Q); | |
175 for i=1:Q | |
176 A = [B(:,:,i) mu(:,i)]; | |
177 % eqn 10 | |
178 SS = (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A') / w(i); | |
179 if cov_type(1)=='d' | |
180 Sigma(:,:,i) = diag(diag(SS)); | |
181 else | |
182 Sigma(:,:,i) = SS; | |
183 end | |
184 end | |
185 else % tied | |
186 SS = zeros(Ysz, Ysz); | |
187 for i=1:Q | |
188 A = [B(:,:,i) mu(:,i)]; | |
189 % eqn 13 | |
190 SS = SS + (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A'); | |
191 end | |
192 SS = SS / N; | |
193 if cov_type(1)=='d' | |
194 Sigma = diag(diag(SS)); | |
195 else | |
196 Sigma = SS; | |
197 end | |
198 Sigma = repmat(Sigma, [1 1 Q]); | |
199 end | |
200 end | |
201 | |
202 Sigma = Sigma + cov_prior; | |
203 |