wolffd@0: function [mu, Sigma, B] = clg_Mstep(w, Y, YY, YTY, X, XX, XY, varargin) wolffd@0: % MSTEP_CLG Compute ML/MAP estimates for a conditional linear Gaussian wolffd@0: % [mu, Sigma, B] = Mstep_clg(w, Y, YY, YTY, X, XX, XY, varargin) wolffd@0: % wolffd@0: % We fit P(Y|X,Q=i) = N(Y; B_i X + mu_i, Sigma_i) wolffd@0: % and w(i,t) = p(M(t)=i|y(t)) = posterior responsibility wolffd@0: % See www.ai.mit.edu/~murphyk/Papers/learncg.pdf. wolffd@0: % wolffd@0: % See process_options for how to specify the input arguments. wolffd@0: % wolffd@0: % INPUTS: wolffd@0: % w(i) = sum_t w(i,t) = responsibilities for each mixture component wolffd@0: % If there is only one mixture component (i.e., Q does not exist), wolffd@0: % then w(i) = N = nsamples, and wolffd@0: % all references to i can be replaced by 1. wolffd@0: % Y(:,i) = sum_t w(i,t) y(:,t) = weighted observations wolffd@0: % YY(:,:,i) = sum_t w(i,t) y(:,t) y(:,t)' = weighted outer product wolffd@0: % YTY(i) = sum_t w(i,t) y(:,t)' y(:,t) = weighted inner product wolffd@0: % You only need to pass in YTY if Sigma is to be estimated as spherical. wolffd@0: % wolffd@0: % In the regression context, we must also pass in the following wolffd@0: % X(:,i) = sum_t w(i,t) x(:,t) = weighted inputs wolffd@0: % XX(:,:,i) = sum_t w(i,t) x(:,t) x(:,t)' = weighted outer product wolffd@0: % XY(i) = sum_t w(i,t) x(:,t) y(:,t)' = weighted outer product wolffd@0: % wolffd@0: % Optional inputs (default values in []) wolffd@0: % wolffd@0: % 'cov_type' - 'full', 'diag' or 'spherical' ['full'] wolffd@0: % 'tied_cov' - 1 (Sigma) or 0 (Sigma_i) [0] wolffd@0: % 'clamped_cov' - pass in clamped value, or [] if unclamped [ [] ] wolffd@0: % 'clamped_mean' - pass in clamped value, or [] if unclamped [ [] ] wolffd@0: % 'clamped_weights' - pass in clamped value, or [] if unclamped [ [] ] wolffd@0: % 'cov_prior' - added to Sigma(:,:,i) to ensure psd [0.01*eye(d,d,Q)] wolffd@0: % wolffd@0: % If cov is tied, Sigma has size d*d. wolffd@0: % But diagonal and spherical covariances are represented in full size. wolffd@0: wolffd@0: [cov_type, tied_cov, ... wolffd@0: clamped_cov, clamped_mean, clamped_weights, cov_prior, ... wolffd@0: xs, ys, post] = ... wolffd@0: process_options(varargin, ... wolffd@0: 'cov_type', 'full', 'tied_cov', 0, 'clamped_cov', [], 'clamped_mean', [], ... wolffd@0: 'clamped_weights', [], 'cov_prior', [], ... wolffd@0: 'xs', [], 'ys', [], 'post', []); wolffd@0: wolffd@0: [Ysz Q] = size(Y); wolffd@0: wolffd@0: if isempty(X) % no regression wolffd@0: %B = []; wolffd@0: B2 = zeros(Ysz, 1, Q); wolffd@0: for i=1:Q wolffd@0: B(:,:,i) = B2(:,1:0,i); % make an empty array of size Ysz x 0 x Q wolffd@0: end wolffd@0: [mu, Sigma] = mixgauss_Mstep(w, Y, YY, YTY, varargin{:}); wolffd@0: return; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: N = sum(w); wolffd@0: if isempty(cov_prior) wolffd@0: cov_prior = 0.01*repmat(eye(Ysz,Ysz), [1 1 Q]); wolffd@0: end wolffd@0: %YY = YY + cov_prior; % regularize the scatter matrix wolffd@0: wolffd@0: % Set any zero weights to one before dividing wolffd@0: % This is valid because w(i)=0 => Y(:,i)=0, etc wolffd@0: w = w + (w==0); wolffd@0: wolffd@0: Xsz = size(X,1); wolffd@0: % Append 1 to X to get Z wolffd@0: ZZ = zeros(Xsz+1, Xsz+1, Q); wolffd@0: ZY = zeros(Xsz+1, Ysz, Q); wolffd@0: for i=1:Q wolffd@0: ZZ(:,:,i) = [XX(:,:,i) X(:,i); wolffd@0: X(:,i)' w(i)]; wolffd@0: ZY(:,:,i) = [XY(:,:,i); wolffd@0: Y(:,i)']; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%% Estimate mean and regression wolffd@0: wolffd@0: if ~isempty(clamped_weights) & ~isempty(clamped_mean) wolffd@0: B = clamped_weights; wolffd@0: mu = clamped_mean; wolffd@0: end wolffd@0: if ~isempty(clamped_weights) & isempty(clamped_mean) wolffd@0: B = clamped_weights; wolffd@0: % eqn 5 wolffd@0: mu = zeros(Ysz, Q); wolffd@0: for i=1:Q wolffd@0: mu(:,i) = (Y(:,i) - B(:,:,i)*X(:,i)) / w(i); wolffd@0: end wolffd@0: end wolffd@0: if isempty(clamped_weights) & ~isempty(clamped_mean) wolffd@0: mu = clamped_mean; wolffd@0: % eqn 3 wolffd@0: B = zeros(Ysz, Xsz, Q); wolffd@0: for i=1:Q wolffd@0: tmp = XY(:,:,i)' - mu(:,i)*X(:,i)'; wolffd@0: %B(:,:,i) = tmp * inv(XX(:,:,i)); wolffd@0: B(:,:,i) = (XX(:,:,i) \ tmp')'; wolffd@0: end wolffd@0: end wolffd@0: if isempty(clamped_weights) & isempty(clamped_mean) wolffd@0: mu = zeros(Ysz, Q); wolffd@0: B = zeros(Ysz, Xsz, Q); wolffd@0: % Nothing is clamped, so we must estimate B and mu jointly wolffd@0: for i=1:Q wolffd@0: % eqn 9 wolffd@0: if rcond(ZZ(:,:,i)) < 1e-10 wolffd@0: sprintf('clg_Mstep warning: ZZ(:,:,%d) is ill-conditioned', i); wolffd@0: % probably because there are too few cases for a high-dimensional input wolffd@0: ZZ(:,:,i) = ZZ(:,:,i) + 1e-5*eye(Xsz+1); wolffd@0: end wolffd@0: %A = ZY(:,:,i)' * inv(ZZ(:,:,i)); wolffd@0: A = (ZZ(:,:,i) \ ZY(:,:,i))'; wolffd@0: B(:,:,i) = A(:, 1:Xsz); wolffd@0: mu(:,i) = A(:, Xsz+1); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if ~isempty(clamped_cov) wolffd@0: Sigma = clamped_cov; wolffd@0: return; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%% Estimate covariance wolffd@0: wolffd@0: % Spherical wolffd@0: if cov_type(1)=='s' wolffd@0: if ~tied_cov wolffd@0: Sigma = zeros(Ysz, Ysz, Q); wolffd@0: for i=1:Q wolffd@0: % eqn 16 wolffd@0: A = [B(:,:,i) mu(:,i)]; wolffd@0: %s = trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i)) / (Ysz*w(i)); % wrong! wolffd@0: s = (YTY(i) + trace(A'*A*ZZ(:,:,i)) - trace(2*A*ZY(:,:,i))) / (Ysz*w(i)); wolffd@0: Sigma(:,:,i) = s*eye(Ysz,Ysz); wolffd@0: wolffd@0: %%%%%%%%%%%%%%%%%%% debug wolffd@0: if ~isempty(xs) wolffd@0: [nx T] = size(xs); wolffd@0: zs = [xs; ones(1,T)]; wolffd@0: yty = 0; wolffd@0: zAAz = 0; wolffd@0: yAz = 0; wolffd@0: for t=1:T wolffd@0: yty = yty + ys(:,t)'*ys(:,t) * post(i,t); wolffd@0: zAAz = zAAz + zs(:,t)'*A'*A*zs(:,t)*post(i,t); wolffd@0: yAz = yAz + ys(:,t)'*A*zs(:,t)*post(i,t); wolffd@0: end wolffd@0: assert(approxeq(yty, YTY(i))) wolffd@0: assert(approxeq(zAAz, trace(A'*A*ZZ(:,:,i)))) wolffd@0: assert(approxeq(yAz, trace(A*ZY(:,:,i)))) wolffd@0: s2 = (yty + zAAz - 2*yAz) / (Ysz*w(i)); wolffd@0: assert(approxeq(s,s2)) wolffd@0: end wolffd@0: %%%%%%%%%%%%%%% end debug wolffd@0: wolffd@0: end wolffd@0: else wolffd@0: S = 0; wolffd@0: for i=1:Q wolffd@0: % eqn 18 wolffd@0: A = [B(:,:,i) mu(:,i)]; wolffd@0: S = S + trace(YTY(i) + A'*A*ZZ(:,:,i) - 2*A*ZY(:,:,i)); wolffd@0: end wolffd@0: Sigma = repmat(S / (N*Ysz), [1 1 Q]); wolffd@0: end wolffd@0: else % Full/diagonal wolffd@0: if ~tied_cov wolffd@0: Sigma = zeros(Ysz, Ysz, Q); wolffd@0: for i=1:Q wolffd@0: A = [B(:,:,i) mu(:,i)]; wolffd@0: % eqn 10 wolffd@0: SS = (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A') / w(i); wolffd@0: if cov_type(1)=='d' wolffd@0: Sigma(:,:,i) = diag(diag(SS)); wolffd@0: else wolffd@0: Sigma(:,:,i) = SS; wolffd@0: end wolffd@0: end wolffd@0: else % tied wolffd@0: SS = zeros(Ysz, Ysz); wolffd@0: for i=1:Q wolffd@0: A = [B(:,:,i) mu(:,i)]; wolffd@0: % eqn 13 wolffd@0: SS = SS + (YY(:,:,i) - ZY(:,:,i)'*A' - A*ZY(:,:,i) + A*ZZ(:,:,i)*A'); wolffd@0: end wolffd@0: SS = SS / N; wolffd@0: if cov_type(1)=='d' wolffd@0: Sigma = diag(diag(SS)); wolffd@0: else wolffd@0: Sigma = SS; wolffd@0: end wolffd@0: Sigma = repmat(Sigma, [1 1 Q]); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: Sigma = Sigma + cov_prior; wolffd@0: