view toolboxes/FullBNT-1.0.7/KPMstats/cwr_demo.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
% Compare my code with
% http://www.media.mit.edu/physics/publications/books/nmm/files/index.html
%
% cwm.m
% (c) Neil Gershenfeld  9/1/97
% 1D Cluster-Weighted Modeling example
%
clear all
figure;
seed = 0;
rand('state', seed);
randn('state', seed);
x = (-10:10)';
y = (x > 0);
npts = length(x);
plot(x,y,'+')
xlabel('x')
ylabel('y')
nclusters = 4;
nplot = 100;
xplot = 24*(1:nplot)'/nplot - 12;

mux = 20*rand(1,nclusters) - 10;
muy = zeros(1,nclusters);
varx = ones(1,nclusters);
vary = ones(1,nclusters);
pc = 1/nclusters * ones(1,nclusters);
niterations = 5;
eps = 0.01;
  

I = repmat(eye(1,1), [1 1 nclusters]);
O = repmat(zeros(1,1), [1 1 nclusters]);
X = x(:)';
Y = y(:)';

cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy,  'SigmaX', I, ...
	     'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', ...
	     'priorC', pc, 'weightsY', O,  'create_init_params', 0, ...
	     'clamp_weights', 1, 'max_iter', niterations, ...
	     'cov_priorX', eps*ones(1,1,nclusters), ...
	     'cov_priorY', eps*ones(1,1,nclusters));


% Gershenfeld's EM code
for step = 1:niterations
    pplot = exp(-(kron(xplot,ones(1,nclusters)) ...
		  - kron(ones(nplot,1),mux)).^2 ...
		./ (2*kron(ones(nplot,1),varx))) ...
	    ./ sqrt(2*pi*kron(ones(nplot,1),varx)) ...
	    .* kron(ones(nplot,1),pc);
    plot(xplot,pplot,'k');
    pause(0);
    px = exp(-(kron(x,ones(1,nclusters)) ...
	       - kron(ones(npts,1),mux)).^2 ...
	     ./ (2*kron(ones(npts,1),varx))) ...
	 ./ sqrt(2*pi*kron(ones(npts,1),varx));
    py = exp(-(kron(y,ones(1,nclusters)) ...
	       - kron(ones(npts,1),muy)).^2 ...
	     ./ (2*kron(ones(npts,1),vary))) ...
	 ./ sqrt(2*pi*kron(ones(npts,1),vary));
    p = px .* py .* kron(ones(npts,1),pc);
    pp = p ./ kron(sum(p,2),ones(1,nclusters));
    pc = sum(pp)/npts;
    yfit = sum(kron(ones(npts,1),muy) .* p,2) ...
	   ./ sum(p,2);
    mux = sum(kron(x,ones(1,nclusters)) .* pp) ...
	  ./ (npts*pc);
    varx = eps + sum((kron(x,ones(1,nclusters)) ...
		      - kron(ones(npts,1),mux)).^2 .* pp) ...
	   ./ (npts*pc);
    muy = sum(kron(y,ones(1,nclusters)) .* pp) ...
	  ./ (npts*pc);
    vary = eps + sum((kron(y,ones(1,nclusters)) ...
		      - kron(ones(npts,1),muy)).^2 .* pp) ...
	   ./ (npts*pc);
end


% Check equal
cwr_pc = cwr.priorC';
assert(approxeq(cwr_pc, pc))
cwr_mux = cwr.muX;
assert(approxeq(mux, cwr_mux))
cwr_SigmaX = squeeze(cwr.SigmaX)';
assert(approxeq(varx, cwr_SigmaX))
cwr_muy = cwr.muY;
assert(approxeq(muy, cwr_muy))
cwr_SigmaY = squeeze(cwr.SigmaY)';
assert(approxeq(vary, cwr_SigmaY))


% Prediction

X = xplot(:)';
[cwr_mu, Sigma, post] = cwr_predict(cwr, X);
cwr_ystd = squeeze(Sigma)';

% pplot(t,c)
pplot = exp(-(kron(xplot,ones(1,nclusters)) ...
   - kron(ones(nplot,1),mux)).^2 ...
   ./ (2*kron(ones(nplot,1),varx))) ...
   ./ sqrt(2*pi*kron(ones(nplot,1),varx)) ...
   .* kron(ones(nplot,1),pc);
yplot = sum(kron(ones(nplot,1),muy) .* pplot,2) ...
   ./ sum(pplot,2);
ystdplot = sum(kron(ones(nplot,1),(muy.^2+vary)) .* pplot,2) ...
   ./ sum(pplot,2) - yplot.^2;


% Check equal
assert(approxeq(yplot(:)', cwr_mu(:)'))
assert(approxeq(ystdplot, cwr_ystd))
assert(approxeq(pplot ./ repmat(sum(pplot,2), 1, nclusters),post') )

plot(xplot,yplot,'k');
hold on
plot(xplot,yplot+ystdplot,'k--');
plot(xplot,yplot-ystdplot,'k--');
plot(x,y,'k+');
axis([-12 12 -1 1.1]);
plot(xplot,.8*pplot/max(max(pplot))-1,'k')
hold off