wolffd@0: function [mix, options, errlog] = gmmem(mix, x, options) wolffd@0: %GMMEM EM algorithm for Gaussian mixture model. wolffd@0: % wolffd@0: % Description wolffd@0: % [MIX, OPTIONS, ERRLOG] = GMMEM(MIX, X, OPTIONS) uses the Expectation wolffd@0: % Maximization algorithm of Dempster et al. to estimate the parameters wolffd@0: % of a Gaussian mixture model defined by a data structure MIX. The wolffd@0: % matrix X represents the data whose expectation is maximized, with wolffd@0: % each row corresponding to a vector. The optional parameters have wolffd@0: % the following interpretations. wolffd@0: % wolffd@0: % OPTIONS(1) is set to 1 to display error values; also logs error wolffd@0: % values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then wolffd@0: % only warning messages are displayed. If OPTIONS(1) is -1, then wolffd@0: % nothing is displayed. wolffd@0: % wolffd@0: % OPTIONS(3) is a measure of the absolute precision required of the wolffd@0: % error function at the solution. If the change in log likelihood wolffd@0: % between two steps of the EM algorithm is less than this value, then wolffd@0: % the function terminates. wolffd@0: % wolffd@0: % OPTIONS(5) is set to 1 if a covariance matrix is reset to its wolffd@0: % original value when any of its singular values are too small (less wolffd@0: % than MIN_COVAR which has the value eps). With the default value of wolffd@0: % 0 no action is taken. wolffd@0: % wolffd@0: % OPTIONS(14) is the maximum number of iterations; default 100. wolffd@0: % wolffd@0: % The optional return value OPTIONS contains the final error value wolffd@0: % (i.e. data log likelihood) in OPTIONS(8). wolffd@0: % wolffd@0: % See also wolffd@0: % GMM, GMMINIT wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: % Check that inputs are consistent wolffd@0: errstring = consist(mix, 'gmm', x); wolffd@0: if ~isempty(errstring) wolffd@0: error(errstring); wolffd@0: end wolffd@0: wolffd@0: [ndata, xdim] = size(x); wolffd@0: wolffd@0: % Sort out the options wolffd@0: if (options(14)) wolffd@0: niters = options(14); wolffd@0: else wolffd@0: niters = 100; wolffd@0: end wolffd@0: wolffd@0: display = options(1); wolffd@0: store = 0; wolffd@0: if (nargout > 2) wolffd@0: store = 1; % Store the error values to return them wolffd@0: errlog = zeros(1, niters); wolffd@0: end wolffd@0: test = 0; wolffd@0: if options(3) > 0.0 wolffd@0: test = 1; % Test log likelihood for termination wolffd@0: end wolffd@0: wolffd@0: check_covars = 0; wolffd@0: if options(5) >= 1 wolffd@0: if display >= 0 wolffd@0: disp('check_covars is on'); wolffd@0: end wolffd@0: check_covars = 1; % Ensure that covariances don't collapse wolffd@0: MIN_COVAR = eps; % Minimum singular value of covariance matrix wolffd@0: init_covars = mix.covars; wolffd@0: end wolffd@0: wolffd@0: % Main loop of algorithm wolffd@0: for n = 1:niters wolffd@0: wolffd@0: % Calculate posteriors based on old parameters wolffd@0: [post, act] = gmmpost(mix, x); wolffd@0: wolffd@0: % Calculate error value if needed wolffd@0: if (display | store | test) wolffd@0: prob = act*(mix.priors)'; wolffd@0: % Error value is negative log likelihood of data wolffd@0: e = - sum(log(prob)); wolffd@0: if store wolffd@0: errlog(n) = e; wolffd@0: end wolffd@0: if display > 0 wolffd@0: fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); wolffd@0: end wolffd@0: if test wolffd@0: if (n > 1 & abs(e - eold) < options(3)) wolffd@0: options(8) = e; wolffd@0: return; wolffd@0: else wolffd@0: eold = e; wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % Adjust the new estimates for the parameters wolffd@0: new_pr = sum(post, 1); wolffd@0: new_c = post' * x; wolffd@0: wolffd@0: % Now move new estimates to old parameter vectors wolffd@0: mix.priors = new_pr ./ ndata; wolffd@0: wolffd@0: mix.centres = new_c ./ (new_pr' * ones(1, mix.nin)); wolffd@0: wolffd@0: switch mix.covar_type wolffd@0: case 'spherical' wolffd@0: n2 = dist2(x, mix.centres); wolffd@0: for j = 1:mix.ncentres wolffd@0: v(j) = (post(:,j)'*n2(:,j)); wolffd@0: end wolffd@0: mix.covars = ((v./new_pr))./mix.nin; wolffd@0: if check_covars wolffd@0: % Ensure that no covariance is too small wolffd@0: for j = 1:mix.ncentres wolffd@0: if mix.covars(j) < MIN_COVAR wolffd@0: mix.covars(j) = init_covars(j); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: case 'diag' wolffd@0: for j = 1:mix.ncentres wolffd@0: diffs = x - (ones(ndata, 1) * mix.centres(j,:)); wolffd@0: mix.covars(j,:) = sum((diffs.*diffs).*(post(:,j)*ones(1, ... wolffd@0: mix.nin)), 1)./new_pr(j); wolffd@0: end wolffd@0: if check_covars wolffd@0: % Ensure that no covariance is too small wolffd@0: for j = 1:mix.ncentres wolffd@0: if min(mix.covars(j,:)) < MIN_COVAR wolffd@0: mix.covars(j,:) = init_covars(j,:); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: case 'full' wolffd@0: for j = 1:mix.ncentres wolffd@0: diffs = x - (ones(ndata, 1) * mix.centres(j,:)); wolffd@0: diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); wolffd@0: mix.covars(:,:,j) = (diffs'*diffs)/new_pr(j); wolffd@0: end wolffd@0: if check_covars wolffd@0: % Ensure that no covariance is too small wolffd@0: for j = 1:mix.ncentres wolffd@0: if min(svd(mix.covars(:,:,j))) < MIN_COVAR wolffd@0: mix.covars(:,:,j) = init_covars(:,:,j); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: case 'ppca' wolffd@0: for j = 1:mix.ncentres wolffd@0: diffs = x - (ones(ndata, 1) * mix.centres(j,:)); wolffd@0: diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); wolffd@0: [tempcovars, tempU, templambda] = ... wolffd@0: ppca((diffs'*diffs)/new_pr(j), mix.ppca_dim); wolffd@0: if length(templambda) ~= mix.ppca_dim wolffd@0: error('Unable to extract enough components'); wolffd@0: else wolffd@0: mix.covars(j) = tempcovars; wolffd@0: mix.U(:, :, j) = tempU; wolffd@0: mix.lambda(j, :) = templambda; wolffd@0: end wolffd@0: end wolffd@0: if check_covars wolffd@0: if mix.covars(j) < MIN_COVAR wolffd@0: mix.covars(j) = init_covars(j); wolffd@0: end wolffd@0: end wolffd@0: otherwise wolffd@0: error(['Unknown covariance type ', mix.covar_type]); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: options(8) = -sum(log(gmmprob(mix, x))); wolffd@0: if (display >= 0) wolffd@0: disp(maxitmess); wolffd@0: end wolffd@0: