annotate toolboxes/FullBNT-1.0.7/KPMstats/mixgauss_em.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [mu, Sigma, prior] = mixgauss_em(Y, nc, varargin)
wolffd@0 2 % MIXGAUSS_EM Fit the parameters of a mixture of Gaussians using EM
wolffd@0 3 % function [mu, Sigma, prior] = mixgauss_em(data, nc, varargin)
wolffd@0 4 %
wolffd@0 5 % data(:, t) is the t'th data point
wolffd@0 6 % nc is the number of clusters
wolffd@0 7
wolffd@0 8 % Kevin Murphy, 13 May 2003
wolffd@0 9
wolffd@0 10 [max_iter, thresh, cov_type, mu, Sigma, method, ...
wolffd@0 11 cov_prior, verbose, prune_thresh] = process_options(...
wolffd@0 12 varargin, 'max_iter', 10, 'thresh', 1e-2, 'cov_type', 'full', ...
wolffd@0 13 'mu', [], 'Sigma', [], 'method', 'kmeans', ...
wolffd@0 14 'cov_prior', [], 'verbose', 0, 'prune_thresh', 0);
wolffd@0 15
wolffd@0 16 [ny T] = size(Y);
wolffd@0 17
wolffd@0 18 if nc==1
wolffd@0 19 % No latent variable, so there is a closed-form solution
wolffd@0 20 mu = mean(Y')';
wolffd@0 21 Sigma = cov(Y');
wolffd@0 22 if strcmp(cov_type, 'diag')
wolffd@0 23 Sigma = diag(diag(Sigma));
wolffd@0 24 end
wolffd@0 25 prior = 1;
wolffd@0 26 return;
wolffd@0 27 end
wolffd@0 28
wolffd@0 29 if isempty(mu)
wolffd@0 30 [mu, Sigma, prior] = mixgauss_init(nc, Y, cov_type, method);
wolffd@0 31 end
wolffd@0 32
wolffd@0 33 previous_loglik = -inf;
wolffd@0 34 num_iter = 1;
wolffd@0 35 converged = 0;
wolffd@0 36
wolffd@0 37 %if verbose, fprintf('starting em\n'); end
wolffd@0 38
wolffd@0 39 while (num_iter <= max_iter) & ~converged
wolffd@0 40 % E step
wolffd@0 41 probY = mixgauss_prob(Y, mu, Sigma, prior); % probY(q,t)
wolffd@0 42 [post, lik] = normalize(probY .* repmat(prior, 1, T), 1); % post(q,t)
wolffd@0 43 loglik = log(sum(lik));
wolffd@0 44
wolffd@0 45 % extract expected sufficient statistics
wolffd@0 46 w = sum(post,2); % w(c) = sum_t post(c,t)
wolffd@0 47 WYY = zeros(ny, ny, nc); % WYY(:,:,c) = sum_t post(c,t) Y(:,t) Y(:,t)'
wolffd@0 48 WY = zeros(ny, nc); % WY(:,c) = sum_t post(c,t) Y(:,t)
wolffd@0 49 WYTY = zeros(nc,1); % WYTY(c) = sum_t post(c,t) Y(:,t)' Y(:,t)
wolffd@0 50 for c=1:nc
wolffd@0 51 weights = repmat(post(c,:), ny, 1); % weights(:,t) = post(c,t)
wolffd@0 52 WYbig = Y .* weights; % WYbig(:,t) = post(c,t) * Y(:,t)
wolffd@0 53 WYY(:,:,c) = WYbig * Y';
wolffd@0 54 WY(:,c) = sum(WYbig, 2);
wolffd@0 55 WYTY(c) = sum(diag(WYbig' * Y));
wolffd@0 56 end
wolffd@0 57
wolffd@0 58 % M step
wolffd@0 59 prior = normalize(w);
wolffd@0 60 [mu, Sigma] = mixgauss_Mstep(w, WY, WYY, WYTY, 'cov_type', cov_type, 'cov_prior', cov_prior);
wolffd@0 61
wolffd@0 62 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
wolffd@0 63 num_iter = num_iter + 1;
wolffd@0 64 converged = em_converged(loglik, previous_loglik, thresh);
wolffd@0 65 previous_loglik = loglik;
wolffd@0 66
wolffd@0 67 end
wolffd@0 68
wolffd@0 69 if prune_thresh > 0
wolffd@0 70 ndx = find(prior < prune_thresh);
wolffd@0 71 mu(:,ndx) = [];
wolffd@0 72 Sigma(:,:,ndx) = [];
wolffd@0 73 prior(ndx) = [];
wolffd@0 74 end