comparison toolboxes/FullBNT-1.0.7/KPMstats/cwr_test.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 % Verify that my code gives the same results as the 1D example at
2 % http://www.media.mit.edu/physics/publications/books/nmm/files/cwm.m
3
4 seed = 0;
5 rand('state', seed);
6 randn('state', seed);
7 x = (-10:10)';
8 y = double(x > 0);
9 npts = length(x);
10 plot(x,y,'+')
11
12 nclusters = 4;
13 nplot = 100;
14 xplot = 24*(1:nplot)'/nplot - 12;
15
16 mux = 20*rand(1,nclusters) - 10;
17 muy = zeros(1,nclusters);
18 varx = ones(1,nclusters);
19 vary = ones(1,nclusters);
20 pc = 1/nclusters * ones(1,nclusters);
21
22
23 I = repmat(eye(1,1), [1 1 nclusters]);
24 O = repmat(zeros(1,1), [1 1 nclusters]);
25 X = x(:)';
26 Y = y(:)';
27
28 % Do 1 iteration of EM
29
30 %cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy, 'SigmaX', I, 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', 'priorC', pc, 'weightsY', O, 'init_params', 0, 'clamp_weights', 1, 'max_iter', 1, 'cov_priorX', zeros(1,1,nclusters), 'cov_priorY', zeros(1,1,nclusters));
31
32 cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy, 'SigmaX', I, 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', 'priorC', pc, 'weightsY', O, 'create_init_params', 0, 'clamp_weights', 1, 'max_iter', 1);
33
34
35 % Check this matches Gershenfeld's code
36
37 % E step
38 % px(t,c) = prob(x(t) | c)
39 px = exp(-(kron(x,ones(1,nclusters)) ...
40 - kron(ones(npts,1),mux)).^2 ...
41 ./ (2*kron(ones(npts,1),varx))) ...
42 ./ sqrt(2*pi*kron(ones(npts,1),varx));
43 py = exp(-(kron(y,ones(1,nclusters)) ...
44 - kron(ones(npts,1),muy)).^2 ...
45 ./ (2*kron(ones(npts,1),vary))) ...
46 ./ sqrt(2*pi*kron(ones(npts,1),vary));
47 p = px .* py .* kron(ones(npts,1),pc);
48 pp = p ./ kron(sum(p,2),ones(1,nclusters));
49
50 % M step
51 eps = 0.01;
52 pc2 = sum(pp)/npts;
53
54 mux2 = sum(kron(x,ones(1,nclusters)) .* pp) ...
55 ./ (npts*pc2);
56 varx2 = eps + sum((kron(x,ones(1,nclusters)) ...
57 - kron(ones(npts,1),mux2)).^2 .* pp) ...
58 ./ (npts*pc2);
59 muy2 = sum(kron(y,ones(1,nclusters)) .* pp) ...
60 ./ (npts*pc2);
61 vary2 = eps + sum((kron(y,ones(1,nclusters)) ...
62 - kron(ones(npts,1),muy2)).^2 .* pp) ...
63 ./ (npts*pc2);
64
65
66 denom = (npts*pc2);
67 % denom(c) = N*pc(c) = w(c) = sum_t pp(c,t)
68 % since pc(c) = sum_t pp(c,t) / N
69
70 cwr_mux = cwr.muX;
71 assert(approxeq(mux2, cwr_mux))
72 cwr_SigmaX = squeeze(cwr.SigmaX)';
73 assert(approxeq(varx2, cwr_SigmaX))
74
75 cwr_muy = cwr.muY;
76 assert(approxeq(muy2, cwr_muy))
77 cwr_SigmaY = squeeze(cwr.SigmaY)';
78 assert(approxeq(vary2, cwr_SigmaY))
79
80