comparison toolboxes/FullBNT-1.0.7/KPMstats/mixgauss_em.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [mu, Sigma, prior] = mixgauss_em(Y, nc, varargin)
2 % MIXGAUSS_EM Fit the parameters of a mixture of Gaussians using EM
3 % function [mu, Sigma, prior] = mixgauss_em(data, nc, varargin)
4 %
5 % data(:, t) is the t'th data point
6 % nc is the number of clusters
7
8 % Kevin Murphy, 13 May 2003
9
10 [max_iter, thresh, cov_type, mu, Sigma, method, ...
11 cov_prior, verbose, prune_thresh] = process_options(...
12 varargin, 'max_iter', 10, 'thresh', 1e-2, 'cov_type', 'full', ...
13 'mu', [], 'Sigma', [], 'method', 'kmeans', ...
14 'cov_prior', [], 'verbose', 0, 'prune_thresh', 0);
15
16 [ny T] = size(Y);
17
18 if nc==1
19 % No latent variable, so there is a closed-form solution
20 mu = mean(Y')';
21 Sigma = cov(Y');
22 if strcmp(cov_type, 'diag')
23 Sigma = diag(diag(Sigma));
24 end
25 prior = 1;
26 return;
27 end
28
29 if isempty(mu)
30 [mu, Sigma, prior] = mixgauss_init(nc, Y, cov_type, method);
31 end
32
33 previous_loglik = -inf;
34 num_iter = 1;
35 converged = 0;
36
37 %if verbose, fprintf('starting em\n'); end
38
39 while (num_iter <= max_iter) & ~converged
40 % E step
41 probY = mixgauss_prob(Y, mu, Sigma, prior); % probY(q,t)
42 [post, lik] = normalize(probY .* repmat(prior, 1, T), 1); % post(q,t)
43 loglik = log(sum(lik));
44
45 % extract expected sufficient statistics
46 w = sum(post,2); % w(c) = sum_t post(c,t)
47 WYY = zeros(ny, ny, nc); % WYY(:,:,c) = sum_t post(c,t) Y(:,t) Y(:,t)'
48 WY = zeros(ny, nc); % WY(:,c) = sum_t post(c,t) Y(:,t)
49 WYTY = zeros(nc,1); % WYTY(c) = sum_t post(c,t) Y(:,t)' Y(:,t)
50 for c=1:nc
51 weights = repmat(post(c,:), ny, 1); % weights(:,t) = post(c,t)
52 WYbig = Y .* weights; % WYbig(:,t) = post(c,t) * Y(:,t)
53 WYY(:,:,c) = WYbig * Y';
54 WY(:,c) = sum(WYbig, 2);
55 WYTY(c) = sum(diag(WYbig' * Y));
56 end
57
58 % M step
59 prior = normalize(w);
60 [mu, Sigma] = mixgauss_Mstep(w, WY, WYY, WYTY, 'cov_type', cov_type, 'cov_prior', cov_prior);
61
62 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
63 num_iter = num_iter + 1;
64 converged = em_converged(loglik, previous_loglik, thresh);
65 previous_loglik = loglik;
66
67 end
68
69 if prune_thresh > 0
70 ndx = find(prior < prune_thresh);
71 mu(:,ndx) = [];
72 Sigma(:,:,ndx) = [];
73 prior(ndx) = [];
74 end