wolffd@0: function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin) wolffd@0: %GMMEM_KPM Like GMMEM, but with additional optional arguments wolffd@0: % function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin) wolffd@0: % wolffd@0: % Input: wolffd@0: % mix - structure created by gmminit or gmmem_multi_restart wolffd@0: % data - each row is an example wolffd@0: % wolffd@0: % Output: wolffd@0: % mix - modified structure wolffd@0: % num_iter - number of iterations needed to reach convergence wolffd@0: % ll - final log likelihood wolffd@0: % wolffd@0: % [ ... ] = gmmem_kpm(..., 'param1',val1, 'param2',val2, ...) allows you to wolffd@0: % specify optional parameter name/value pairs. wolffd@0: % Parameters are below [default value in brackets] wolffd@0: % wolffd@0: % 'max_iter' - maximum number of EM iterations [10] wolffd@0: % 'll_thresh' - change in log-likelihood threshold for convergence [1e-2] wolffd@0: % 'verbose' - 1 means display output while running [0] wolffd@0: % 'prior_cov' - this will be added to each estimated covariance wolffd@0: % to prevent singularities [1e-3*eye(d)] wolffd@0: % 'fn' - this function, if non-empty, will be called at every iteration wolffd@0: % (e.g., to display the parameters as they evolve) [ [] ] wolffd@0: % The fn is called as fn(mix, x, iter_num, fnargs). wolffd@0: % It is also called before the iteration starts as wolffd@0: % fn(mix, x, -1, fnargs), which can be used to initialize things. wolffd@0: % 'fnargs' - additional arguments to be passed to fn [ {} ] wolffd@0: % wolffd@0: % Modified by Kevin P Murphy, 29 Dec 2002 wolffd@0: wolffd@0: wolffd@0: % Check that inputs are consistent wolffd@0: errstring = consist(mix, 'gmm', x); wolffd@0: if ~isempty(errstring) wolffd@0: error(errstring); wolffd@0: end wolffd@0: wolffd@0: [ndata, xdim] = size(x); wolffd@0: wolffd@0: [max_iter, ll_thresh, verbose, prior_cov, fn, fnargs] = ... wolffd@0: process_options(varargin, ... wolffd@0: 'max_iter', 10, 'll_thresh', 1e-2, 'verbose', 1, ... wolffd@0: 'prior_cov', 1e-3*eye(xdim), 'fn', [], 'fnargs', {}); wolffd@0: wolffd@0: options = foptions; wolffd@0: if verbose, options(1)=1; else options(1)=-1; end wolffd@0: options(14) = max_iter; wolffd@0: options(3) = ll_thresh; wolffd@0: wolffd@0: wolffd@0: % Sort out the options wolffd@0: if (options(14)) wolffd@0: niters = options(14); wolffd@0: else wolffd@0: niters = 100; wolffd@0: end wolffd@0: wolffd@0: display = options(1); wolffd@0: test = 0; wolffd@0: if options(3) > 0.0 wolffd@0: test = 1; % Test log likelihood for termination wolffd@0: end wolffd@0: wolffd@0: check_covars = 0; wolffd@0: if options(5) >= 1 wolffd@0: if display >= 0 wolffd@0: disp('check_covars is on'); wolffd@0: end wolffd@0: check_covars = 1; % Ensure that covariances don't collapse wolffd@0: MIN_COVAR = eps; % Minimum singular value of covariance matrix wolffd@0: init_covars = mix.covars; wolffd@0: end wolffd@0: wolffd@0: mix0 = mix; % save init values for debugging wolffd@0: wolffd@0: if ~isempty(fn) wolffd@0: feval(fn, mix, x, -1, fnargs{:}); wolffd@0: end wolffd@0: wolffd@0: % Main loop of algorithm wolffd@0: for n = 1:niters wolffd@0: wolffd@0: % Calculate posteriors based on old parameters wolffd@0: [post, act] = gmmpost(mix, x); wolffd@0: wolffd@0: % Calculate error value if needed wolffd@0: if (display | test) wolffd@0: prob = act*(mix.priors)'; wolffd@0: % Error value is negative log likelihood of data wolffd@0: e = - sum(log(prob + eps)); wolffd@0: if display > 0 wolffd@0: fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); wolffd@0: end wolffd@0: if test wolffd@0: if (n > 1 & abs(e - eold) < options(3)) wolffd@0: options(8) = e; wolffd@0: ll = -e; wolffd@0: num_iter = n; wolffd@0: return; %%%%%%%%%%%%%%%% Exit here if converged wolffd@0: else wolffd@0: eold = e; wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if ~isempty(fn) wolffd@0: feval(fn, mix, x, n, fnargs{:}); wolffd@0: end wolffd@0: wolffd@0: % Adjust the new estimates for the parameters wolffd@0: new_pr = sum(post, 1); wolffd@0: new_c = post' * x; wolffd@0: wolffd@0: % Now move new estimates to old parameter vectors wolffd@0: mix.priors = new_pr ./ ndata; wolffd@0: wolffd@0: mix.centres = new_c ./ (new_pr' * ones(1, mix.nin)); wolffd@0: wolffd@0: switch mix.covar_type wolffd@0: case 'spherical' wolffd@0: n2 = dist2(x, mix.centres); wolffd@0: for j = 1:mix.ncentres wolffd@0: v(j) = (post(:,j)'*n2(:,j)); wolffd@0: end wolffd@0: mix.covars = ((v./new_pr) + sum(diag(prior_cov)))./mix.nin; wolffd@0: if check_covars wolffd@0: % Ensure that no covariance is too small wolffd@0: for j = 1:mix.ncentres wolffd@0: if mix.covars(j) < MIN_COVAR wolffd@0: mix.covars(j) = init_covars(j); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: case 'diag' wolffd@0: for j = 1:mix.ncentres wolffd@0: diffs = x - (ones(ndata, 1) * mix.centres(j,:)); wolffd@0: wts = (post(:,j)*ones(1, mix.nin)); wolffd@0: mix.covars(j,:) = sum((diffs.*diffs).*wts + prior_cov, 1)./new_pr(j); wolffd@0: end wolffd@0: if check_covars wolffd@0: % Ensure that no covariance is too small wolffd@0: for j = 1:mix.ncentres wolffd@0: if min(mix.covars(j,:)) < MIN_COVAR wolffd@0: mix.covars(j,:) = init_covars(j,:); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: case 'full' wolffd@0: for j = 1:mix.ncentres wolffd@0: diffs = x - (ones(ndata, 1) * mix.centres(j,:)); wolffd@0: diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); wolffd@0: mix.covars(:,:,j) = (diffs'*diffs + prior_cov)/new_pr(j); wolffd@0: end wolffd@0: if check_covars wolffd@0: % Ensure that no covariance is too small wolffd@0: for j = 1:mix.ncentres wolffd@0: if min(svd(mix.covars(:,:,j))) < MIN_COVAR wolffd@0: mix.covars(:,:,j) = init_covars(:,:,j); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: case 'ppca' wolffd@0: for j = 1:mix.ncentres wolffd@0: diffs = x - (ones(ndata, 1) * mix.centres(j,:)); wolffd@0: diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); wolffd@0: [mix.covars(j), mix.U(:,:,j), mix.lambda(j,:)] = ... wolffd@0: ppca((diffs'*diffs)/new_pr(j), mix.ppca_dim); wolffd@0: end wolffd@0: if check_covars wolffd@0: if mix.covars(j) < MIN_COVAR wolffd@0: mix.covars(j) = init_covars(j); wolffd@0: end wolffd@0: end wolffd@0: otherwise wolffd@0: error(['Unknown covariance type ', mix.covar_type]); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: ll = sum(log(gmmprob(mix, x))); wolffd@0: num_iter = n; wolffd@0: wolffd@0: %if (display >= 0) wolffd@0: % disp('Warning: Maximum number of iterations has been exceeded'); wolffd@0: %end wolffd@0: