annotate toolboxes/FullBNT-1.0.7/netlabKPM/gmmem2.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin)
wolffd@0 2 %GMMEM_KPM Like GMMEM, but with additional optional arguments
wolffd@0 3 % function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin)
wolffd@0 4 %
wolffd@0 5 % Input:
wolffd@0 6 % mix - structure created by gmminit or gmmem_multi_restart
wolffd@0 7 % data - each row is an example
wolffd@0 8 %
wolffd@0 9 % Output:
wolffd@0 10 % mix - modified structure
wolffd@0 11 % num_iter - number of iterations needed to reach convergence
wolffd@0 12 % ll - final log likelihood
wolffd@0 13 %
wolffd@0 14 % [ ... ] = gmmem_kpm(..., 'param1',val1, 'param2',val2, ...) allows you to
wolffd@0 15 % specify optional parameter name/value pairs.
wolffd@0 16 % Parameters are below [default value in brackets]
wolffd@0 17 %
wolffd@0 18 % 'max_iter' - maximum number of EM iterations [10]
wolffd@0 19 % 'll_thresh' - change in log-likelihood threshold for convergence [1e-2]
wolffd@0 20 % 'verbose' - 1 means display output while running [0]
wolffd@0 21 % 'prior_cov' - this will be added to each estimated covariance
wolffd@0 22 % to prevent singularities [1e-3*eye(d)]
wolffd@0 23 % 'fn' - this function, if non-empty, will be called at every iteration
wolffd@0 24 % (e.g., to display the parameters as they evolve) [ [] ]
wolffd@0 25 % The fn is called as fn(mix, x, iter_num, fnargs).
wolffd@0 26 % It is also called before the iteration starts as
wolffd@0 27 % fn(mix, x, -1, fnargs), which can be used to initialize things.
wolffd@0 28 % 'fnargs' - additional arguments to be passed to fn [ {} ]
wolffd@0 29 %
wolffd@0 30 % Modified by Kevin P Murphy, 29 Dec 2002
wolffd@0 31
wolffd@0 32
wolffd@0 33 % Check that inputs are consistent
wolffd@0 34 errstring = consist(mix, 'gmm', x);
wolffd@0 35 if ~isempty(errstring)
wolffd@0 36 error(errstring);
wolffd@0 37 end
wolffd@0 38
wolffd@0 39 [ndata, xdim] = size(x);
wolffd@0 40
wolffd@0 41 [max_iter, ll_thresh, verbose, prior_cov, fn, fnargs] = ...
wolffd@0 42 process_options(varargin, ...
wolffd@0 43 'max_iter', 10, 'll_thresh', 1e-2, 'verbose', 1, ...
wolffd@0 44 'prior_cov', 1e-3*eye(xdim), 'fn', [], 'fnargs', {});
wolffd@0 45
wolffd@0 46 options = foptions;
wolffd@0 47 if verbose, options(1)=1; else options(1)=-1; end
wolffd@0 48 options(14) = max_iter;
wolffd@0 49 options(3) = ll_thresh;
wolffd@0 50
wolffd@0 51
wolffd@0 52 % Sort out the options
wolffd@0 53 if (options(14))
wolffd@0 54 niters = options(14);
wolffd@0 55 else
wolffd@0 56 niters = 100;
wolffd@0 57 end
wolffd@0 58
wolffd@0 59 display = options(1);
wolffd@0 60 test = 0;
wolffd@0 61 if options(3) > 0.0
wolffd@0 62 test = 1; % Test log likelihood for termination
wolffd@0 63 end
wolffd@0 64
wolffd@0 65 check_covars = 0;
wolffd@0 66 if options(5) >= 1
wolffd@0 67 if display >= 0
wolffd@0 68 disp('check_covars is on');
wolffd@0 69 end
wolffd@0 70 check_covars = 1; % Ensure that covariances don't collapse
wolffd@0 71 MIN_COVAR = eps; % Minimum singular value of covariance matrix
wolffd@0 72 init_covars = mix.covars;
wolffd@0 73 end
wolffd@0 74
wolffd@0 75 mix0 = mix; % save init values for debugging
wolffd@0 76
wolffd@0 77 if ~isempty(fn)
wolffd@0 78 feval(fn, mix, x, -1, fnargs{:});
wolffd@0 79 end
wolffd@0 80
wolffd@0 81 % Main loop of algorithm
wolffd@0 82 for n = 1:niters
wolffd@0 83
wolffd@0 84 % Calculate posteriors based on old parameters
wolffd@0 85 [post, act] = gmmpost(mix, x);
wolffd@0 86
wolffd@0 87 % Calculate error value if needed
wolffd@0 88 if (display | test)
wolffd@0 89 prob = act*(mix.priors)';
wolffd@0 90 % Error value is negative log likelihood of data
wolffd@0 91 e = - sum(log(prob + eps));
wolffd@0 92 if display > 0
wolffd@0 93 fprintf(1, 'Cycle %4d Error %11.6f\n', n, e);
wolffd@0 94 end
wolffd@0 95 if test
wolffd@0 96 if (n > 1 & abs(e - eold) < options(3))
wolffd@0 97 options(8) = e;
wolffd@0 98 ll = -e;
wolffd@0 99 num_iter = n;
wolffd@0 100 return; %%%%%%%%%%%%%%%% Exit here if converged
wolffd@0 101 else
wolffd@0 102 eold = e;
wolffd@0 103 end
wolffd@0 104 end
wolffd@0 105 end
wolffd@0 106
wolffd@0 107 if ~isempty(fn)
wolffd@0 108 feval(fn, mix, x, n, fnargs{:});
wolffd@0 109 end
wolffd@0 110
wolffd@0 111 % Adjust the new estimates for the parameters
wolffd@0 112 new_pr = sum(post, 1);
wolffd@0 113 new_c = post' * x;
wolffd@0 114
wolffd@0 115 % Now move new estimates to old parameter vectors
wolffd@0 116 mix.priors = new_pr ./ ndata;
wolffd@0 117
wolffd@0 118 mix.centres = new_c ./ (new_pr' * ones(1, mix.nin));
wolffd@0 119
wolffd@0 120 switch mix.covar_type
wolffd@0 121 case 'spherical'
wolffd@0 122 n2 = dist2(x, mix.centres);
wolffd@0 123 for j = 1:mix.ncentres
wolffd@0 124 v(j) = (post(:,j)'*n2(:,j));
wolffd@0 125 end
wolffd@0 126 mix.covars = ((v./new_pr) + sum(diag(prior_cov)))./mix.nin;
wolffd@0 127 if check_covars
wolffd@0 128 % Ensure that no covariance is too small
wolffd@0 129 for j = 1:mix.ncentres
wolffd@0 130 if mix.covars(j) < MIN_COVAR
wolffd@0 131 mix.covars(j) = init_covars(j);
wolffd@0 132 end
wolffd@0 133 end
wolffd@0 134 end
wolffd@0 135 case 'diag'
wolffd@0 136 for j = 1:mix.ncentres
wolffd@0 137 diffs = x - (ones(ndata, 1) * mix.centres(j,:));
wolffd@0 138 wts = (post(:,j)*ones(1, mix.nin));
wolffd@0 139 mix.covars(j,:) = sum((diffs.*diffs).*wts + prior_cov, 1)./new_pr(j);
wolffd@0 140 end
wolffd@0 141 if check_covars
wolffd@0 142 % Ensure that no covariance is too small
wolffd@0 143 for j = 1:mix.ncentres
wolffd@0 144 if min(mix.covars(j,:)) < MIN_COVAR
wolffd@0 145 mix.covars(j,:) = init_covars(j,:);
wolffd@0 146 end
wolffd@0 147 end
wolffd@0 148 end
wolffd@0 149 case 'full'
wolffd@0 150 for j = 1:mix.ncentres
wolffd@0 151 diffs = x - (ones(ndata, 1) * mix.centres(j,:));
wolffd@0 152 diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin));
wolffd@0 153 mix.covars(:,:,j) = (diffs'*diffs + prior_cov)/new_pr(j);
wolffd@0 154 end
wolffd@0 155 if check_covars
wolffd@0 156 % Ensure that no covariance is too small
wolffd@0 157 for j = 1:mix.ncentres
wolffd@0 158 if min(svd(mix.covars(:,:,j))) < MIN_COVAR
wolffd@0 159 mix.covars(:,:,j) = init_covars(:,:,j);
wolffd@0 160 end
wolffd@0 161 end
wolffd@0 162 end
wolffd@0 163 case 'ppca'
wolffd@0 164 for j = 1:mix.ncentres
wolffd@0 165 diffs = x - (ones(ndata, 1) * mix.centres(j,:));
wolffd@0 166 diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin));
wolffd@0 167 [mix.covars(j), mix.U(:,:,j), mix.lambda(j,:)] = ...
wolffd@0 168 ppca((diffs'*diffs)/new_pr(j), mix.ppca_dim);
wolffd@0 169 end
wolffd@0 170 if check_covars
wolffd@0 171 if mix.covars(j) < MIN_COVAR
wolffd@0 172 mix.covars(j) = init_covars(j);
wolffd@0 173 end
wolffd@0 174 end
wolffd@0 175 otherwise
wolffd@0 176 error(['Unknown covariance type ', mix.covar_type]);
wolffd@0 177 end
wolffd@0 178 end
wolffd@0 179
wolffd@0 180 ll = sum(log(gmmprob(mix, x)));
wolffd@0 181 num_iter = n;
wolffd@0 182
wolffd@0 183 %if (display >= 0)
wolffd@0 184 % disp('Warning: Maximum number of iterations has been exceeded');
wolffd@0 185 %end
wolffd@0 186