wolffd@0
|
1 function [mu, Sigma] = mixgauss_Mstep(w, Y, YY, YTY, varargin)
|
wolffd@0
|
2 % MSTEP_COND_GAUSS Compute MLEs for mixture of Gaussians given expected sufficient statistics
|
wolffd@0
|
3 % function [mu, Sigma] = Mstep_cond_gauss(w, Y, YY, YTY, varargin)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % We assume P(Y|Q=i) = N(Y; mu_i, Sigma_i)
|
wolffd@0
|
6 % and w(i,t) = p(Q(t)=i|y(t)) = posterior responsibility
|
wolffd@0
|
7 % See www.ai.mit.edu/~murphyk/Papers/learncg.pdf.
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % INPUTS:
|
wolffd@0
|
10 % w(i) = sum_t w(i,t) = responsibilities for each mixture component
|
wolffd@0
|
11 % If there is only one mixture component (i.e., Q does not exist),
|
wolffd@0
|
12 % then w(i) = N = nsamples, and
|
wolffd@0
|
13 % all references to i can be replaced by 1.
|
wolffd@0
|
14 % YY(:,:,i) = sum_t w(i,t) y(:,t) y(:,t)' = weighted outer product
|
wolffd@0
|
15 % Y(:,i) = sum_t w(i,t) y(:,t) = weighted observations
|
wolffd@0
|
16 % YTY(i) = sum_t w(i,t) y(:,t)' y(:,t) = weighted inner product
|
wolffd@0
|
17 % You only need to pass in YTY if Sigma is to be estimated as spherical.
|
wolffd@0
|
18 %
|
wolffd@0
|
19 % Optional parameters may be passed as 'param_name', param_value pairs.
|
wolffd@0
|
20 % Parameter names are shown below; default values in [] - if none, argument is mandatory.
|
wolffd@0
|
21 %
|
wolffd@0
|
22 % 'cov_type' - 'full', 'diag' or 'spherical' ['full']
|
wolffd@0
|
23 % 'tied_cov' - 1 (Sigma) or 0 (Sigma_i) [0]
|
wolffd@0
|
24 % 'clamped_cov' - pass in clamped value, or [] if unclamped [ [] ]
|
wolffd@0
|
25 % 'clamped_mean' - pass in clamped value, or [] if unclamped [ [] ]
|
wolffd@0
|
26 % 'cov_prior' - Lambda_i, added to YY(:,:,i) [0.01*eye(d,d,Q)]
|
wolffd@0
|
27 %
|
wolffd@0
|
28 % If covariance is tied, Sigma has size d*d.
|
wolffd@0
|
29 % But diagonal and spherical covariances are represented in full size.
|
wolffd@0
|
30
|
wolffd@0
|
31 [cov_type, tied_cov, clamped_cov, clamped_mean, cov_prior, other] = ...
|
wolffd@0
|
32 process_options(varargin,...
|
wolffd@0
|
33 'cov_type', 'full', 'tied_cov', 0, 'clamped_cov', [], 'clamped_mean', [], ...
|
wolffd@0
|
34 'cov_prior', []);
|
wolffd@0
|
35
|
wolffd@0
|
36 [Ysz Q] = size(Y);
|
wolffd@0
|
37 N = sum(w);
|
wolffd@0
|
38 if isempty(cov_prior)
|
wolffd@0
|
39 %cov_prior = zeros(Ysz, Ysz, Q);
|
wolffd@0
|
40 %for q=1:Q
|
wolffd@0
|
41 % cov_prior(:,:,q) = 0.01*cov(Y(:,q)');
|
wolffd@0
|
42 %end
|
wolffd@0
|
43 cov_prior = repmat(0.01*eye(Ysz,Ysz), [1 1 Q]);
|
wolffd@0
|
44 end
|
wolffd@0
|
45 %YY = reshape(YY, [Ysz Ysz Q]) + cov_prior; % regularize the scatter matrix
|
wolffd@0
|
46 YY = reshape(YY, [Ysz Ysz Q]);
|
wolffd@0
|
47
|
wolffd@0
|
48 % Set any zero weights to one before dividing
|
wolffd@0
|
49 % This is valid because w(i)=0 => Y(:,i)=0, etc
|
wolffd@0
|
50 w = w + (w==0);
|
wolffd@0
|
51
|
wolffd@0
|
52 if ~isempty(clamped_mean)
|
wolffd@0
|
53 mu = clamped_mean;
|
wolffd@0
|
54 else
|
wolffd@0
|
55 % eqn 6
|
wolffd@0
|
56 %mu = Y ./ repmat(w(:)', [Ysz 1]);% Y may have a funny size
|
wolffd@0
|
57 mu = zeros(Ysz, Q);
|
wolffd@0
|
58 for i=1:Q
|
wolffd@0
|
59 mu(:,i) = Y(:,i) / w(i);
|
wolffd@0
|
60 end
|
wolffd@0
|
61 end
|
wolffd@0
|
62
|
wolffd@0
|
63 if ~isempty(clamped_cov)
|
wolffd@0
|
64 Sigma = clamped_cov;
|
wolffd@0
|
65 return;
|
wolffd@0
|
66 end
|
wolffd@0
|
67
|
wolffd@0
|
68 if ~tied_cov
|
wolffd@0
|
69 Sigma = zeros(Ysz,Ysz,Q);
|
wolffd@0
|
70 for i=1:Q
|
wolffd@0
|
71 if cov_type(1) == 's'
|
wolffd@0
|
72 % eqn 17
|
wolffd@0
|
73 s2 = (1/Ysz)*( (YTY(i)/w(i)) - mu(:,i)'*mu(:,i) );
|
wolffd@0
|
74 Sigma(:,:,i) = s2 * eye(Ysz);
|
wolffd@0
|
75 else
|
wolffd@0
|
76 % eqn 12
|
wolffd@0
|
77 SS = YY(:,:,i)/w(i) - mu(:,i)*mu(:,i)';
|
wolffd@0
|
78 if cov_type(1)=='d'
|
wolffd@0
|
79 SS = diag(diag(SS));
|
wolffd@0
|
80 end
|
wolffd@0
|
81 Sigma(:,:,i) = SS;
|
wolffd@0
|
82 end
|
wolffd@0
|
83 end
|
wolffd@0
|
84 else % tied cov
|
wolffd@0
|
85 if cov_type(1) == 's'
|
wolffd@0
|
86 % eqn 19
|
wolffd@0
|
87 s2 = (1/(N*Ysz))*(sum(YTY,2) + sum(diag(mu'*mu) .* w));
|
wolffd@0
|
88 Sigma = s2*eye(Ysz);
|
wolffd@0
|
89 else
|
wolffd@0
|
90 SS = zeros(Ysz, Ysz);
|
wolffd@0
|
91 % eqn 15
|
wolffd@0
|
92 for i=1:Q % probably could vectorize this...
|
wolffd@0
|
93 SS = SS + YY(:,:,i)/N - mu(:,i)*mu(:,i)';
|
wolffd@0
|
94 end
|
wolffd@0
|
95 if cov_type(1) == 'd'
|
wolffd@0
|
96 Sigma = diag(diag(SS));
|
wolffd@0
|
97 else
|
wolffd@0
|
98 Sigma = SS;
|
wolffd@0
|
99 end
|
wolffd@0
|
100 end
|
wolffd@0
|
101 end
|
wolffd@0
|
102
|
wolffd@0
|
103 if tied_cov
|
wolffd@0
|
104 Sigma = repmat(Sigma, [1 1 Q]);
|
wolffd@0
|
105 end
|
wolffd@0
|
106 Sigma = Sigma + cov_prior;
|