Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/KPMstats/mixgauss_em.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [mu, Sigma, prior] = mixgauss_em(Y, nc, varargin) | |
2 % MIXGAUSS_EM Fit the parameters of a mixture of Gaussians using EM | |
3 % function [mu, Sigma, prior] = mixgauss_em(data, nc, varargin) | |
4 % | |
5 % data(:, t) is the t'th data point | |
6 % nc is the number of clusters | |
7 | |
8 % Kevin Murphy, 13 May 2003 | |
9 | |
10 [max_iter, thresh, cov_type, mu, Sigma, method, ... | |
11 cov_prior, verbose, prune_thresh] = process_options(... | |
12 varargin, 'max_iter', 10, 'thresh', 1e-2, 'cov_type', 'full', ... | |
13 'mu', [], 'Sigma', [], 'method', 'kmeans', ... | |
14 'cov_prior', [], 'verbose', 0, 'prune_thresh', 0); | |
15 | |
16 [ny T] = size(Y); | |
17 | |
18 if nc==1 | |
19 % No latent variable, so there is a closed-form solution | |
20 mu = mean(Y')'; | |
21 Sigma = cov(Y'); | |
22 if strcmp(cov_type, 'diag') | |
23 Sigma = diag(diag(Sigma)); | |
24 end | |
25 prior = 1; | |
26 return; | |
27 end | |
28 | |
29 if isempty(mu) | |
30 [mu, Sigma, prior] = mixgauss_init(nc, Y, cov_type, method); | |
31 end | |
32 | |
33 previous_loglik = -inf; | |
34 num_iter = 1; | |
35 converged = 0; | |
36 | |
37 %if verbose, fprintf('starting em\n'); end | |
38 | |
39 while (num_iter <= max_iter) & ~converged | |
40 % E step | |
41 probY = mixgauss_prob(Y, mu, Sigma, prior); % probY(q,t) | |
42 [post, lik] = normalize(probY .* repmat(prior, 1, T), 1); % post(q,t) | |
43 loglik = log(sum(lik)); | |
44 | |
45 % extract expected sufficient statistics | |
46 w = sum(post,2); % w(c) = sum_t post(c,t) | |
47 WYY = zeros(ny, ny, nc); % WYY(:,:,c) = sum_t post(c,t) Y(:,t) Y(:,t)' | |
48 WY = zeros(ny, nc); % WY(:,c) = sum_t post(c,t) Y(:,t) | |
49 WYTY = zeros(nc,1); % WYTY(c) = sum_t post(c,t) Y(:,t)' Y(:,t) | |
50 for c=1:nc | |
51 weights = repmat(post(c,:), ny, 1); % weights(:,t) = post(c,t) | |
52 WYbig = Y .* weights; % WYbig(:,t) = post(c,t) * Y(:,t) | |
53 WYY(:,:,c) = WYbig * Y'; | |
54 WY(:,c) = sum(WYbig, 2); | |
55 WYTY(c) = sum(diag(WYbig' * Y)); | |
56 end | |
57 | |
58 % M step | |
59 prior = normalize(w); | |
60 [mu, Sigma] = mixgauss_Mstep(w, WY, WYY, WYTY, 'cov_type', cov_type, 'cov_prior', cov_prior); | |
61 | |
62 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end | |
63 num_iter = num_iter + 1; | |
64 converged = em_converged(loglik, previous_loglik, thresh); | |
65 previous_loglik = loglik; | |
66 | |
67 end | |
68 | |
69 if prune_thresh > 0 | |
70 ndx = find(prior < prune_thresh); | |
71 mu(:,ndx) = []; | |
72 Sigma(:,:,ndx) = []; | |
73 prior(ndx) = []; | |
74 end |