wolffd@0
|
1 % Compare my code with
|
wolffd@0
|
2 % http://www.media.mit.edu/physics/publications/books/nmm/files/index.html
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % cwm.m
|
wolffd@0
|
5 % (c) Neil Gershenfeld 9/1/97
|
wolffd@0
|
6 % 1D Cluster-Weighted Modeling example
|
wolffd@0
|
7 %
|
wolffd@0
|
8 clear all
|
wolffd@0
|
9 figure;
|
wolffd@0
|
10 seed = 0;
|
wolffd@0
|
11 rand('state', seed);
|
wolffd@0
|
12 randn('state', seed);
|
wolffd@0
|
13 x = (-10:10)';
|
wolffd@0
|
14 y = (x > 0);
|
wolffd@0
|
15 npts = length(x);
|
wolffd@0
|
16 plot(x,y,'+')
|
wolffd@0
|
17 xlabel('x')
|
wolffd@0
|
18 ylabel('y')
|
wolffd@0
|
19 nclusters = 4;
|
wolffd@0
|
20 nplot = 100;
|
wolffd@0
|
21 xplot = 24*(1:nplot)'/nplot - 12;
|
wolffd@0
|
22
|
wolffd@0
|
23 mux = 20*rand(1,nclusters) - 10;
|
wolffd@0
|
24 muy = zeros(1,nclusters);
|
wolffd@0
|
25 varx = ones(1,nclusters);
|
wolffd@0
|
26 vary = ones(1,nclusters);
|
wolffd@0
|
27 pc = 1/nclusters * ones(1,nclusters);
|
wolffd@0
|
28 niterations = 5;
|
wolffd@0
|
29 eps = 0.01;
|
wolffd@0
|
30
|
wolffd@0
|
31
|
wolffd@0
|
32 I = repmat(eye(1,1), [1 1 nclusters]);
|
wolffd@0
|
33 O = repmat(zeros(1,1), [1 1 nclusters]);
|
wolffd@0
|
34 X = x(:)';
|
wolffd@0
|
35 Y = y(:)';
|
wolffd@0
|
36
|
wolffd@0
|
37 cwr = cwr_em(X, Y, nclusters, 'muX', mux, 'muY', muy, 'SigmaX', I, ...
|
wolffd@0
|
38 'cov_typeX', 'spherical', 'SigmaY', I, 'cov_typeY', 'spherical', ...
|
wolffd@0
|
39 'priorC', pc, 'weightsY', O, 'create_init_params', 0, ...
|
wolffd@0
|
40 'clamp_weights', 1, 'max_iter', niterations, ...
|
wolffd@0
|
41 'cov_priorX', eps*ones(1,1,nclusters), ...
|
wolffd@0
|
42 'cov_priorY', eps*ones(1,1,nclusters));
|
wolffd@0
|
43
|
wolffd@0
|
44
|
wolffd@0
|
45 % Gershenfeld's EM code
|
wolffd@0
|
46 for step = 1:niterations
|
wolffd@0
|
47 pplot = exp(-(kron(xplot,ones(1,nclusters)) ...
|
wolffd@0
|
48 - kron(ones(nplot,1),mux)).^2 ...
|
wolffd@0
|
49 ./ (2*kron(ones(nplot,1),varx))) ...
|
wolffd@0
|
50 ./ sqrt(2*pi*kron(ones(nplot,1),varx)) ...
|
wolffd@0
|
51 .* kron(ones(nplot,1),pc);
|
wolffd@0
|
52 plot(xplot,pplot,'k');
|
wolffd@0
|
53 pause(0);
|
wolffd@0
|
54 px = exp(-(kron(x,ones(1,nclusters)) ...
|
wolffd@0
|
55 - kron(ones(npts,1),mux)).^2 ...
|
wolffd@0
|
56 ./ (2*kron(ones(npts,1),varx))) ...
|
wolffd@0
|
57 ./ sqrt(2*pi*kron(ones(npts,1),varx));
|
wolffd@0
|
58 py = exp(-(kron(y,ones(1,nclusters)) ...
|
wolffd@0
|
59 - kron(ones(npts,1),muy)).^2 ...
|
wolffd@0
|
60 ./ (2*kron(ones(npts,1),vary))) ...
|
wolffd@0
|
61 ./ sqrt(2*pi*kron(ones(npts,1),vary));
|
wolffd@0
|
62 p = px .* py .* kron(ones(npts,1),pc);
|
wolffd@0
|
63 pp = p ./ kron(sum(p,2),ones(1,nclusters));
|
wolffd@0
|
64 pc = sum(pp)/npts;
|
wolffd@0
|
65 yfit = sum(kron(ones(npts,1),muy) .* p,2) ...
|
wolffd@0
|
66 ./ sum(p,2);
|
wolffd@0
|
67 mux = sum(kron(x,ones(1,nclusters)) .* pp) ...
|
wolffd@0
|
68 ./ (npts*pc);
|
wolffd@0
|
69 varx = eps + sum((kron(x,ones(1,nclusters)) ...
|
wolffd@0
|
70 - kron(ones(npts,1),mux)).^2 .* pp) ...
|
wolffd@0
|
71 ./ (npts*pc);
|
wolffd@0
|
72 muy = sum(kron(y,ones(1,nclusters)) .* pp) ...
|
wolffd@0
|
73 ./ (npts*pc);
|
wolffd@0
|
74 vary = eps + sum((kron(y,ones(1,nclusters)) ...
|
wolffd@0
|
75 - kron(ones(npts,1),muy)).^2 .* pp) ...
|
wolffd@0
|
76 ./ (npts*pc);
|
wolffd@0
|
77 end
|
wolffd@0
|
78
|
wolffd@0
|
79
|
wolffd@0
|
80 % Check equal
|
wolffd@0
|
81 cwr_pc = cwr.priorC';
|
wolffd@0
|
82 assert(approxeq(cwr_pc, pc))
|
wolffd@0
|
83 cwr_mux = cwr.muX;
|
wolffd@0
|
84 assert(approxeq(mux, cwr_mux))
|
wolffd@0
|
85 cwr_SigmaX = squeeze(cwr.SigmaX)';
|
wolffd@0
|
86 assert(approxeq(varx, cwr_SigmaX))
|
wolffd@0
|
87 cwr_muy = cwr.muY;
|
wolffd@0
|
88 assert(approxeq(muy, cwr_muy))
|
wolffd@0
|
89 cwr_SigmaY = squeeze(cwr.SigmaY)';
|
wolffd@0
|
90 assert(approxeq(vary, cwr_SigmaY))
|
wolffd@0
|
91
|
wolffd@0
|
92
|
wolffd@0
|
93 % Prediction
|
wolffd@0
|
94
|
wolffd@0
|
95 X = xplot(:)';
|
wolffd@0
|
96 [cwr_mu, Sigma, post] = cwr_predict(cwr, X);
|
wolffd@0
|
97 cwr_ystd = squeeze(Sigma)';
|
wolffd@0
|
98
|
wolffd@0
|
99 % pplot(t,c)
|
wolffd@0
|
100 pplot = exp(-(kron(xplot,ones(1,nclusters)) ...
|
wolffd@0
|
101 - kron(ones(nplot,1),mux)).^2 ...
|
wolffd@0
|
102 ./ (2*kron(ones(nplot,1),varx))) ...
|
wolffd@0
|
103 ./ sqrt(2*pi*kron(ones(nplot,1),varx)) ...
|
wolffd@0
|
104 .* kron(ones(nplot,1),pc);
|
wolffd@0
|
105 yplot = sum(kron(ones(nplot,1),muy) .* pplot,2) ...
|
wolffd@0
|
106 ./ sum(pplot,2);
|
wolffd@0
|
107 ystdplot = sum(kron(ones(nplot,1),(muy.^2+vary)) .* pplot,2) ...
|
wolffd@0
|
108 ./ sum(pplot,2) - yplot.^2;
|
wolffd@0
|
109
|
wolffd@0
|
110
|
wolffd@0
|
111 % Check equal
|
wolffd@0
|
112 assert(approxeq(yplot(:)', cwr_mu(:)'))
|
wolffd@0
|
113 assert(approxeq(ystdplot, cwr_ystd))
|
wolffd@0
|
114 assert(approxeq(pplot ./ repmat(sum(pplot,2), 1, nclusters),post') )
|
wolffd@0
|
115
|
wolffd@0
|
116 plot(xplot,yplot,'k');
|
wolffd@0
|
117 hold on
|
wolffd@0
|
118 plot(xplot,yplot+ystdplot,'k--');
|
wolffd@0
|
119 plot(xplot,yplot-ystdplot,'k--');
|
wolffd@0
|
120 plot(x,y,'k+');
|
wolffd@0
|
121 axis([-12 12 -1 1.1]);
|
wolffd@0
|
122 plot(xplot,.8*pplot/max(max(pplot))-1,'k')
|
wolffd@0
|
123 hold off
|
wolffd@0
|
124
|