view toolboxes/FullBNT-1.0.7/netlabKPM/gmmem2.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin)
%GMMEM_KPM Like GMMEM, but with additional optional arguments
% function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin)
%
% Input:
% mix - structure created by gmminit or gmmem_multi_restart
% data - each row is an example
%
% Output:
% mix - modified structure
% num_iter - number of iterations needed to reach convergence
% ll - final log likelihood
%
% [ ... ] = gmmem_kpm(..., 'param1',val1, 'param2',val2, ...) allows you to
% specify optional parameter name/value pairs.
% Parameters are below [default value in brackets]
%
% 'max_iter' - maximum number of EM iterations [10]
% 'll_thresh' - change in log-likelihood threshold for convergence [1e-2]
% 'verbose' - 1 means display output while running [0]
% 'prior_cov' - this will be added to each estimated covariance
%               to prevent singularities  [1e-3*eye(d)]
% 'fn'        - this function, if non-empty,  will be called at every iteration
%               (e.g., to display the parameters as they evolve) [ [] ]
%               The fn is called as fn(mix, x, iter_num, fnargs).
%               It is also called before the iteration starts as
%               fn(mix, x, -1, fnargs), which can be used to initialize things.
% 'fnargs'    - additional arguments to be passed to fn [ {} ]
%
% Modified by Kevin P Murphy, 29 Dec 2002


% Check that inputs are consistent
errstring = consist(mix, 'gmm', x);
if ~isempty(errstring)
  error(errstring);
end

[ndata, xdim] = size(x);

[max_iter, ll_thresh, verbose, prior_cov, fn, fnargs] = ...
    process_options(varargin, ...
	'max_iter', 10, 'll_thresh', 1e-2, 'verbose', 1, ...
	'prior_cov', 1e-3*eye(xdim), 'fn', [], 'fnargs', {});

options = foptions;
if verbose, options(1)=1; else options(1)=-1; end
options(14) = max_iter;
options(3) = ll_thresh;


% Sort out the options
if (options(14))
  niters = options(14);
else
  niters = 100;
end

display = options(1);
test = 0;
if options(3) > 0.0
  test = 1;	% Test log likelihood for termination
end

check_covars = 0;
if options(5) >= 1
  if display >= 0
    disp('check_covars is on');
  end
  check_covars = 1;	% Ensure that covariances don't collapse
  MIN_COVAR = eps;	% Minimum singular value of covariance matrix
  init_covars = mix.covars;
end

mix0 = mix; % save init values for debugging

if ~isempty(fn)
  feval(fn, mix, x, -1, fnargs{:});
end

% Main loop of algorithm
for n = 1:niters
  
  % Calculate posteriors based on old parameters
  [post, act] = gmmpost(mix, x);
  
  % Calculate error value if needed
  if (display |  test)
    prob = act*(mix.priors)';
    % Error value is negative log likelihood of data
    e = - sum(log(prob + eps));
    if display > 0
      fprintf(1, 'Cycle %4d  Error %11.6f\n', n, e);
    end
    if test
      if (n > 1 & abs(e - eold) < options(3))
        options(8) = e;
	ll = -e;
	num_iter = n;
        return; %%%%%%%%%%%%%%%% Exit here if converged
      else
        eold = e;
      end
    end
  end

  if ~isempty(fn)
    feval(fn, mix, x, n, fnargs{:});
  end

  % Adjust the new estimates for the parameters
  new_pr = sum(post, 1);
  new_c = post' * x;
  
  % Now move new estimates to old parameter vectors
  mix.priors = new_pr ./ ndata;
  
  mix.centres = new_c ./ (new_pr' * ones(1, mix.nin));
  
  switch mix.covar_type
  case 'spherical'
    n2 = dist2(x, mix.centres);
    for j = 1:mix.ncentres
      v(j) = (post(:,j)'*n2(:,j));
    end
    mix.covars = ((v./new_pr) + sum(diag(prior_cov)))./mix.nin;
    if check_covars
      % Ensure that no covariance is too small
      for j = 1:mix.ncentres
        if mix.covars(j) < MIN_COVAR
          mix.covars(j) = init_covars(j);
        end
      end
    end
  case 'diag'
    for j = 1:mix.ncentres
      diffs = x - (ones(ndata, 1) * mix.centres(j,:));
      wts = (post(:,j)*ones(1, mix.nin));
      mix.covars(j,:) = sum((diffs.*diffs).*wts + prior_cov, 1)./new_pr(j);
    end
    if check_covars
      % Ensure that no covariance is too small
      for j = 1:mix.ncentres
        if min(mix.covars(j,:)) < MIN_COVAR
          mix.covars(j,:) = init_covars(j,:);
        end
      end
    end
  case 'full'
    for j = 1:mix.ncentres
      diffs = x - (ones(ndata, 1) * mix.centres(j,:));
      diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin));
      mix.covars(:,:,j) = (diffs'*diffs + prior_cov)/new_pr(j);
    end
    if check_covars
      % Ensure that no covariance is too small
      for j = 1:mix.ncentres
        if min(svd(mix.covars(:,:,j))) < MIN_COVAR
          mix.covars(:,:,j) = init_covars(:,:,j);
        end
      end
    end
  case 'ppca'
    for j = 1:mix.ncentres
      diffs = x - (ones(ndata, 1) * mix.centres(j,:));
      diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin));
      [mix.covars(j), mix.U(:,:,j), mix.lambda(j,:)] = ...
        ppca((diffs'*diffs)/new_pr(j), mix.ppca_dim);
    end
    if check_covars
      if mix.covars(j) < MIN_COVAR
        mix.covars(j) = init_covars(j);
      end
    end
    otherwise
      error(['Unknown covariance type ', mix.covar_type]);               
  end
end

ll = sum(log(gmmprob(mix, x)));
num_iter = n;

%if (display >= 0)
%  disp('Warning: Maximum number of iterations has been exceeded');
%end