wolffd@0
|
1 function [mu, Sigma, prior] = mixgauss_em(Y, nc, varargin)
|
wolffd@0
|
2 % MIXGAUSS_EM Fit the parameters of a mixture of Gaussians using EM
|
wolffd@0
|
3 % function [mu, Sigma, prior] = mixgauss_em(data, nc, varargin)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % data(:, t) is the t'th data point
|
wolffd@0
|
6 % nc is the number of clusters
|
wolffd@0
|
7
|
wolffd@0
|
8 % Kevin Murphy, 13 May 2003
|
wolffd@0
|
9
|
wolffd@0
|
10 [max_iter, thresh, cov_type, mu, Sigma, method, ...
|
wolffd@0
|
11 cov_prior, verbose, prune_thresh] = process_options(...
|
wolffd@0
|
12 varargin, 'max_iter', 10, 'thresh', 1e-2, 'cov_type', 'full', ...
|
wolffd@0
|
13 'mu', [], 'Sigma', [], 'method', 'kmeans', ...
|
wolffd@0
|
14 'cov_prior', [], 'verbose', 0, 'prune_thresh', 0);
|
wolffd@0
|
15
|
wolffd@0
|
16 [ny T] = size(Y);
|
wolffd@0
|
17
|
wolffd@0
|
18 if nc==1
|
wolffd@0
|
19 % No latent variable, so there is a closed-form solution
|
wolffd@0
|
20 mu = mean(Y')';
|
wolffd@0
|
21 Sigma = cov(Y');
|
wolffd@0
|
22 if strcmp(cov_type, 'diag')
|
wolffd@0
|
23 Sigma = diag(diag(Sigma));
|
wolffd@0
|
24 end
|
wolffd@0
|
25 prior = 1;
|
wolffd@0
|
26 return;
|
wolffd@0
|
27 end
|
wolffd@0
|
28
|
wolffd@0
|
29 if isempty(mu)
|
wolffd@0
|
30 [mu, Sigma, prior] = mixgauss_init(nc, Y, cov_type, method);
|
wolffd@0
|
31 end
|
wolffd@0
|
32
|
wolffd@0
|
33 previous_loglik = -inf;
|
wolffd@0
|
34 num_iter = 1;
|
wolffd@0
|
35 converged = 0;
|
wolffd@0
|
36
|
wolffd@0
|
37 %if verbose, fprintf('starting em\n'); end
|
wolffd@0
|
38
|
wolffd@0
|
39 while (num_iter <= max_iter) & ~converged
|
wolffd@0
|
40 % E step
|
wolffd@0
|
41 probY = mixgauss_prob(Y, mu, Sigma, prior); % probY(q,t)
|
wolffd@0
|
42 [post, lik] = normalize(probY .* repmat(prior, 1, T), 1); % post(q,t)
|
wolffd@0
|
43 loglik = log(sum(lik));
|
wolffd@0
|
44
|
wolffd@0
|
45 % extract expected sufficient statistics
|
wolffd@0
|
46 w = sum(post,2); % w(c) = sum_t post(c,t)
|
wolffd@0
|
47 WYY = zeros(ny, ny, nc); % WYY(:,:,c) = sum_t post(c,t) Y(:,t) Y(:,t)'
|
wolffd@0
|
48 WY = zeros(ny, nc); % WY(:,c) = sum_t post(c,t) Y(:,t)
|
wolffd@0
|
49 WYTY = zeros(nc,1); % WYTY(c) = sum_t post(c,t) Y(:,t)' Y(:,t)
|
wolffd@0
|
50 for c=1:nc
|
wolffd@0
|
51 weights = repmat(post(c,:), ny, 1); % weights(:,t) = post(c,t)
|
wolffd@0
|
52 WYbig = Y .* weights; % WYbig(:,t) = post(c,t) * Y(:,t)
|
wolffd@0
|
53 WYY(:,:,c) = WYbig * Y';
|
wolffd@0
|
54 WY(:,c) = sum(WYbig, 2);
|
wolffd@0
|
55 WYTY(c) = sum(diag(WYbig' * Y));
|
wolffd@0
|
56 end
|
wolffd@0
|
57
|
wolffd@0
|
58 % M step
|
wolffd@0
|
59 prior = normalize(w);
|
wolffd@0
|
60 [mu, Sigma] = mixgauss_Mstep(w, WY, WYY, WYTY, 'cov_type', cov_type, 'cov_prior', cov_prior);
|
wolffd@0
|
61
|
wolffd@0
|
62 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
|
wolffd@0
|
63 num_iter = num_iter + 1;
|
wolffd@0
|
64 converged = em_converged(loglik, previous_loglik, thresh);
|
wolffd@0
|
65 previous_loglik = loglik;
|
wolffd@0
|
66
|
wolffd@0
|
67 end
|
wolffd@0
|
68
|
wolffd@0
|
69 if prune_thresh > 0
|
wolffd@0
|
70 ndx = find(prior < prune_thresh);
|
wolffd@0
|
71 mu(:,ndx) = [];
|
wolffd@0
|
72 Sigma(:,:,ndx) = [];
|
wolffd@0
|
73 prior(ndx) = [];
|
wolffd@0
|
74 end
|