annotate toolboxes/FullBNT-1.0.7/KPMstats/clg_Mstep.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [mu, Sigma, B] = clg_Mstep(w, Y, YY, YTY, X, XX, XY, varargin)
wolffd@0 2 % MSTEP_CLG Compute ML/MAP estimates for a conditional linear Gaussian
wolffd@0 3 % [mu, Sigma, B] = Mstep_clg(w, Y, YY, YTY, X, XX, XY, varargin)
wolffd@0 4 %
wolffd@0 5 % We fit P(Y|X,Q=i) = N(Y; B_i X + mu_i, Sigma_i)
wolffd@0 6 % and w(i,t) = p(M(t)=i|y(t)) = posterior responsibility
wolffd@0 7 % See www.ai.mit.edu/~murphyk/Papers/learncg.pdf.
wolffd@0 8 %
wolffd@0 9 % See process_options for how to specify the input arguments.
wolffd@0 10 %
wolffd@0 11 % INPUTS:
wolffd@0 12 % w(i) = sum_t w(i,t) = responsibilities for each mixture component
wolffd@0 13 % If there is only one mixture component (i.e., Q does not exist),
wolffd@0 14 % then w(i) = N = nsamples, and
wolffd@0 15 % all references to i can be replaced by 1.
wolffd@0 16 % Y(:,i) = sum_t w(i,t) y(:,t) = weighted observations
wolffd@0 17 % YY(:,:,i) = sum_t w(i,t) y(:,t) y(:,t)' = weighted outer product
wolffd@0 18 % YTY(i) = sum_t w(i,t) y(:,t)' y(:,t) = weighted inner product
wolffd@0 19 % You only need to pass in YTY if Sigma is to be estimated as spherical.
wolffd@0 20 %
wolffd@0 21 % In the regression context, we must also pass in the following
wolffd@0 22 % X(:,i) = sum_t w(i,t) x(:,t) = weighted inputs
wolffd@0 23 % XX(:,:,i) = sum_t w(i,t) x(:,t) x(:,t)' = weighted outer product
wolffd@0 24 % XY(i) = sum_t w(i,t) x(:,t) y(:,t)' = weighted outer product
wolffd@0 25 %
wolffd@0 26 % Optional inputs (default values in [])
wolffd@0 27 %
wolffd@0 28 % 'cov_type' - 'full', 'diag' or 'spherical' ['full']
wolffd@0 29 % 'tied_cov' - 1 (Sigma) or 0 (Sigma_i) [0]
wolffd@0 30 % 'clamped_cov' - pass in clamped value, or [] if unclamped [ [] ]
wolffd@0 31 % 'clamped_mean' - pass in clamped value, or [] if unclamped [ [] ]
wolffd@0 32 % 'clamped_weights' - pass in clamped value, or [] if unclamped [ [] ]
wolffd@0 33 % 'cov_prior' - added to Sigma(:,:,i) to ensure psd [0.01*eye(d,d,Q)]
wolffd@0 34 %
wolffd@0 35 % If cov is tied, Sigma has size d*d.
wolffd@0 36 % But diagonal and spherical covariances are represented in full size.
wolffd@0 37
wolffd@0 38 [cov_type, tied_cov, ...
wolffd@0 39 clamped_cov, clamped_mean, clamped_weights, cov_prior, ...
wolffd@0 40 xs, ys, post] = ...
wolffd@0 41 process_options(varargin, ...
wolffd@0 42 'cov_type', 'full', 'tied_cov', 0, 'clamped_cov', [], 'clamped_mean', [], ...
wolffd@0 43 'clamped_weights', [], 'cov_prior', [], ...
wolffd@0 44 'xs', [], 'ys', [], 'post', []);
wolffd@0 45
wolffd@0 46 [Ysz Q] = size(Y);
wolffd@0 47
wolffd@0 48 if isempty(X) % no regression
wolffd@0 49 %B = [];
wolffd@0 50 B2 = zeros(Ysz, 1, Q);
wolffd@0 51 for i=1:Q
wolffd@0 52 B(:,:,i) = B2(:,1:0,i); % make an empty array of size Ysz x 0 x Q
wolffd@0 53 end
wolffd@0 54 [mu, Sigma] = mixgauss_Mstep(w, Y, YY, YTY, varargin{:});
wolffd@0 55 return;
wolffd@0 56 end
wolffd@0 57
wolffd@0 58
wolffd@0 59 N = sum(w);
wolffd@0 60 if isempty(cov_prior)
wolffd@0 61 cov_prior = 0.01*repmat(eye(Ysz,Ysz), [1 1 Q]);
wolffd@0 62 end
wolffd@0 63 %YY = YY + cov_prior; % regularize the scatter matrix
wolffd@0 64
wolffd@0 65 % Set any zero weights to one before dividing
wolffd@0 66 % This is valid because w(i)=0 => Y(:,i)=0, etc
wolffd@0 67 w = w + (w==0);
wolffd@0 68
wolffd@0 69 Xsz = size(X,1);
wolffd@0 70 % Append 1 to X to get Z
wolffd@0 71 ZZ = zeros(Xsz+1, Xsz+1, Q);
wolffd@0 72 ZY = zeros(Xsz+1, Ysz, Q);
wolffd@0 73 for i=1:Q
wolffd@0 74 ZZ(:,:,i) = [XX(:,:,i) X(:,i);
wolffd@0 75 X(:,i)' w(i)];
wolffd@0 76 ZY(:,:,i) = [XY(:,:,i);
wolffd@0 77 Y(:,i)'];
wolffd@0 78 end
wolffd@0 79
wolffd@0 80
wolffd@0 81 %%% Estimate mean and regression
wolffd@0 82
wolffd@0 83 if ~isempty(clamped_weights) & ~isempty(clamped_mean)
wolffd@0 84 B = clamped_weights;
wolffd@0 85 mu = clamped_mean;
wolffd@0 86 end
wolffd@0 87 if ~isempty(clamped_weights) & isempty(clamped_mean)
wolffd@0 88 B = clamped_weights;
wolffd@0 89 % eqn 5
wolffd@0 90 mu = zeros(Ysz, Q);
wolffd@0 91 for i=1:Q
wolffd@0 92 mu(:,i) = (Y(:,i) - B(:,:,i)*X(:,i)) / w(i);
wolffd@0 93 end
wolffd@0 94 end
wolffd@0 95 if isempty(clamped_weights) & ~isempty(clamped_mean)
wolffd@0 96 mu = clamped_mean;
wolffd@0 97 % eqn 3
wolffd@0 98 B = zeros(Ysz, Xsz, Q);
wolffd@0 99 for i=1:Q
wolffd@0 100 tmp = XY(:,:,i)' - mu(:,i)*X(:,i)';
wolffd@0 101 %B(:,:,i) = tmp * inv(XX(:,:,i));
wolffd@0 102 B(:,:,i) = (XX(:,:,i) \ tmp')';
wolffd@0 103 end
wolffd@0 104 end
wolffd@0 105 if isempty(clamped_weights) & isempty(clamped_mean)
wolffd@0 106 mu = zeros(Ysz, Q);
wolffd@0 107 B = zeros(Ysz, Xsz, Q);
wolffd@0 108 % Nothing is clamped, so we must estimate B and mu jointly
wolffd@0 109 for i=1:Q
wolffd@0 110 % eqn 9
wolffd@0 111 if rcond(ZZ(:,:,i)) < 1e-10
wolffd@0 112 sprintf('clg_Mstep warning: ZZ(:,:,%d) is ill-conditioned', i);
wolffd@0 113 % probably because there are too few cases for a high-dimensional input
wolffd@0 114 ZZ(:,:,i) = ZZ(:,:,i) + 1e-5*eye(Xsz+1);
wolffd@0 115 end
wolffd@0 116 %A = ZY(:,:,i)' * inv(ZZ(:,:,i));
wolffd@0 117 A = (ZZ(:,:,i) \ ZY(:,:,i))';
wolffd@0 118 B(:,:,i) = A(:, 1:Xsz);
wolffd@0 119 mu(:,i) = A(:, Xsz+1);
wolffd@0 120 end
wolffd@0 121 end
wolffd@0 122
wolffd@0 123 if ~isempty(clamped_cov)
wolffd@0 124 Sigma = clamped_cov;
wolffd@0 125 return;
wolffd@0 126 end
wolffd@0 127
wolffd@0 128
wolffd@0 129 %%% Estimate covariance
wolffd@0 130
wolffd@0 131 % Spherical
wolffd@0 132 if cov_type(1)=='s'
wolffd@0 133 if ~tied_cov
wolffd@0 134 Sigma = zeros(Ysz, Ysz, Q);
wolffd@0 135 for i=1:Q
wolffd@0 136 % eqn 16
wolffd@0 137 A = [B(:,:,i) mu(:,i)];
wolffd@0 138 %s = trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i)) / (Ysz*w(i)); % wrong!
wolffd@0 139 s = (YTY(i) + trace(A'*A*ZZ(:,:,i)) - trace(2*A*ZY(:,:,i))) / (Ysz*w(i));
wolffd@0 140 Sigma(:,:,i) = s*eye(Ysz,Ysz);
wolffd@0 141
wolffd@0 142 %%%%%%%%%%%%%%%%%%% debug
wolffd@0 143 if ~isempty(xs)
wolffd@0 144 [nx T] = size(xs);
wolffd@0 145 zs = [xs; ones(1,T)];
wolffd@0 146 yty = 0;
wolffd@0 147 zAAz = 0;
wolffd@0 148 yAz = 0;
wolffd@0 149 for t=1:T
wolffd@0 150 yty = yty + ys(:,t)'*ys(:,t) * post(i,t);
wolffd@0 151 zAAz = zAAz + zs(:,t)'*A'*A*zs(:,t)*post(i,t);
wolffd@0 152 yAz = yAz + ys(:,t)'*A*zs(:,t)*post(i,t);
wolffd@0 153 end
wolffd@0 154 assert(approxeq(yty, YTY(i)))
wolffd@0 155 assert(approxeq(zAAz, trace(A'*A*ZZ(:,:,i))))
wolffd@0 156 assert(approxeq(yAz, trace(A*ZY(:,:,i))))
wolffd@0 157 s2 = (yty + zAAz - 2*yAz) / (Ysz*w(i));
wolffd@0 158 assert(approxeq(s,s2))
wolffd@0 159 end
wolffd@0 160 %%%%%%%%%%%%%%% end debug
wolffd@0 161
wolffd@0 162 end
wolffd@0 163 else
wolffd@0 164 S = 0;
wolffd@0 165 for i=1:Q
wolffd@0 166 % eqn 18
wolffd@0 167 A = [B(:,:,i) mu(:,i)];
wolffd@0 168 S = S + trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i));
wolffd@0 169 end
wolffd@0 170 Sigma = repmat(S / (N*Ysz), [1 1 Q]);
wolffd@0 171 end
wolffd@0 172 else % Full/diagonal
wolffd@0 173 if ~tied_cov
wolffd@0 174 Sigma = zeros(Ysz, Ysz, Q);
wolffd@0 175 for i=1:Q
wolffd@0 176 A = [B(:,:,i) mu(:,i)];
wolffd@0 177 % eqn 10
wolffd@0 178 SS = (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A') / w(i);
wolffd@0 179 if cov_type(1)=='d'
wolffd@0 180 Sigma(:,:,i) = diag(diag(SS));
wolffd@0 181 else
wolffd@0 182 Sigma(:,:,i) = SS;
wolffd@0 183 end
wolffd@0 184 end
wolffd@0 185 else % tied
wolffd@0 186 SS = zeros(Ysz, Ysz);
wolffd@0 187 for i=1:Q
wolffd@0 188 A = [B(:,:,i) mu(:,i)];
wolffd@0 189 % eqn 13
wolffd@0 190 SS = SS + (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A');
wolffd@0 191 end
wolffd@0 192 SS = SS / N;
wolffd@0 193 if cov_type(1)=='d'
wolffd@0 194 Sigma = diag(diag(SS));
wolffd@0 195 else
wolffd@0 196 Sigma = SS;
wolffd@0 197 end
wolffd@0 198 Sigma = repmat(Sigma, [1 1 Q]);
wolffd@0 199 end
wolffd@0 200 end
wolffd@0 201
wolffd@0 202 Sigma = Sigma + cov_prior;
wolffd@0 203