wolffd@0
|
1 % Verify that my code gives the same results as the 1D example at
|
wolffd@0
|
2 % http://www.media.mit.edu/physics/publications/books/nmm/files/cwm.m
|
wolffd@0
|
3
|
wolffd@0
|
4 seed = 0;
|
wolffd@0
|
5 rand('state', seed);
|
wolffd@0
|
6 randn('state', seed);
|
wolffd@0
|
7 x = (-10:10)';
|
wolffd@0
|
8 y = double(x > 0);
|
wolffd@0
|
9 npts = length(x);
|
wolffd@0
|
10 plot(x,y,'+')
|
wolffd@0
|
11
|
wolffd@0
|
12 nclusters = 4;
|
wolffd@0
|
13 nplot = 100;
|
wolffd@0
|
14 xplot = 24*(1:nplot)'/nplot - 12;
|
wolffd@0
|
15
|
wolffd@0
|
16 mux = 20*rand(1,nclusters) - 10;
|
wolffd@0
|
17 muy = zeros(1,nclusters);
|
wolffd@0
|
18 varx = ones(1,nclusters);
|
wolffd@0
|
19 vary = ones(1,nclusters);
|
wolffd@0
|
20 pc = 1/nclusters * ones(1,nclusters);
|
wolffd@0
|
21
|
wolffd@0
|
22
|
wolffd@0
|
23 I = repmat(eye(1,1), [1 1 nclusters]);
|
wolffd@0
|
24 O = repmat(zeros(1,1), [1 1 nclusters]);
|
wolffd@0
|
25 X = x(:)';
|
wolffd@0
|
26 Y = y(:)';
|
wolffd@0
|
27
|
wolffd@0
|
28 % Do 1 iteration of EM
|
wolffd@0
|
29
|
wolffd@0
|
30 %cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy, 'SigmaX', I, 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', 'priorC', pc, 'weightsY', O, 'init_params', 0, 'clamp_weights', 1, 'max_iter', 1, 'cov_priorX', zeros(1,1,nclusters), 'cov_priorY', zeros(1,1,nclusters));
|
wolffd@0
|
31
|
wolffd@0
|
32 cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy, 'SigmaX', I, 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', 'priorC', pc, 'weightsY', O, 'create_init_params', 0, 'clamp_weights', 1, 'max_iter', 1);
|
wolffd@0
|
33
|
wolffd@0
|
34
|
wolffd@0
|
35 % Check this matches Gershenfeld's code
|
wolffd@0
|
36
|
wolffd@0
|
37 % E step
|
wolffd@0
|
38 % px(t,c) = prob(x(t) | c)
|
wolffd@0
|
39 px = exp(-(kron(x,ones(1,nclusters)) ...
|
wolffd@0
|
40 - kron(ones(npts,1),mux)).^2 ...
|
wolffd@0
|
41 ./ (2*kron(ones(npts,1),varx))) ...
|
wolffd@0
|
42 ./ sqrt(2*pi*kron(ones(npts,1),varx));
|
wolffd@0
|
43 py = exp(-(kron(y,ones(1,nclusters)) ...
|
wolffd@0
|
44 - kron(ones(npts,1),muy)).^2 ...
|
wolffd@0
|
45 ./ (2*kron(ones(npts,1),vary))) ...
|
wolffd@0
|
46 ./ sqrt(2*pi*kron(ones(npts,1),vary));
|
wolffd@0
|
47 p = px .* py .* kron(ones(npts,1),pc);
|
wolffd@0
|
48 pp = p ./ kron(sum(p,2),ones(1,nclusters));
|
wolffd@0
|
49
|
wolffd@0
|
50 % M step
|
wolffd@0
|
51 eps = 0.01;
|
wolffd@0
|
52 pc2 = sum(pp)/npts;
|
wolffd@0
|
53
|
wolffd@0
|
54 mux2 = sum(kron(x,ones(1,nclusters)) .* pp) ...
|
wolffd@0
|
55 ./ (npts*pc2);
|
wolffd@0
|
56 varx2 = eps + sum((kron(x,ones(1,nclusters)) ...
|
wolffd@0
|
57 - kron(ones(npts,1),mux2)).^2 .* pp) ...
|
wolffd@0
|
58 ./ (npts*pc2);
|
wolffd@0
|
59 muy2 = sum(kron(y,ones(1,nclusters)) .* pp) ...
|
wolffd@0
|
60 ./ (npts*pc2);
|
wolffd@0
|
61 vary2 = eps + sum((kron(y,ones(1,nclusters)) ...
|
wolffd@0
|
62 - kron(ones(npts,1),muy2)).^2 .* pp) ...
|
wolffd@0
|
63 ./ (npts*pc2);
|
wolffd@0
|
64
|
wolffd@0
|
65
|
wolffd@0
|
66 denom = (npts*pc2);
|
wolffd@0
|
67 % denom(c) = N*pc(c) = w(c) = sum_t pp(c,t)
|
wolffd@0
|
68 % since pc(c) = sum_t pp(c,t) / N
|
wolffd@0
|
69
|
wolffd@0
|
70 cwr_mux = cwr.muX;
|
wolffd@0
|
71 assert(approxeq(mux2, cwr_mux))
|
wolffd@0
|
72 cwr_SigmaX = squeeze(cwr.SigmaX)';
|
wolffd@0
|
73 assert(approxeq(varx2, cwr_SigmaX))
|
wolffd@0
|
74
|
wolffd@0
|
75 cwr_muy = cwr.muY;
|
wolffd@0
|
76 assert(approxeq(muy2, cwr_muy))
|
wolffd@0
|
77 cwr_SigmaY = squeeze(cwr.SigmaY)';
|
wolffd@0
|
78 assert(approxeq(vary2, cwr_SigmaY))
|
wolffd@0
|
79
|
wolffd@0
|
80
|