wolffd@0: function [mu, Sigma, prior] = mixgauss_em(Y, nc, varargin) wolffd@0: % MIXGAUSS_EM Fit the parameters of a mixture of Gaussians using EM wolffd@0: % function [mu, Sigma, prior] = mixgauss_em(data, nc, varargin) wolffd@0: % wolffd@0: % data(:, t) is the t'th data point wolffd@0: % nc is the number of clusters wolffd@0: wolffd@0: % Kevin Murphy, 13 May 2003 wolffd@0: wolffd@0: [max_iter, thresh, cov_type, mu, Sigma, method, ... wolffd@0: cov_prior, verbose, prune_thresh] = process_options(... wolffd@0: varargin, 'max_iter', 10, 'thresh', 1e-2, 'cov_type', 'full', ... wolffd@0: 'mu', [], 'Sigma', [], 'method', 'kmeans', ... wolffd@0: 'cov_prior', [], 'verbose', 0, 'prune_thresh', 0); wolffd@0: wolffd@0: [ny T] = size(Y); wolffd@0: wolffd@0: if nc==1 wolffd@0: % No latent variable, so there is a closed-form solution wolffd@0: mu = mean(Y')'; wolffd@0: Sigma = cov(Y'); wolffd@0: if strcmp(cov_type, 'diag') wolffd@0: Sigma = diag(diag(Sigma)); wolffd@0: end wolffd@0: prior = 1; wolffd@0: return; wolffd@0: end wolffd@0: wolffd@0: if isempty(mu) wolffd@0: [mu, Sigma, prior] = mixgauss_init(nc, Y, cov_type, method); wolffd@0: end wolffd@0: wolffd@0: previous_loglik = -inf; wolffd@0: num_iter = 1; wolffd@0: converged = 0; wolffd@0: wolffd@0: %if verbose, fprintf('starting em\n'); end wolffd@0: wolffd@0: while (num_iter <= max_iter) & ~converged wolffd@0: % E step wolffd@0: probY = mixgauss_prob(Y, mu, Sigma, prior); % probY(q,t) wolffd@0: [post, lik] = normalize(probY .* repmat(prior, 1, T), 1); % post(q,t) wolffd@0: loglik = log(sum(lik)); wolffd@0: wolffd@0: % extract expected sufficient statistics wolffd@0: w = sum(post,2); % w(c) = sum_t post(c,t) wolffd@0: WYY = zeros(ny, ny, nc); % WYY(:,:,c) = sum_t post(c,t) Y(:,t) Y(:,t)' wolffd@0: WY = zeros(ny, nc); % WY(:,c) = sum_t post(c,t) Y(:,t) wolffd@0: WYTY = zeros(nc,1); % WYTY(c) = sum_t post(c,t) Y(:,t)' Y(:,t) wolffd@0: for c=1:nc wolffd@0: weights = repmat(post(c,:), ny, 1); % weights(:,t) = post(c,t) wolffd@0: WYbig = Y .* weights; % WYbig(:,t) = post(c,t) * Y(:,t) wolffd@0: WYY(:,:,c) = WYbig * Y'; wolffd@0: WY(:,c) = sum(WYbig, 2); wolffd@0: WYTY(c) = sum(diag(WYbig' * Y)); wolffd@0: end wolffd@0: wolffd@0: % M step wolffd@0: prior = normalize(w); wolffd@0: [mu, Sigma] = mixgauss_Mstep(w, WY, WYY, WYTY, 'cov_type', cov_type, 'cov_prior', cov_prior); wolffd@0: wolffd@0: if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end wolffd@0: num_iter = num_iter + 1; wolffd@0: converged = em_converged(loglik, previous_loglik, thresh); wolffd@0: previous_loglik = loglik; wolffd@0: wolffd@0: end wolffd@0: wolffd@0: if prune_thresh > 0 wolffd@0: ndx = find(prior < prune_thresh); wolffd@0: mu(:,ndx) = []; wolffd@0: Sigma(:,:,ndx) = []; wolffd@0: prior(ndx) = []; wolffd@0: end