wolffd@0
|
1 %DEMGTM1 Demonstrate EM for GTM.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % This script demonstrates the use of the EM algorithm to fit a one-
|
wolffd@0
|
5 % dimensional GTM to a two-dimensional set of data using maximum
|
wolffd@0
|
6 % likelihood. The location and spread of the Gaussian kernels in the
|
wolffd@0
|
7 % data space is shown during training.
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % See also
|
wolffd@0
|
10 % DEMGTM2, GTM, GTMEM, GTMPOST
|
wolffd@0
|
11 %
|
wolffd@0
|
12
|
wolffd@0
|
13 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
14
|
wolffd@0
|
15 % Demonstrates the GTM with a 2D target space and a 1D latent space.
|
wolffd@0
|
16 %
|
wolffd@0
|
17 % This script generates a simple data set in 2 dimensions,
|
wolffd@0
|
18 % with an intrinsic dimensionality of 1, and trains a GTM
|
wolffd@0
|
19 % with a 1-dimensional latent variable to model this data
|
wolffd@0
|
20 % set, visually illustrating the training process
|
wolffd@0
|
21 %
|
wolffd@0
|
22 % Synopsis: gtm_demo
|
wolffd@0
|
23
|
wolffd@0
|
24 % Generate and plot a 2D data set
|
wolffd@0
|
25
|
wolffd@0
|
26 data_min = 0.15;
|
wolffd@0
|
27 data_max = 3.05;
|
wolffd@0
|
28 T = [data_min:0.05:data_max]';
|
wolffd@0
|
29 T = [T (T + 1.25*sin(2*T))];
|
wolffd@0
|
30 fh1 = figure;
|
wolffd@0
|
31 plot(T(:,1), T(:,2), 'ro');
|
wolffd@0
|
32 axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]);
|
wolffd@0
|
33 clc;
|
wolffd@0
|
34 disp('This demonstration shows in detail how the EM algorithm works')
|
wolffd@0
|
35 disp('for training a GTM with a one dimensional latent space.')
|
wolffd@0
|
36 disp(' ')
|
wolffd@0
|
37 fprintf([...
|
wolffd@0
|
38 'The figure shows data generated by feeding a 1D uniform distribution\n', ...
|
wolffd@0
|
39 '(on the X-axis) through a non-linear function (y = x + 1.25*sin(2*x))\n', ...
|
wolffd@0
|
40 '\nPress any key to continue ...\n\n']);
|
wolffd@0
|
41 pause;
|
wolffd@0
|
42
|
wolffd@0
|
43 % Generate a unit circle figure, to be used for plotting
|
wolffd@0
|
44 src = [0:(2*pi)/(20-1):2*pi]';
|
wolffd@0
|
45 unitC = [sin(src) cos(src)];
|
wolffd@0
|
46
|
wolffd@0
|
47 % Generate and plot (along with the data) an initial GTM model
|
wolffd@0
|
48
|
wolffd@0
|
49 clc;
|
wolffd@0
|
50 num_latent_points = 20;
|
wolffd@0
|
51 num_rbf_centres = 5;
|
wolffd@0
|
52
|
wolffd@0
|
53 net = gtm(1, num_latent_points, 2, num_rbf_centres, 'gaussian');
|
wolffd@0
|
54
|
wolffd@0
|
55 options = zeros(1, 18);
|
wolffd@0
|
56 options(7) = 1;
|
wolffd@0
|
57 net = gtminit(net, options, T, 'regular', num_latent_points, ...
|
wolffd@0
|
58 num_rbf_centres);
|
wolffd@0
|
59
|
wolffd@0
|
60 mix = gtmfwd(net);
|
wolffd@0
|
61 % Replot the figure
|
wolffd@0
|
62 hold off;
|
wolffd@0
|
63 plot(mix.centres(:,1), mix.centres(:,2), 'g');
|
wolffd@0
|
64 hold on;
|
wolffd@0
|
65 for i=1:num_latent_points
|
wolffd@0
|
66 c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ...
|
wolffd@0
|
67 ones(num_latent_points,1)*mix.centres(i,2)];
|
wolffd@0
|
68 fill(c(:,1), c(:,2), [0.8 1 0.8]);
|
wolffd@0
|
69 end
|
wolffd@0
|
70 plot(T(:,1), T(:,2), 'ro');
|
wolffd@0
|
71 plot(mix.centres(:,1), mix.centres(:,2), 'g+');
|
wolffd@0
|
72 plot(mix.centres(:,1), mix.centres(:,2), 'g');
|
wolffd@0
|
73 axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]);
|
wolffd@0
|
74 drawnow;
|
wolffd@0
|
75 title('Initial configuration');
|
wolffd@0
|
76 disp(' ')
|
wolffd@0
|
77 fprintf([...
|
wolffd@0
|
78 'The figure shows the starting point for the GTM, before the training.\n', ...
|
wolffd@0
|
79 'A discrete latent variable distribution of %d points in 1 dimension \n', ...
|
wolffd@0
|
80 'is mapped to the 1st principal component of the target data by an RBF.\n', ...
|
wolffd@0
|
81 'with %d basis functions. Each of the %d points defines the centre of\n', ...
|
wolffd@0
|
82 'a Gaussian in a Gaussian mixture, marked by the green ''+''-signs. The\n', ...
|
wolffd@0
|
83 'mixture components all have equal variance, illustrated by the filled\n', ...
|
wolffd@0
|
84 'circle around each ''+''-sign, the radii corresponding to 2 standard\n', ...
|
wolffd@0
|
85 'deviations. The ''+''-signs are connected with a line according to their\n', ...
|
wolffd@0
|
86 'corresponding ordering in latent space.\n\n', ...
|
wolffd@0
|
87 'Press any key to begin training ...\n\n'], num_latent_points, ...
|
wolffd@0
|
88 num_rbf_centres, num_latent_points);
|
wolffd@0
|
89 pause;
|
wolffd@0
|
90
|
wolffd@0
|
91 figure(fh1);
|
wolffd@0
|
92 %%%% Train the GTM and plot it (along with the data) as training proceeds %%%%
|
wolffd@0
|
93 options = foptions;
|
wolffd@0
|
94 options(1) = -1; % Turn off all warning messages
|
wolffd@0
|
95 options(14) = 1;
|
wolffd@0
|
96 for j = 1:15
|
wolffd@0
|
97 [net, options] = gtmem(net, T, options);
|
wolffd@0
|
98 hold off;
|
wolffd@0
|
99 mix = gtmfwd(net);
|
wolffd@0
|
100 plot(mix.centres(:,1), mix.centres(:,2), 'g');
|
wolffd@0
|
101 hold on;
|
wolffd@0
|
102 for i=1:20
|
wolffd@0
|
103 c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ...
|
wolffd@0
|
104 ones(20,1)*mix.centres(i,2)];
|
wolffd@0
|
105 fill(c(:,1), c(:,2), [0.8 1.0 0.8]);
|
wolffd@0
|
106 end
|
wolffd@0
|
107 plot(T(:,1), T(:,2), 'ro');
|
wolffd@0
|
108 plot(mix.centres(:,1), mix.centres(:,2), 'g+');
|
wolffd@0
|
109 plot(mix.centres(:,1), mix.centres(:,2), 'g');
|
wolffd@0
|
110 axis([0 3.5 0 3.5]);
|
wolffd@0
|
111 title(['After ', int2str(j),' iterations of training.']);
|
wolffd@0
|
112 drawnow;
|
wolffd@0
|
113 if (j == 4)
|
wolffd@0
|
114 fprintf([...
|
wolffd@0
|
115 'The GTM initially adapts relatively quickly - already after \n', ...
|
wolffd@0
|
116 '4 iterations of training, a rough fit is attained.\n\n', ...
|
wolffd@0
|
117 'Press any key to continue training ...\n\n']);
|
wolffd@0
|
118 pause;
|
wolffd@0
|
119 figure(fh1);
|
wolffd@0
|
120 elseif (j == 8)
|
wolffd@0
|
121 fprintf([...
|
wolffd@0
|
122 'After another 4 iterations of training: from now on further \n', ...
|
wolffd@0
|
123 'training only makes small changes to the mapping, which combined with \n', ...
|
wolffd@0
|
124 'decrements of the Gaussian mixture variance, optimize the fit in \n', ...
|
wolffd@0
|
125 'terms of likelihood.\n\n', ...
|
wolffd@0
|
126 'Press any key to continue training ...\n\n']);
|
wolffd@0
|
127 pause;
|
wolffd@0
|
128 figure(fh1);
|
wolffd@0
|
129 else
|
wolffd@0
|
130 pause(1);
|
wolffd@0
|
131 end
|
wolffd@0
|
132 end
|
wolffd@0
|
133
|
wolffd@0
|
134 clc;
|
wolffd@0
|
135 fprintf([...
|
wolffd@0
|
136 'After 15 iterations of training the GTM can be regarded as converged. \n', ...
|
wolffd@0
|
137 'Is has been adapted to fit the target data distribution as well \n', ...
|
wolffd@0
|
138 'as possible, given prior smoothness constraints on the mapping. It \n', ...
|
wolffd@0
|
139 'captures the fact that the probabilty density is higher at the two \n', ...
|
wolffd@0
|
140 'bends of the curve, and lower towards its end points.\n\n']);
|
wolffd@0
|
141 disp(' ');
|
wolffd@0
|
142 disp('Press any key to exit.');
|
wolffd@0
|
143 pause;
|
wolffd@0
|
144
|
wolffd@0
|
145 close(fh1);
|
wolffd@0
|
146 clear all;
|
wolffd@0
|
147
|