Daniel@0: function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin) Daniel@0: %GMMEM_KPM Like GMMEM, but with additional optional arguments Daniel@0: % function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin) Daniel@0: % Daniel@0: % Input: Daniel@0: % mix - structure created by gmminit or gmmem_multi_restart Daniel@0: % data - each row is an example Daniel@0: % Daniel@0: % Output: Daniel@0: % mix - modified structure Daniel@0: % num_iter - number of iterations needed to reach convergence Daniel@0: % ll - final log likelihood Daniel@0: % Daniel@0: % [ ... ] = gmmem_kpm(..., 'param1',val1, 'param2',val2, ...) allows you to Daniel@0: % specify optional parameter name/value pairs. Daniel@0: % Parameters are below [default value in brackets] Daniel@0: % Daniel@0: % 'max_iter' - maximum number of EM iterations [10] Daniel@0: % 'll_thresh' - change in log-likelihood threshold for convergence [1e-2] Daniel@0: % 'verbose' - 1 means display output while running [0] Daniel@0: % 'prior_cov' - this will be added to each estimated covariance Daniel@0: % to prevent singularities [1e-3*eye(d)] Daniel@0: % 'fn' - this function, if non-empty, will be called at every iteration Daniel@0: % (e.g., to display the parameters as they evolve) [ [] ] Daniel@0: % The fn is called as fn(mix, x, iter_num, fnargs). Daniel@0: % It is also called before the iteration starts as Daniel@0: % fn(mix, x, -1, fnargs), which can be used to initialize things. Daniel@0: % 'fnargs' - additional arguments to be passed to fn [ {} ] Daniel@0: % Daniel@0: % Modified by Kevin P Murphy, 29 Dec 2002 Daniel@0: Daniel@0: Daniel@0: % Check that inputs are consistent Daniel@0: errstring = consist(mix, 'gmm', x); Daniel@0: if ~isempty(errstring) Daniel@0: error(errstring); Daniel@0: end Daniel@0: Daniel@0: [ndata, xdim] = size(x); Daniel@0: Daniel@0: [max_iter, ll_thresh, verbose, prior_cov, fn, fnargs] = ... Daniel@0: process_options(varargin, ... Daniel@0: 'max_iter', 10, 'll_thresh', 1e-2, 'verbose', 1, ... Daniel@0: 'prior_cov', 1e-3*eye(xdim), 'fn', [], 'fnargs', {}); Daniel@0: Daniel@0: options = foptions; Daniel@0: if verbose, options(1)=1; else options(1)=-1; end Daniel@0: options(14) = max_iter; Daniel@0: options(3) = ll_thresh; Daniel@0: Daniel@0: Daniel@0: % Sort out the options Daniel@0: if (options(14)) Daniel@0: niters = options(14); Daniel@0: else Daniel@0: niters = 100; Daniel@0: end Daniel@0: Daniel@0: display = options(1); Daniel@0: test = 0; Daniel@0: if options(3) > 0.0 Daniel@0: test = 1; % Test log likelihood for termination Daniel@0: end Daniel@0: Daniel@0: check_covars = 0; Daniel@0: if options(5) >= 1 Daniel@0: if display >= 0 Daniel@0: disp('check_covars is on'); Daniel@0: end Daniel@0: check_covars = 1; % Ensure that covariances don't collapse Daniel@0: MIN_COVAR = eps; % Minimum singular value of covariance matrix Daniel@0: init_covars = mix.covars; Daniel@0: end Daniel@0: Daniel@0: mix0 = mix; % save init values for debugging Daniel@0: Daniel@0: if ~isempty(fn) Daniel@0: feval(fn, mix, x, -1, fnargs{:}); Daniel@0: end Daniel@0: Daniel@0: % Main loop of algorithm Daniel@0: for n = 1:niters Daniel@0: Daniel@0: % Calculate posteriors based on old parameters Daniel@0: [post, act] = gmmpost(mix, x); Daniel@0: Daniel@0: % Calculate error value if needed Daniel@0: if (display | test) Daniel@0: prob = act*(mix.priors)'; Daniel@0: % Error value is negative log likelihood of data Daniel@0: e = - sum(log(prob + eps)); Daniel@0: if display > 0 Daniel@0: fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); Daniel@0: end Daniel@0: if test Daniel@0: if (n > 1 & abs(e - eold) < options(3)) Daniel@0: options(8) = e; Daniel@0: ll = -e; Daniel@0: num_iter = n; Daniel@0: return; %%%%%%%%%%%%%%%% Exit here if converged Daniel@0: else Daniel@0: eold = e; Daniel@0: end Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: if ~isempty(fn) Daniel@0: feval(fn, mix, x, n, fnargs{:}); Daniel@0: end Daniel@0: Daniel@0: % Adjust the new estimates for the parameters Daniel@0: new_pr = sum(post, 1); Daniel@0: new_c = post' * x; Daniel@0: Daniel@0: % Now move new estimates to old parameter vectors Daniel@0: mix.priors = new_pr ./ ndata; Daniel@0: Daniel@0: mix.centres = new_c ./ (new_pr' * ones(1, mix.nin)); Daniel@0: Daniel@0: switch mix.covar_type Daniel@0: case 'spherical' Daniel@0: n2 = dist2(x, mix.centres); Daniel@0: for j = 1:mix.ncentres Daniel@0: v(j) = (post(:,j)'*n2(:,j)); Daniel@0: end Daniel@0: mix.covars = ((v./new_pr) + sum(diag(prior_cov)))./mix.nin; Daniel@0: if check_covars Daniel@0: % Ensure that no covariance is too small Daniel@0: for j = 1:mix.ncentres Daniel@0: if mix.covars(j) < MIN_COVAR Daniel@0: mix.covars(j) = init_covars(j); Daniel@0: end Daniel@0: end Daniel@0: end Daniel@0: case 'diag' Daniel@0: for j = 1:mix.ncentres Daniel@0: diffs = x - (ones(ndata, 1) * mix.centres(j,:)); Daniel@0: wts = (post(:,j)*ones(1, mix.nin)); Daniel@0: mix.covars(j,:) = sum((diffs.*diffs).*wts + prior_cov, 1)./new_pr(j); Daniel@0: end Daniel@0: if check_covars Daniel@0: % Ensure that no covariance is too small Daniel@0: for j = 1:mix.ncentres Daniel@0: if min(mix.covars(j,:)) < MIN_COVAR Daniel@0: mix.covars(j,:) = init_covars(j,:); Daniel@0: end Daniel@0: end Daniel@0: end Daniel@0: case 'full' Daniel@0: for j = 1:mix.ncentres Daniel@0: diffs = x - (ones(ndata, 1) * mix.centres(j,:)); Daniel@0: diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); Daniel@0: mix.covars(:,:,j) = (diffs'*diffs + prior_cov)/new_pr(j); Daniel@0: end Daniel@0: if check_covars Daniel@0: % Ensure that no covariance is too small Daniel@0: for j = 1:mix.ncentres Daniel@0: if min(svd(mix.covars(:,:,j))) < MIN_COVAR Daniel@0: mix.covars(:,:,j) = init_covars(:,:,j); Daniel@0: end Daniel@0: end Daniel@0: end Daniel@0: case 'ppca' Daniel@0: for j = 1:mix.ncentres Daniel@0: diffs = x - (ones(ndata, 1) * mix.centres(j,:)); Daniel@0: diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); Daniel@0: [mix.covars(j), mix.U(:,:,j), mix.lambda(j,:)] = ... Daniel@0: ppca((diffs'*diffs)/new_pr(j), mix.ppca_dim); Daniel@0: end Daniel@0: if check_covars Daniel@0: if mix.covars(j) < MIN_COVAR Daniel@0: mix.covars(j) = init_covars(j); Daniel@0: end Daniel@0: end Daniel@0: otherwise Daniel@0: error(['Unknown covariance type ', mix.covar_type]); Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: ll = sum(log(gmmprob(mix, x))); Daniel@0: num_iter = n; Daniel@0: Daniel@0: %if (display >= 0) Daniel@0: % disp('Warning: Maximum number of iterations has been exceeded'); Daniel@0: %end Daniel@0: