Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlabKPM/gmmem2.m @ 0:e9a9cd732c1e tip
first hg version after svn
| author | wolffd |
|---|---|
| date | Tue, 10 Feb 2015 15:05:51 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:e9a9cd732c1e |
|---|---|
| 1 function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin) | |
| 2 %GMMEM_KPM Like GMMEM, but with additional optional arguments | |
| 3 % function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin) | |
| 4 % | |
| 5 % Input: | |
| 6 % mix - structure created by gmminit or gmmem_multi_restart | |
| 7 % data - each row is an example | |
| 8 % | |
| 9 % Output: | |
| 10 % mix - modified structure | |
| 11 % num_iter - number of iterations needed to reach convergence | |
| 12 % ll - final log likelihood | |
| 13 % | |
| 14 % [ ... ] = gmmem_kpm(..., 'param1',val1, 'param2',val2, ...) allows you to | |
| 15 % specify optional parameter name/value pairs. | |
| 16 % Parameters are below [default value in brackets] | |
| 17 % | |
| 18 % 'max_iter' - maximum number of EM iterations [10] | |
| 19 % 'll_thresh' - change in log-likelihood threshold for convergence [1e-2] | |
| 20 % 'verbose' - 1 means display output while running [0] | |
| 21 % 'prior_cov' - this will be added to each estimated covariance | |
| 22 % to prevent singularities [1e-3*eye(d)] | |
| 23 % 'fn' - this function, if non-empty, will be called at every iteration | |
| 24 % (e.g., to display the parameters as they evolve) [ [] ] | |
| 25 % The fn is called as fn(mix, x, iter_num, fnargs). | |
| 26 % It is also called before the iteration starts as | |
| 27 % fn(mix, x, -1, fnargs), which can be used to initialize things. | |
| 28 % 'fnargs' - additional arguments to be passed to fn [ {} ] | |
| 29 % | |
| 30 % Modified by Kevin P Murphy, 29 Dec 2002 | |
| 31 | |
| 32 | |
| 33 % Check that inputs are consistent | |
| 34 errstring = consist(mix, 'gmm', x); | |
| 35 if ~isempty(errstring) | |
| 36 error(errstring); | |
| 37 end | |
| 38 | |
| 39 [ndata, xdim] = size(x); | |
| 40 | |
| 41 [max_iter, ll_thresh, verbose, prior_cov, fn, fnargs] = ... | |
| 42 process_options(varargin, ... | |
| 43 'max_iter', 10, 'll_thresh', 1e-2, 'verbose', 1, ... | |
| 44 'prior_cov', 1e-3*eye(xdim), 'fn', [], 'fnargs', {}); | |
| 45 | |
| 46 options = foptions; | |
| 47 if verbose, options(1)=1; else options(1)=-1; end | |
| 48 options(14) = max_iter; | |
| 49 options(3) = ll_thresh; | |
| 50 | |
| 51 | |
| 52 % Sort out the options | |
| 53 if (options(14)) | |
| 54 niters = options(14); | |
| 55 else | |
| 56 niters = 100; | |
| 57 end | |
| 58 | |
| 59 display = options(1); | |
| 60 test = 0; | |
| 61 if options(3) > 0.0 | |
| 62 test = 1; % Test log likelihood for termination | |
| 63 end | |
| 64 | |
| 65 check_covars = 0; | |
| 66 if options(5) >= 1 | |
| 67 if display >= 0 | |
| 68 disp('check_covars is on'); | |
| 69 end | |
| 70 check_covars = 1; % Ensure that covariances don't collapse | |
| 71 MIN_COVAR = eps; % Minimum singular value of covariance matrix | |
| 72 init_covars = mix.covars; | |
| 73 end | |
| 74 | |
| 75 mix0 = mix; % save init values for debugging | |
| 76 | |
| 77 if ~isempty(fn) | |
| 78 feval(fn, mix, x, -1, fnargs{:}); | |
| 79 end | |
| 80 | |
| 81 % Main loop of algorithm | |
| 82 for n = 1:niters | |
| 83 | |
| 84 % Calculate posteriors based on old parameters | |
| 85 [post, act] = gmmpost(mix, x); | |
| 86 | |
| 87 % Calculate error value if needed | |
| 88 if (display | test) | |
| 89 prob = act*(mix.priors)'; | |
| 90 % Error value is negative log likelihood of data | |
| 91 e = - sum(log(prob + eps)); | |
| 92 if display > 0 | |
| 93 fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); | |
| 94 end | |
| 95 if test | |
| 96 if (n > 1 & abs(e - eold) < options(3)) | |
| 97 options(8) = e; | |
| 98 ll = -e; | |
| 99 num_iter = n; | |
| 100 return; %%%%%%%%%%%%%%%% Exit here if converged | |
| 101 else | |
| 102 eold = e; | |
| 103 end | |
| 104 end | |
| 105 end | |
| 106 | |
| 107 if ~isempty(fn) | |
| 108 feval(fn, mix, x, n, fnargs{:}); | |
| 109 end | |
| 110 | |
| 111 % Adjust the new estimates for the parameters | |
| 112 new_pr = sum(post, 1); | |
| 113 new_c = post' * x; | |
| 114 | |
| 115 % Now move new estimates to old parameter vectors | |
| 116 mix.priors = new_pr ./ ndata; | |
| 117 | |
| 118 mix.centres = new_c ./ (new_pr' * ones(1, mix.nin)); | |
| 119 | |
| 120 switch mix.covar_type | |
| 121 case 'spherical' | |
| 122 n2 = dist2(x, mix.centres); | |
| 123 for j = 1:mix.ncentres | |
| 124 v(j) = (post(:,j)'*n2(:,j)); | |
| 125 end | |
| 126 mix.covars = ((v./new_pr) + sum(diag(prior_cov)))./mix.nin; | |
| 127 if check_covars | |
| 128 % Ensure that no covariance is too small | |
| 129 for j = 1:mix.ncentres | |
| 130 if mix.covars(j) < MIN_COVAR | |
| 131 mix.covars(j) = init_covars(j); | |
| 132 end | |
| 133 end | |
| 134 end | |
| 135 case 'diag' | |
| 136 for j = 1:mix.ncentres | |
| 137 diffs = x - (ones(ndata, 1) * mix.centres(j,:)); | |
| 138 wts = (post(:,j)*ones(1, mix.nin)); | |
| 139 mix.covars(j,:) = sum((diffs.*diffs).*wts + prior_cov, 1)./new_pr(j); | |
| 140 end | |
| 141 if check_covars | |
| 142 % Ensure that no covariance is too small | |
| 143 for j = 1:mix.ncentres | |
| 144 if min(mix.covars(j,:)) < MIN_COVAR | |
| 145 mix.covars(j,:) = init_covars(j,:); | |
| 146 end | |
| 147 end | |
| 148 end | |
| 149 case 'full' | |
| 150 for j = 1:mix.ncentres | |
| 151 diffs = x - (ones(ndata, 1) * mix.centres(j,:)); | |
| 152 diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); | |
| 153 mix.covars(:,:,j) = (diffs'*diffs + prior_cov)/new_pr(j); | |
| 154 end | |
| 155 if check_covars | |
| 156 % Ensure that no covariance is too small | |
| 157 for j = 1:mix.ncentres | |
| 158 if min(svd(mix.covars(:,:,j))) < MIN_COVAR | |
| 159 mix.covars(:,:,j) = init_covars(:,:,j); | |
| 160 end | |
| 161 end | |
| 162 end | |
| 163 case 'ppca' | |
| 164 for j = 1:mix.ncentres | |
| 165 diffs = x - (ones(ndata, 1) * mix.centres(j,:)); | |
| 166 diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin)); | |
| 167 [mix.covars(j), mix.U(:,:,j), mix.lambda(j,:)] = ... | |
| 168 ppca((diffs'*diffs)/new_pr(j), mix.ppca_dim); | |
| 169 end | |
| 170 if check_covars | |
| 171 if mix.covars(j) < MIN_COVAR | |
| 172 mix.covars(j) = init_covars(j); | |
| 173 end | |
| 174 end | |
| 175 otherwise | |
| 176 error(['Unknown covariance type ', mix.covar_type]); | |
| 177 end | |
| 178 end | |
| 179 | |
| 180 ll = sum(log(gmmprob(mix, x))); | |
| 181 num_iter = n; | |
| 182 | |
| 183 %if (display >= 0) | |
| 184 % disp('Warning: Maximum number of iterations has been exceeded'); | |
| 185 %end | |
| 186 |
