Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/netlabKPM/gmmem2.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/netlabKPM/gmmem2.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,186 @@ +function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin) +%GMMEM_KPM Like GMMEM, but with additional optional arguments +% function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin) +% +% Input: +% mix - structure created by gmminit or gmmem_multi_restart +% data - each row is an example +% +% Output: +% mix - modified structure +% num_iter - number of iterations needed to reach convergence +% ll - final log likelihood +% +% [ ... ] = gmmem_kpm(..., 'param1',val1, 'param2',val2, ...) allows you to +% specify optional parameter name/value pairs. +% Parameters are below [default value in brackets] +% +% 'max_iter' - maximum number of EM iterations [10] +% 'll_thresh' - change in log-likelihood threshold for convergence [1e-2] +% 'verbose' - 1 means display output while running [0] +% 'prior_cov' - this will be added to each estimated covariance +% to prevent singularities [1e-3*eye(d)] +% 'fn' - this function, if non-empty, will be called at every iteration +% (e.g., to display the parameters as they evolve) [ [] ] +% The fn is called as fn(mix, x, iter_num, fnargs). +% It is also called before the iteration starts as +% fn(mix, x, -1, fnargs), which can be used to initialize things. +% 'fnargs' - additional arguments to be passed to fn [ {} ] +% +% Modified by Kevin P Murphy, 29 Dec 2002 + + +% Check that inputs are consistent +errstring = consist(mix, 'gmm', x); +if ~isempty(errstring) + error(errstring); +end + +[ndata, xdim] = size(x); + +[max_iter, ll_thresh, verbose, prior_cov, fn, fnargs] = ... + process_options(varargin, ... + 'max_iter', 10, 'll_thresh', 1e-2, 'verbose', 1, ... + 'prior_cov', 1e-3*eye(xdim), 'fn', [], 'fnargs', {}); + +options = foptions; +if verbose, options(1)=1; else options(1)=-1; end +options(14) = max_iter; +options(3) = ll_thresh; + + +% Sort out the options +if (options(14)) + niters = options(14); +else + niters = 100; +end + +display = options(1); +test = 0; +if options(3) > 0.0 + test = 1; % Test log likelihood for termination +end + +check_covars = 0; +if options(5) >= 1 + if display >= 0 + disp('check_covars is on'); + end + check_covars = 1; % Ensure that covariances don't collapse + MIN_COVAR = eps; % Minimum singular value of covariance matrix + init_covars = mix.covars; +end + +mix0 = mix; % save init values for debugging + +if ~isempty(fn) + feval(fn, mix, x, -1, fnargs{:}); +end + +% Main loop of algorithm +for n = 1:niters + + % Calculate posteriors based on old parameters + [post, act] = gmmpost(mix, x); + + % Calculate error value if needed + if (display | test) + prob = act*(mix.priors)'; + % Error value is negative log likelihood of data + e = - sum(log(prob + eps)); + if display > 0 + fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); + end + if test + if (n > 1 & abs(e - eold) < options(3)) + options(8) = e; + ll = -e; + num_iter = n; + return; %%%%%%%%%%%%%%%% Exit here if converged + else + eold = e; + end + end + end + + if ~isempty(fn) + feval(fn, mix, x, n, fnargs{:}); + end + + % Adjust the new estimates for the parameters + new_pr = sum(post, 1); + new_c = post' * x; + + % Now move new estimates to old parameter vectors + mix.priors = new_pr ./ ndata; + + mix.centres = new_c ./ (new_pr' * ones(1, mix.nin)); + + switch mix.covar_type + case 'spherical' + n2 = dist2(x, mix.centres); + for j = 1:mix.ncentres + v(j) = (post(:,j)'*n2(:,j)); + end + mix.covars = ((v./new_pr) + sum(diag(prior_cov)))./mix.nin; + if check_covars + % Ensure that no covariance is too small + for j = 1:mix.ncentres + if mix.covars(j) < MIN_COVAR + mix.covars(j) = init_covars(j); + end + end + end + case 'diag' + for j = 1:mix.ncentres + diffs = x - (ones(ndata, 1) * mix.centres(j,:)); + wts = (post(:,j)*ones(1, mix.nin)); + mix.covars(j,:) = sum((diffs.*diffs).*wts + prior_cov, 1)./new_pr(j); + end + if check_covars + % Ensure that no covariance is too small + for j = 1:mix.ncentres + if min(mix.covars(j,:)) < MIN_COVAR + mix.covars(j,:) = init_covars(j,:); + end + end + end + case 'full' + for j = 1:mix.ncentres + diffs = x - (ones(ndata, 1) * mix.centres(j,:)); + diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); + mix.covars(:,:,j) = (diffs'*diffs + prior_cov)/new_pr(j); + end + if check_covars + % Ensure that no covariance is too small + for j = 1:mix.ncentres + if min(svd(mix.covars(:,:,j))) < MIN_COVAR + mix.covars(:,:,j) = init_covars(:,:,j); + end + end + end + case 'ppca' + for j = 1:mix.ncentres + diffs = x - (ones(ndata, 1) * mix.centres(j,:)); + diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); + [mix.covars(j), mix.U(:,:,j), mix.lambda(j,:)] = ... + ppca((diffs'*diffs)/new_pr(j), mix.ppca_dim); + end + if check_covars + if mix.covars(j) < MIN_COVAR + mix.covars(j) = init_covars(j); + end + end + otherwise + error(['Unknown covariance type ', mix.covar_type]); + end +end + +ll = sum(log(gmmprob(mix, x))); +num_iter = n; + +%if (display >= 0) +% disp('Warning: Maximum number of iterations has been exceeded'); +%end +