annotate toolboxes/FullBNT-1.0.7/KPMstats/mixgauss_Mstep.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [mu, Sigma] = mixgauss_Mstep(w, Y, YY, YTY, varargin)
wolffd@0 2 % MSTEP_COND_GAUSS Compute MLEs for mixture of Gaussians given expected sufficient statistics
wolffd@0 3 % function [mu, Sigma] = Mstep_cond_gauss(w, Y, YY, YTY, varargin)
wolffd@0 4 %
wolffd@0 5 % We assume P(Y|Q=i) = N(Y; mu_i, Sigma_i)
wolffd@0 6 % and w(i,t) = p(Q(t)=i|y(t)) = posterior responsibility
wolffd@0 7 % See www.ai.mit.edu/~murphyk/Papers/learncg.pdf.
wolffd@0 8 %
wolffd@0 9 % INPUTS:
wolffd@0 10 % w(i) = sum_t w(i,t) = responsibilities for each mixture component
wolffd@0 11 % If there is only one mixture component (i.e., Q does not exist),
wolffd@0 12 % then w(i) = N = nsamples, and
wolffd@0 13 % all references to i can be replaced by 1.
wolffd@0 14 % YY(:,:,i) = sum_t w(i,t) y(:,t) y(:,t)' = weighted outer product
wolffd@0 15 % Y(:,i) = sum_t w(i,t) y(:,t) = weighted observations
wolffd@0 16 % YTY(i) = sum_t w(i,t) y(:,t)' y(:,t) = weighted inner product
wolffd@0 17 % You only need to pass in YTY if Sigma is to be estimated as spherical.
wolffd@0 18 %
wolffd@0 19 % Optional parameters may be passed as 'param_name', param_value pairs.
wolffd@0 20 % Parameter names are shown below; default values in [] - if none, argument is mandatory.
wolffd@0 21 %
wolffd@0 22 % 'cov_type' - 'full', 'diag' or 'spherical' ['full']
wolffd@0 23 % 'tied_cov' - 1 (Sigma) or 0 (Sigma_i) [0]
wolffd@0 24 % 'clamped_cov' - pass in clamped value, or [] if unclamped [ [] ]
wolffd@0 25 % 'clamped_mean' - pass in clamped value, or [] if unclamped [ [] ]
wolffd@0 26 % 'cov_prior' - Lambda_i, added to YY(:,:,i) [0.01*eye(d,d,Q)]
wolffd@0 27 %
wolffd@0 28 % If covariance is tied, Sigma has size d*d.
wolffd@0 29 % But diagonal and spherical covariances are represented in full size.
wolffd@0 30
wolffd@0 31 [cov_type, tied_cov, clamped_cov, clamped_mean, cov_prior, other] = ...
wolffd@0 32 process_options(varargin,...
wolffd@0 33 'cov_type', 'full', 'tied_cov', 0, 'clamped_cov', [], 'clamped_mean', [], ...
wolffd@0 34 'cov_prior', []);
wolffd@0 35
wolffd@0 36 [Ysz Q] = size(Y);
wolffd@0 37 N = sum(w);
wolffd@0 38 if isempty(cov_prior)
wolffd@0 39 %cov_prior = zeros(Ysz, Ysz, Q);
wolffd@0 40 %for q=1:Q
wolffd@0 41 % cov_prior(:,:,q) = 0.01*cov(Y(:,q)');
wolffd@0 42 %end
wolffd@0 43 cov_prior = repmat(0.01*eye(Ysz,Ysz), [1 1 Q]);
wolffd@0 44 end
wolffd@0 45 %YY = reshape(YY, [Ysz Ysz Q]) + cov_prior; % regularize the scatter matrix
wolffd@0 46 YY = reshape(YY, [Ysz Ysz Q]);
wolffd@0 47
wolffd@0 48 % Set any zero weights to one before dividing
wolffd@0 49 % This is valid because w(i)=0 => Y(:,i)=0, etc
wolffd@0 50 w = w + (w==0);
wolffd@0 51
wolffd@0 52 if ~isempty(clamped_mean)
wolffd@0 53 mu = clamped_mean;
wolffd@0 54 else
wolffd@0 55 % eqn 6
wolffd@0 56 %mu = Y ./ repmat(w(:)', [Ysz 1]);% Y may have a funny size
wolffd@0 57 mu = zeros(Ysz, Q);
wolffd@0 58 for i=1:Q
wolffd@0 59 mu(:,i) = Y(:,i) / w(i);
wolffd@0 60 end
wolffd@0 61 end
wolffd@0 62
wolffd@0 63 if ~isempty(clamped_cov)
wolffd@0 64 Sigma = clamped_cov;
wolffd@0 65 return;
wolffd@0 66 end
wolffd@0 67
wolffd@0 68 if ~tied_cov
wolffd@0 69 Sigma = zeros(Ysz,Ysz,Q);
wolffd@0 70 for i=1:Q
wolffd@0 71 if cov_type(1) == 's'
wolffd@0 72 % eqn 17
wolffd@0 73 s2 = (1/Ysz)*( (YTY(i)/w(i)) - mu(:,i)'*mu(:,i) );
wolffd@0 74 Sigma(:,:,i) = s2 * eye(Ysz);
wolffd@0 75 else
wolffd@0 76 % eqn 12
wolffd@0 77 SS = YY(:,:,i)/w(i) - mu(:,i)*mu(:,i)';
wolffd@0 78 if cov_type(1)=='d'
wolffd@0 79 SS = diag(diag(SS));
wolffd@0 80 end
wolffd@0 81 Sigma(:,:,i) = SS;
wolffd@0 82 end
wolffd@0 83 end
wolffd@0 84 else % tied cov
wolffd@0 85 if cov_type(1) == 's'
wolffd@0 86 % eqn 19
wolffd@0 87 s2 = (1/(N*Ysz))*(sum(YTY,2) + sum(diag(mu'*mu) .* w));
wolffd@0 88 Sigma = s2*eye(Ysz);
wolffd@0 89 else
wolffd@0 90 SS = zeros(Ysz, Ysz);
wolffd@0 91 % eqn 15
wolffd@0 92 for i=1:Q % probably could vectorize this...
wolffd@0 93 SS = SS + YY(:,:,i)/N - mu(:,i)*mu(:,i)';
wolffd@0 94 end
wolffd@0 95 if cov_type(1) == 'd'
wolffd@0 96 Sigma = diag(diag(SS));
wolffd@0 97 else
wolffd@0 98 Sigma = SS;
wolffd@0 99 end
wolffd@0 100 end
wolffd@0 101 end
wolffd@0 102
wolffd@0 103 if tied_cov
wolffd@0 104 Sigma = repmat(Sigma, [1 1 Q]);
wolffd@0 105 end
wolffd@0 106 Sigma = Sigma + cov_prior;