Daniel@0: % Compare my code with Daniel@0: % http://www.media.mit.edu/physics/publications/books/nmm/files/index.html Daniel@0: % Daniel@0: % cwm.m Daniel@0: % (c) Neil Gershenfeld 9/1/97 Daniel@0: % 1D Cluster-Weighted Modeling example Daniel@0: % Daniel@0: clear all Daniel@0: figure; Daniel@0: seed = 0; Daniel@0: rand('state', seed); Daniel@0: randn('state', seed); Daniel@0: x = (-10:10)'; Daniel@0: y = (x > 0); Daniel@0: npts = length(x); Daniel@0: plot(x,y,'+') Daniel@0: xlabel('x') Daniel@0: ylabel('y') Daniel@0: nclusters = 4; Daniel@0: nplot = 100; Daniel@0: xplot = 24*(1:nplot)'/nplot - 12; Daniel@0: Daniel@0: mux = 20*rand(1,nclusters) - 10; Daniel@0: muy = zeros(1,nclusters); Daniel@0: varx = ones(1,nclusters); Daniel@0: vary = ones(1,nclusters); Daniel@0: pc = 1/nclusters * ones(1,nclusters); Daniel@0: niterations = 5; Daniel@0: eps = 0.01; Daniel@0: Daniel@0: Daniel@0: I = repmat(eye(1,1), [1 1 nclusters]); Daniel@0: O = repmat(zeros(1,1), [1 1 nclusters]); Daniel@0: X = x(:)'; Daniel@0: Y = y(:)'; Daniel@0: Daniel@0: cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy, 'SigmaX', I, ... Daniel@0: 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', ... Daniel@0: 'priorC', pc, 'weightsY', O, 'create_init_params', 0, ... Daniel@0: 'clamp_weights', 1, 'max_iter', niterations, ... Daniel@0: 'cov_priorX', eps*ones(1,1,nclusters), ... Daniel@0: 'cov_priorY', eps*ones(1,1,nclusters)); Daniel@0: Daniel@0: Daniel@0: % Gershenfeld's EM code Daniel@0: for step = 1:niterations Daniel@0: pplot = exp(-(kron(xplot,ones(1,nclusters)) ... Daniel@0: - kron(ones(nplot,1),mux)).^2 ... Daniel@0: ./ (2*kron(ones(nplot,1),varx))) ... Daniel@0: ./ sqrt(2*pi*kron(ones(nplot,1),varx)) ... Daniel@0: .* kron(ones(nplot,1),pc); Daniel@0: plot(xplot,pplot,'k'); Daniel@0: pause(0); Daniel@0: px = exp(-(kron(x,ones(1,nclusters)) ... Daniel@0: - kron(ones(npts,1),mux)).^2 ... Daniel@0: ./ (2*kron(ones(npts,1),varx))) ... Daniel@0: ./ sqrt(2*pi*kron(ones(npts,1),varx)); Daniel@0: py = exp(-(kron(y,ones(1,nclusters)) ... Daniel@0: - kron(ones(npts,1),muy)).^2 ... Daniel@0: ./ (2*kron(ones(npts,1),vary))) ... Daniel@0: ./ sqrt(2*pi*kron(ones(npts,1),vary)); Daniel@0: p = px .* py .* kron(ones(npts,1),pc); Daniel@0: pp = p ./ kron(sum(p,2),ones(1,nclusters)); Daniel@0: pc = sum(pp)/npts; Daniel@0: yfit = sum(kron(ones(npts,1),muy) .* p,2) ... Daniel@0: ./ sum(p,2); Daniel@0: mux = sum(kron(x,ones(1,nclusters)) .* pp) ... Daniel@0: ./ (npts*pc); Daniel@0: varx = eps + sum((kron(x,ones(1,nclusters)) ... Daniel@0: - kron(ones(npts,1),mux)).^2 .* pp) ... Daniel@0: ./ (npts*pc); Daniel@0: muy = sum(kron(y,ones(1,nclusters)) .* pp) ... Daniel@0: ./ (npts*pc); Daniel@0: vary = eps + sum((kron(y,ones(1,nclusters)) ... Daniel@0: - kron(ones(npts,1),muy)).^2 .* pp) ... Daniel@0: ./ (npts*pc); Daniel@0: end Daniel@0: Daniel@0: Daniel@0: % Check equal Daniel@0: cwr_pc = cwr.priorC'; Daniel@0: assert(approxeq(cwr_pc, pc)) Daniel@0: cwr_mux = cwr.muX; Daniel@0: assert(approxeq(mux, cwr_mux)) Daniel@0: cwr_SigmaX = squeeze(cwr.SigmaX)'; Daniel@0: assert(approxeq(varx, cwr_SigmaX)) Daniel@0: cwr_muy = cwr.muY; Daniel@0: assert(approxeq(muy, cwr_muy)) Daniel@0: cwr_SigmaY = squeeze(cwr.SigmaY)'; Daniel@0: assert(approxeq(vary, cwr_SigmaY)) Daniel@0: Daniel@0: Daniel@0: % Prediction Daniel@0: Daniel@0: X = xplot(:)'; Daniel@0: [cwr_mu, Sigma, post] = cwr_predict(cwr, X); Daniel@0: cwr_ystd = squeeze(Sigma)'; Daniel@0: Daniel@0: % pplot(t,c) Daniel@0: pplot = exp(-(kron(xplot,ones(1,nclusters)) ... Daniel@0: - kron(ones(nplot,1),mux)).^2 ... Daniel@0: ./ (2*kron(ones(nplot,1),varx))) ... Daniel@0: ./ sqrt(2*pi*kron(ones(nplot,1),varx)) ... Daniel@0: .* kron(ones(nplot,1),pc); Daniel@0: yplot = sum(kron(ones(nplot,1),muy) .* pplot,2) ... Daniel@0: ./ sum(pplot,2); Daniel@0: ystdplot = sum(kron(ones(nplot,1),(muy.^2+vary)) .* pplot,2) ... Daniel@0: ./ sum(pplot,2) - yplot.^2; Daniel@0: Daniel@0: Daniel@0: % Check equal Daniel@0: assert(approxeq(yplot(:)', cwr_mu(:)')) Daniel@0: assert(approxeq(ystdplot, cwr_ystd)) Daniel@0: assert(approxeq(pplot ./ repmat(sum(pplot,2), 1, nclusters),post') ) Daniel@0: Daniel@0: plot(xplot,yplot,'k'); Daniel@0: hold on Daniel@0: plot(xplot,yplot+ystdplot,'k--'); Daniel@0: plot(xplot,yplot-ystdplot,'k--'); Daniel@0: plot(x,y,'k+'); Daniel@0: axis([-12 12 -1 1.1]); Daniel@0: plot(xplot,.8*pplot/max(max(pplot))-1,'k') Daniel@0: hold off Daniel@0: