Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/KPMstats/cwr_demo.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % Compare my code with | |
2 % http://www.media.mit.edu/physics/publications/books/nmm/files/index.html | |
3 % | |
4 % cwm.m | |
5 % (c) Neil Gershenfeld 9/1/97 | |
6 % 1D Cluster-Weighted Modeling example | |
7 % | |
8 clear all | |
9 figure; | |
10 seed = 0; | |
11 rand('state', seed); | |
12 randn('state', seed); | |
13 x = (-10:10)'; | |
14 y = (x > 0); | |
15 npts = length(x); | |
16 plot(x,y,'+') | |
17 xlabel('x') | |
18 ylabel('y') | |
19 nclusters = 4; | |
20 nplot = 100; | |
21 xplot = 24*(1:nplot)'/nplot - 12; | |
22 | |
23 mux = 20*rand(1,nclusters) - 10; | |
24 muy = zeros(1,nclusters); | |
25 varx = ones(1,nclusters); | |
26 vary = ones(1,nclusters); | |
27 pc = 1/nclusters * ones(1,nclusters); | |
28 niterations = 5; | |
29 eps = 0.01; | |
30 | |
31 | |
32 I = repmat(eye(1,1), [1 1 nclusters]); | |
33 O = repmat(zeros(1,1), [1 1 nclusters]); | |
34 X = x(:)'; | |
35 Y = y(:)'; | |
36 | |
37 cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy, 'SigmaX', I, ... | |
38 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', ... | |
39 'priorC', pc, 'weightsY', O, 'create_init_params', 0, ... | |
40 'clamp_weights', 1, 'max_iter', niterations, ... | |
41 'cov_priorX', eps*ones(1,1,nclusters), ... | |
42 'cov_priorY', eps*ones(1,1,nclusters)); | |
43 | |
44 | |
45 % Gershenfeld's EM code | |
46 for step = 1:niterations | |
47 pplot = exp(-(kron(xplot,ones(1,nclusters)) ... | |
48 - kron(ones(nplot,1),mux)).^2 ... | |
49 ./ (2*kron(ones(nplot,1),varx))) ... | |
50 ./ sqrt(2*pi*kron(ones(nplot,1),varx)) ... | |
51 .* kron(ones(nplot,1),pc); | |
52 plot(xplot,pplot,'k'); | |
53 pause(0); | |
54 px = exp(-(kron(x,ones(1,nclusters)) ... | |
55 - kron(ones(npts,1),mux)).^2 ... | |
56 ./ (2*kron(ones(npts,1),varx))) ... | |
57 ./ sqrt(2*pi*kron(ones(npts,1),varx)); | |
58 py = exp(-(kron(y,ones(1,nclusters)) ... | |
59 - kron(ones(npts,1),muy)).^2 ... | |
60 ./ (2*kron(ones(npts,1),vary))) ... | |
61 ./ sqrt(2*pi*kron(ones(npts,1),vary)); | |
62 p = px .* py .* kron(ones(npts,1),pc); | |
63 pp = p ./ kron(sum(p,2),ones(1,nclusters)); | |
64 pc = sum(pp)/npts; | |
65 yfit = sum(kron(ones(npts,1),muy) .* p,2) ... | |
66 ./ sum(p,2); | |
67 mux = sum(kron(x,ones(1,nclusters)) .* pp) ... | |
68 ./ (npts*pc); | |
69 varx = eps + sum((kron(x,ones(1,nclusters)) ... | |
70 - kron(ones(npts,1),mux)).^2 .* pp) ... | |
71 ./ (npts*pc); | |
72 muy = sum(kron(y,ones(1,nclusters)) .* pp) ... | |
73 ./ (npts*pc); | |
74 vary = eps + sum((kron(y,ones(1,nclusters)) ... | |
75 - kron(ones(npts,1),muy)).^2 .* pp) ... | |
76 ./ (npts*pc); | |
77 end | |
78 | |
79 | |
80 % Check equal | |
81 cwr_pc = cwr.priorC'; | |
82 assert(approxeq(cwr_pc, pc)) | |
83 cwr_mux = cwr.muX; | |
84 assert(approxeq(mux, cwr_mux)) | |
85 cwr_SigmaX = squeeze(cwr.SigmaX)'; | |
86 assert(approxeq(varx, cwr_SigmaX)) | |
87 cwr_muy = cwr.muY; | |
88 assert(approxeq(muy, cwr_muy)) | |
89 cwr_SigmaY = squeeze(cwr.SigmaY)'; | |
90 assert(approxeq(vary, cwr_SigmaY)) | |
91 | |
92 | |
93 % Prediction | |
94 | |
95 X = xplot(:)'; | |
96 [cwr_mu, Sigma, post] = cwr_predict(cwr, X); | |
97 cwr_ystd = squeeze(Sigma)'; | |
98 | |
99 % pplot(t,c) | |
100 pplot = exp(-(kron(xplot,ones(1,nclusters)) ... | |
101 - kron(ones(nplot,1),mux)).^2 ... | |
102 ./ (2*kron(ones(nplot,1),varx))) ... | |
103 ./ sqrt(2*pi*kron(ones(nplot,1),varx)) ... | |
104 .* kron(ones(nplot,1),pc); | |
105 yplot = sum(kron(ones(nplot,1),muy) .* pplot,2) ... | |
106 ./ sum(pplot,2); | |
107 ystdplot = sum(kron(ones(nplot,1),(muy.^2+vary)) .* pplot,2) ... | |
108 ./ sum(pplot,2) - yplot.^2; | |
109 | |
110 | |
111 % Check equal | |
112 assert(approxeq(yplot(:)', cwr_mu(:)')) | |
113 assert(approxeq(ystdplot, cwr_ystd)) | |
114 assert(approxeq(pplot ./ repmat(sum(pplot,2), 1, nclusters),post') ) | |
115 | |
116 plot(xplot,yplot,'k'); | |
117 hold on | |
118 plot(xplot,yplot+ystdplot,'k--'); | |
119 plot(xplot,yplot-ystdplot,'k--'); | |
120 plot(x,y,'k+'); | |
121 axis([-12 12 -1 1.1]); | |
122 plot(xplot,.8*pplot/max(max(pplot))-1,'k') | |
123 hold off | |
124 |