comparison toolboxes/FullBNT-1.0.7/netlab3.3/demgtm1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMGTM1 Demonstrate EM for GTM.
2 %
3 % Description
4 % This script demonstrates the use of the EM algorithm to fit a one-
5 % dimensional GTM to a two-dimensional set of data using maximum
6 % likelihood. The location and spread of the Gaussian kernels in the
7 % data space is shown during training.
8 %
9 % See also
10 % DEMGTM2, GTM, GTMEM, GTMPOST
11 %
12
13 % Copyright (c) Ian T Nabney (1996-2001)
14
15 % Demonstrates the GTM with a 2D target space and a 1D latent space.
16 %
17 % This script generates a simple data set in 2 dimensions,
18 % with an intrinsic dimensionality of 1, and trains a GTM
19 % with a 1-dimensional latent variable to model this data
20 % set, visually illustrating the training process
21 %
22 % Synopsis: gtm_demo
23
24 % Generate and plot a 2D data set
25
26 data_min = 0.15;
27 data_max = 3.05;
28 T = [data_min:0.05:data_max]';
29 T = [T (T + 1.25*sin(2*T))];
30 fh1 = figure;
31 plot(T(:,1), T(:,2), 'ro');
32 axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]);
33 clc;
34 disp('This demonstration shows in detail how the EM algorithm works')
35 disp('for training a GTM with a one dimensional latent space.')
36 disp(' ')
37 fprintf([...
38 'The figure shows data generated by feeding a 1D uniform distribution\n', ...
39 '(on the X-axis) through a non-linear function (y = x + 1.25*sin(2*x))\n', ...
40 '\nPress any key to continue ...\n\n']);
41 pause;
42
43 % Generate a unit circle figure, to be used for plotting
44 src = [0:(2*pi)/(20-1):2*pi]';
45 unitC = [sin(src) cos(src)];
46
47 % Generate and plot (along with the data) an initial GTM model
48
49 clc;
50 num_latent_points = 20;
51 num_rbf_centres = 5;
52
53 net = gtm(1, num_latent_points, 2, num_rbf_centres, 'gaussian');
54
55 options = zeros(1, 18);
56 options(7) = 1;
57 net = gtminit(net, options, T, 'regular', num_latent_points, ...
58 num_rbf_centres);
59
60 mix = gtmfwd(net);
61 % Replot the figure
62 hold off;
63 plot(mix.centres(:,1), mix.centres(:,2), 'g');
64 hold on;
65 for i=1:num_latent_points
66 c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ...
67 ones(num_latent_points,1)*mix.centres(i,2)];
68 fill(c(:,1), c(:,2), [0.8 1 0.8]);
69 end
70 plot(T(:,1), T(:,2), 'ro');
71 plot(mix.centres(:,1), mix.centres(:,2), 'g+');
72 plot(mix.centres(:,1), mix.centres(:,2), 'g');
73 axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]);
74 drawnow;
75 title('Initial configuration');
76 disp(' ')
77 fprintf([...
78 'The figure shows the starting point for the GTM, before the training.\n', ...
79 'A discrete latent variable distribution of %d points in 1 dimension \n', ...
80 'is mapped to the 1st principal component of the target data by an RBF.\n', ...
81 'with %d basis functions. Each of the %d points defines the centre of\n', ...
82 'a Gaussian in a Gaussian mixture, marked by the green ''+''-signs. The\n', ...
83 'mixture components all have equal variance, illustrated by the filled\n', ...
84 'circle around each ''+''-sign, the radii corresponding to 2 standard\n', ...
85 'deviations. The ''+''-signs are connected with a line according to their\n', ...
86 'corresponding ordering in latent space.\n\n', ...
87 'Press any key to begin training ...\n\n'], num_latent_points, ...
88 num_rbf_centres, num_latent_points);
89 pause;
90
91 figure(fh1);
92 %%%% Train the GTM and plot it (along with the data) as training proceeds %%%%
93 options = foptions;
94 options(1) = -1; % Turn off all warning messages
95 options(14) = 1;
96 for j = 1:15
97 [net, options] = gtmem(net, T, options);
98 hold off;
99 mix = gtmfwd(net);
100 plot(mix.centres(:,1), mix.centres(:,2), 'g');
101 hold on;
102 for i=1:20
103 c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ...
104 ones(20,1)*mix.centres(i,2)];
105 fill(c(:,1), c(:,2), [0.8 1.0 0.8]);
106 end
107 plot(T(:,1), T(:,2), 'ro');
108 plot(mix.centres(:,1), mix.centres(:,2), 'g+');
109 plot(mix.centres(:,1), mix.centres(:,2), 'g');
110 axis([0 3.5 0 3.5]);
111 title(['After ', int2str(j),' iterations of training.']);
112 drawnow;
113 if (j == 4)
114 fprintf([...
115 'The GTM initially adapts relatively quickly - already after \n', ...
116 '4 iterations of training, a rough fit is attained.\n\n', ...
117 'Press any key to continue training ...\n\n']);
118 pause;
119 figure(fh1);
120 elseif (j == 8)
121 fprintf([...
122 'After another 4 iterations of training: from now on further \n', ...
123 'training only makes small changes to the mapping, which combined with \n', ...
124 'decrements of the Gaussian mixture variance, optimize the fit in \n', ...
125 'terms of likelihood.\n\n', ...
126 'Press any key to continue training ...\n\n']);
127 pause;
128 figure(fh1);
129 else
130 pause(1);
131 end
132 end
133
134 clc;
135 fprintf([...
136 'After 15 iterations of training the GTM can be regarded as converged. \n', ...
137 'Is has been adapted to fit the target data distribution as well \n', ...
138 'as possible, given prior smoothness constraints on the mapping. It \n', ...
139 'captures the fact that the probabilty density is higher at the two \n', ...
140 'bends of the curve, and lower towards its end points.\n\n']);
141 disp(' ');
142 disp('Press any key to exit.');
143 pause;
144
145 close(fh1);
146 clear all;
147