Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/KPMstats/cwr_test.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % Verify that my code gives the same results as the 1D example at | |
2 % http://www.media.mit.edu/physics/publications/books/nmm/files/cwm.m | |
3 | |
4 seed = 0; | |
5 rand('state', seed); | |
6 randn('state', seed); | |
7 x = (-10:10)'; | |
8 y = double(x > 0); | |
9 npts = length(x); | |
10 plot(x,y,'+') | |
11 | |
12 nclusters = 4; | |
13 nplot = 100; | |
14 xplot = 24*(1:nplot)'/nplot - 12; | |
15 | |
16 mux = 20*rand(1,nclusters) - 10; | |
17 muy = zeros(1,nclusters); | |
18 varx = ones(1,nclusters); | |
19 vary = ones(1,nclusters); | |
20 pc = 1/nclusters * ones(1,nclusters); | |
21 | |
22 | |
23 I = repmat(eye(1,1), [1 1 nclusters]); | |
24 O = repmat(zeros(1,1), [1 1 nclusters]); | |
25 X = x(:)'; | |
26 Y = y(:)'; | |
27 | |
28 % Do 1 iteration of EM | |
29 | |
30 %cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy, 'SigmaX', I, 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', 'priorC', pc, 'weightsY', O, 'init_params', 0, 'clamp_weights', 1, 'max_iter', 1, 'cov_priorX', zeros(1,1,nclusters), 'cov_priorY', zeros(1,1,nclusters)); | |
31 | |
32 cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy, 'SigmaX', I, 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', 'priorC', pc, 'weightsY', O, 'create_init_params', 0, 'clamp_weights', 1, 'max_iter', 1); | |
33 | |
34 | |
35 % Check this matches Gershenfeld's code | |
36 | |
37 % E step | |
38 % px(t,c) = prob(x(t) | c) | |
39 px = exp(-(kron(x,ones(1,nclusters)) ... | |
40 - kron(ones(npts,1),mux)).^2 ... | |
41 ./ (2*kron(ones(npts,1),varx))) ... | |
42 ./ sqrt(2*pi*kron(ones(npts,1),varx)); | |
43 py = exp(-(kron(y,ones(1,nclusters)) ... | |
44 - kron(ones(npts,1),muy)).^2 ... | |
45 ./ (2*kron(ones(npts,1),vary))) ... | |
46 ./ sqrt(2*pi*kron(ones(npts,1),vary)); | |
47 p = px .* py .* kron(ones(npts,1),pc); | |
48 pp = p ./ kron(sum(p,2),ones(1,nclusters)); | |
49 | |
50 % M step | |
51 eps = 0.01; | |
52 pc2 = sum(pp)/npts; | |
53 | |
54 mux2 = sum(kron(x,ones(1,nclusters)) .* pp) ... | |
55 ./ (npts*pc2); | |
56 varx2 = eps + sum((kron(x,ones(1,nclusters)) ... | |
57 - kron(ones(npts,1),mux2)).^2 .* pp) ... | |
58 ./ (npts*pc2); | |
59 muy2 = sum(kron(y,ones(1,nclusters)) .* pp) ... | |
60 ./ (npts*pc2); | |
61 vary2 = eps + sum((kron(y,ones(1,nclusters)) ... | |
62 - kron(ones(npts,1),muy2)).^2 .* pp) ... | |
63 ./ (npts*pc2); | |
64 | |
65 | |
66 denom = (npts*pc2); | |
67 % denom(c) = N*pc(c) = w(c) = sum_t pp(c,t) | |
68 % since pc(c) = sum_t pp(c,t) / N | |
69 | |
70 cwr_mux = cwr.muX; | |
71 assert(approxeq(mux2, cwr_mux)) | |
72 cwr_SigmaX = squeeze(cwr.SigmaX)'; | |
73 assert(approxeq(varx2, cwr_SigmaX)) | |
74 | |
75 cwr_muy = cwr.muY; | |
76 assert(approxeq(muy2, cwr_muy)) | |
77 cwr_SigmaY = squeeze(cwr.SigmaY)'; | |
78 assert(approxeq(vary2, cwr_SigmaY)) | |
79 | |
80 |