Mercurial > hg > camir-aes2014
view toolboxes/FullBNT-1.0.7/KPMstats/mixgauss_em.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line source
function [mu, Sigma, prior] = mixgauss_em(Y, nc, varargin) % MIXGAUSS_EM Fit the parameters of a mixture of Gaussians using EM % function [mu, Sigma, prior] = mixgauss_em(data, nc, varargin) % % data(:, t) is the t'th data point % nc is the number of clusters % Kevin Murphy, 13 May 2003 [max_iter, thresh, cov_type, mu, Sigma, method, ... cov_prior, verbose, prune_thresh] = process_options(... varargin, 'max_iter', 10, 'thresh', 1e-2, 'cov_type', 'full', ... 'mu', [], 'Sigma', [], 'method', 'kmeans', ... 'cov_prior', [], 'verbose', 0, 'prune_thresh', 0); [ny T] = size(Y); if nc==1 % No latent variable, so there is a closed-form solution mu = mean(Y')'; Sigma = cov(Y'); if strcmp(cov_type, 'diag') Sigma = diag(diag(Sigma)); end prior = 1; return; end if isempty(mu) [mu, Sigma, prior] = mixgauss_init(nc, Y, cov_type, method); end previous_loglik = -inf; num_iter = 1; converged = 0; %if verbose, fprintf('starting em\n'); end while (num_iter <= max_iter) & ~converged % E step probY = mixgauss_prob(Y, mu, Sigma, prior); % probY(q,t) [post, lik] = normalize(probY .* repmat(prior, 1, T), 1); % post(q,t) loglik = log(sum(lik)); % extract expected sufficient statistics w = sum(post,2); % w(c) = sum_t post(c,t) WYY = zeros(ny, ny, nc); % WYY(:,:,c) = sum_t post(c,t) Y(:,t) Y(:,t)' WY = zeros(ny, nc); % WY(:,c) = sum_t post(c,t) Y(:,t) WYTY = zeros(nc,1); % WYTY(c) = sum_t post(c,t) Y(:,t)' Y(:,t) for c=1:nc weights = repmat(post(c,:), ny, 1); % weights(:,t) = post(c,t) WYbig = Y .* weights; % WYbig(:,t) = post(c,t) * Y(:,t) WYY(:,:,c) = WYbig * Y'; WY(:,c) = sum(WYbig, 2); WYTY(c) = sum(diag(WYbig' * Y)); end % M step prior = normalize(w); [mu, Sigma] = mixgauss_Mstep(w, WY, WYY, WYTY, 'cov_type', cov_type, 'cov_prior', cov_prior); if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end num_iter = num_iter + 1; converged = em_converged(loglik, previous_loglik, thresh); previous_loglik = loglik; end if prune_thresh > 0 ndx = find(prior < prune_thresh); mu(:,ndx) = []; Sigma(:,:,ndx) = []; prior(ndx) = []; end