wolffd@0
|
1 function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin)
|
wolffd@0
|
2 %GMMEM_KPM Like GMMEM, but with additional optional arguments
|
wolffd@0
|
3 % function [mix, num_iter, ll] = gmmem_kpm(mix, x, varargin)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % Input:
|
wolffd@0
|
6 % mix - structure created by gmminit or gmmem_multi_restart
|
wolffd@0
|
7 % data - each row is an example
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % Output:
|
wolffd@0
|
10 % mix - modified structure
|
wolffd@0
|
11 % num_iter - number of iterations needed to reach convergence
|
wolffd@0
|
12 % ll - final log likelihood
|
wolffd@0
|
13 %
|
wolffd@0
|
14 % [ ... ] = gmmem_kpm(..., 'param1',val1, 'param2',val2, ...) allows you to
|
wolffd@0
|
15 % specify optional parameter name/value pairs.
|
wolffd@0
|
16 % Parameters are below [default value in brackets]
|
wolffd@0
|
17 %
|
wolffd@0
|
18 % 'max_iter' - maximum number of EM iterations [10]
|
wolffd@0
|
19 % 'll_thresh' - change in log-likelihood threshold for convergence [1e-2]
|
wolffd@0
|
20 % 'verbose' - 1 means display output while running [0]
|
wolffd@0
|
21 % 'prior_cov' - this will be added to each estimated covariance
|
wolffd@0
|
22 % to prevent singularities [1e-3*eye(d)]
|
wolffd@0
|
23 % 'fn' - this function, if non-empty, will be called at every iteration
|
wolffd@0
|
24 % (e.g., to display the parameters as they evolve) [ [] ]
|
wolffd@0
|
25 % The fn is called as fn(mix, x, iter_num, fnargs).
|
wolffd@0
|
26 % It is also called before the iteration starts as
|
wolffd@0
|
27 % fn(mix, x, -1, fnargs), which can be used to initialize things.
|
wolffd@0
|
28 % 'fnargs' - additional arguments to be passed to fn [ {} ]
|
wolffd@0
|
29 %
|
wolffd@0
|
30 % Modified by Kevin P Murphy, 29 Dec 2002
|
wolffd@0
|
31
|
wolffd@0
|
32
|
wolffd@0
|
33 % Check that inputs are consistent
|
wolffd@0
|
34 errstring = consist(mix, 'gmm', x);
|
wolffd@0
|
35 if ~isempty(errstring)
|
wolffd@0
|
36 error(errstring);
|
wolffd@0
|
37 end
|
wolffd@0
|
38
|
wolffd@0
|
39 [ndata, xdim] = size(x);
|
wolffd@0
|
40
|
wolffd@0
|
41 [max_iter, ll_thresh, verbose, prior_cov, fn, fnargs] = ...
|
wolffd@0
|
42 process_options(varargin, ...
|
wolffd@0
|
43 'max_iter', 10, 'll_thresh', 1e-2, 'verbose', 1, ...
|
wolffd@0
|
44 'prior_cov', 1e-3*eye(xdim), 'fn', [], 'fnargs', {});
|
wolffd@0
|
45
|
wolffd@0
|
46 options = foptions;
|
wolffd@0
|
47 if verbose, options(1)=1; else options(1)=-1; end
|
wolffd@0
|
48 options(14) = max_iter;
|
wolffd@0
|
49 options(3) = ll_thresh;
|
wolffd@0
|
50
|
wolffd@0
|
51
|
wolffd@0
|
52 % Sort out the options
|
wolffd@0
|
53 if (options(14))
|
wolffd@0
|
54 niters = options(14);
|
wolffd@0
|
55 else
|
wolffd@0
|
56 niters = 100;
|
wolffd@0
|
57 end
|
wolffd@0
|
58
|
wolffd@0
|
59 display = options(1);
|
wolffd@0
|
60 test = 0;
|
wolffd@0
|
61 if options(3) > 0.0
|
wolffd@0
|
62 test = 1; % Test log likelihood for termination
|
wolffd@0
|
63 end
|
wolffd@0
|
64
|
wolffd@0
|
65 check_covars = 0;
|
wolffd@0
|
66 if options(5) >= 1
|
wolffd@0
|
67 if display >= 0
|
wolffd@0
|
68 disp('check_covars is on');
|
wolffd@0
|
69 end
|
wolffd@0
|
70 check_covars = 1; % Ensure that covariances don't collapse
|
wolffd@0
|
71 MIN_COVAR = eps; % Minimum singular value of covariance matrix
|
wolffd@0
|
72 init_covars = mix.covars;
|
wolffd@0
|
73 end
|
wolffd@0
|
74
|
wolffd@0
|
75 mix0 = mix; % save init values for debugging
|
wolffd@0
|
76
|
wolffd@0
|
77 if ~isempty(fn)
|
wolffd@0
|
78 feval(fn, mix, x, -1, fnargs{:});
|
wolffd@0
|
79 end
|
wolffd@0
|
80
|
wolffd@0
|
81 % Main loop of algorithm
|
wolffd@0
|
82 for n = 1:niters
|
wolffd@0
|
83
|
wolffd@0
|
84 % Calculate posteriors based on old parameters
|
wolffd@0
|
85 [post, act] = gmmpost(mix, x);
|
wolffd@0
|
86
|
wolffd@0
|
87 % Calculate error value if needed
|
wolffd@0
|
88 if (display | test)
|
wolffd@0
|
89 prob = act*(mix.priors)';
|
wolffd@0
|
90 % Error value is negative log likelihood of data
|
wolffd@0
|
91 e = - sum(log(prob + eps));
|
wolffd@0
|
92 if display > 0
|
wolffd@0
|
93 fprintf(1, 'Cycle %4d Error %11.6f\n', n, e);
|
wolffd@0
|
94 end
|
wolffd@0
|
95 if test
|
wolffd@0
|
96 if (n > 1 & abs(e - eold) < options(3))
|
wolffd@0
|
97 options(8) = e;
|
wolffd@0
|
98 ll = -e;
|
wolffd@0
|
99 num_iter = n;
|
wolffd@0
|
100 return; %%%%%%%%%%%%%%%% Exit here if converged
|
wolffd@0
|
101 else
|
wolffd@0
|
102 eold = e;
|
wolffd@0
|
103 end
|
wolffd@0
|
104 end
|
wolffd@0
|
105 end
|
wolffd@0
|
106
|
wolffd@0
|
107 if ~isempty(fn)
|
wolffd@0
|
108 feval(fn, mix, x, n, fnargs{:});
|
wolffd@0
|
109 end
|
wolffd@0
|
110
|
wolffd@0
|
111 % Adjust the new estimates for the parameters
|
wolffd@0
|
112 new_pr = sum(post, 1);
|
wolffd@0
|
113 new_c = post' * x;
|
wolffd@0
|
114
|
wolffd@0
|
115 % Now move new estimates to old parameter vectors
|
wolffd@0
|
116 mix.priors = new_pr ./ ndata;
|
wolffd@0
|
117
|
wolffd@0
|
118 mix.centres = new_c ./ (new_pr' * ones(1, mix.nin));
|
wolffd@0
|
119
|
wolffd@0
|
120 switch mix.covar_type
|
wolffd@0
|
121 case 'spherical'
|
wolffd@0
|
122 n2 = dist2(x, mix.centres);
|
wolffd@0
|
123 for j = 1:mix.ncentres
|
wolffd@0
|
124 v(j) = (post(:,j)'*n2(:,j));
|
wolffd@0
|
125 end
|
wolffd@0
|
126 mix.covars = ((v./new_pr) + sum(diag(prior_cov)))./mix.nin;
|
wolffd@0
|
127 if check_covars
|
wolffd@0
|
128 % Ensure that no covariance is too small
|
wolffd@0
|
129 for j = 1:mix.ncentres
|
wolffd@0
|
130 if mix.covars(j) < MIN_COVAR
|
wolffd@0
|
131 mix.covars(j) = init_covars(j);
|
wolffd@0
|
132 end
|
wolffd@0
|
133 end
|
wolffd@0
|
134 end
|
wolffd@0
|
135 case 'diag'
|
wolffd@0
|
136 for j = 1:mix.ncentres
|
wolffd@0
|
137 diffs = x - (ones(ndata, 1) * mix.centres(j,:));
|
wolffd@0
|
138 wts = (post(:,j)*ones(1, mix.nin));
|
wolffd@0
|
139 mix.covars(j,:) = sum((diffs.*diffs).*wts + prior_cov, 1)./new_pr(j);
|
wolffd@0
|
140 end
|
wolffd@0
|
141 if check_covars
|
wolffd@0
|
142 % Ensure that no covariance is too small
|
wolffd@0
|
143 for j = 1:mix.ncentres
|
wolffd@0
|
144 if min(mix.covars(j,:)) < MIN_COVAR
|
wolffd@0
|
145 mix.covars(j,:) = init_covars(j,:);
|
wolffd@0
|
146 end
|
wolffd@0
|
147 end
|
wolffd@0
|
148 end
|
wolffd@0
|
149 case 'full'
|
wolffd@0
|
150 for j = 1:mix.ncentres
|
wolffd@0
|
151 diffs = x - (ones(ndata, 1) * mix.centres(j,:));
|
wolffd@0
|
152 diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin));
|
wolffd@0
|
153 mix.covars(:,:,j) = (diffs'*diffs + prior_cov)/new_pr(j);
|
wolffd@0
|
154 end
|
wolffd@0
|
155 if check_covars
|
wolffd@0
|
156 % Ensure that no covariance is too small
|
wolffd@0
|
157 for j = 1:mix.ncentres
|
wolffd@0
|
158 if min(svd(mix.covars(:,:,j))) < MIN_COVAR
|
wolffd@0
|
159 mix.covars(:,:,j) = init_covars(:,:,j);
|
wolffd@0
|
160 end
|
wolffd@0
|
161 end
|
wolffd@0
|
162 end
|
wolffd@0
|
163 case 'ppca'
|
wolffd@0
|
164 for j = 1:mix.ncentres
|
wolffd@0
|
165 diffs = x - (ones(ndata, 1) * mix.centres(j,:));
|
wolffd@0
|
166 diffs = diffs.*(sqrt(post(:,j))*ones(1, mix.nin));
|
wolffd@0
|
167 [mix.covars(j), mix.U(:,:,j), mix.lambda(j,:)] = ...
|
wolffd@0
|
168 ppca((diffs'*diffs)/new_pr(j), mix.ppca_dim);
|
wolffd@0
|
169 end
|
wolffd@0
|
170 if check_covars
|
wolffd@0
|
171 if mix.covars(j) < MIN_COVAR
|
wolffd@0
|
172 mix.covars(j) = init_covars(j);
|
wolffd@0
|
173 end
|
wolffd@0
|
174 end
|
wolffd@0
|
175 otherwise
|
wolffd@0
|
176 error(['Unknown covariance type ', mix.covar_type]);
|
wolffd@0
|
177 end
|
wolffd@0
|
178 end
|
wolffd@0
|
179
|
wolffd@0
|
180 ll = sum(log(gmmprob(mix, x)));
|
wolffd@0
|
181 num_iter = n;
|
wolffd@0
|
182
|
wolffd@0
|
183 %if (display >= 0)
|
wolffd@0
|
184 % disp('Warning: Maximum number of iterations has been exceeded');
|
wolffd@0
|
185 %end
|
wolffd@0
|
186
|