diff toolboxes/FullBNT-1.0.7/KPMstats/cwr_test.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/KPMstats/cwr_test.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,80 @@
+% Verify that my code gives the same results as the 1D example at
+% http://www.media.mit.edu/physics/publications/books/nmm/files/cwm.m
+
+seed = 0;
+rand('state', seed);
+randn('state', seed);
+x = (-10:10)';
+y = double(x > 0);
+npts = length(x);
+plot(x,y,'+')
+
+nclusters = 4;
+nplot = 100;
+xplot = 24*(1:nplot)'/nplot - 12;
+
+mux = 20*rand(1,nclusters) - 10;
+muy = zeros(1,nclusters);
+varx = ones(1,nclusters);
+vary = ones(1,nclusters);
+pc = 1/nclusters * ones(1,nclusters);
+
+
+I = repmat(eye(1,1), [1 1 nclusters]);
+O = repmat(zeros(1,1), [1 1 nclusters]);
+X = x(:)';
+Y = y(:)';
+
+% Do 1 iteration of EM
+
+%cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy,  'SigmaX', I, 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', 'priorC', pc, 'weightsY', O,  'init_params', 0, 'clamp_weights', 1, 'max_iter', 1, 'cov_priorX', zeros(1,1,nclusters), 'cov_priorY', zeros(1,1,nclusters));
+
+cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy,  'SigmaX', I, 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', 'priorC', pc, 'weightsY', O,  'create_init_params', 0, 'clamp_weights', 1, 'max_iter', 1);
+
+
+% Check this matches Gershenfeld's code
+
+% E step
+% px(t,c) = prob(x(t) | c)
+px = exp(-(kron(x,ones(1,nclusters)) ...
+	   - kron(ones(npts,1),mux)).^2 ...
+	 ./ (2*kron(ones(npts,1),varx))) ...
+     ./ sqrt(2*pi*kron(ones(npts,1),varx));
+py = exp(-(kron(y,ones(1,nclusters)) ...
+	   - kron(ones(npts,1),muy)).^2 ...
+	 ./ (2*kron(ones(npts,1),vary))) ...
+     ./ sqrt(2*pi*kron(ones(npts,1),vary));
+p = px .* py .* kron(ones(npts,1),pc);
+pp = p ./ kron(sum(p,2),ones(1,nclusters));
+
+% M step
+eps = 0.01;
+pc2 = sum(pp)/npts;
+
+mux2 = sum(kron(x,ones(1,nclusters)) .* pp) ...
+      ./ (npts*pc2);
+varx2 = eps + sum((kron(x,ones(1,nclusters)) ...
+		  - kron(ones(npts,1),mux2)).^2 .* pp) ...
+       ./ (npts*pc2);
+muy2 = sum(kron(y,ones(1,nclusters)) .* pp) ...
+      ./ (npts*pc2);
+vary2 = eps + sum((kron(y,ones(1,nclusters)) ...
+		  - kron(ones(npts,1),muy2)).^2 .* pp) ...
+       ./ (npts*pc2);
+
+
+denom = (npts*pc2);
+% denom(c) = N*pc(c) = w(c) = sum_t pp(c,t)
+% since pc(c) = sum_t pp(c,t) / N
+
+cwr_mux = cwr.muX;
+assert(approxeq(mux2, cwr_mux))
+cwr_SigmaX = squeeze(cwr.SigmaX)';
+assert(approxeq(varx2, cwr_SigmaX))
+
+cwr_muy = cwr.muY;
+assert(approxeq(muy2, cwr_muy))
+cwr_SigmaY = squeeze(cwr.SigmaY)';
+assert(approxeq(vary2, cwr_SigmaY))
+
+