Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/demgtm1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 %DEMGTM1 Demonstrate EM for GTM. | |
2 % | |
3 % Description | |
4 % This script demonstrates the use of the EM algorithm to fit a one- | |
5 % dimensional GTM to a two-dimensional set of data using maximum | |
6 % likelihood. The location and spread of the Gaussian kernels in the | |
7 % data space is shown during training. | |
8 % | |
9 % See also | |
10 % DEMGTM2, GTM, GTMEM, GTMPOST | |
11 % | |
12 | |
13 % Copyright (c) Ian T Nabney (1996-2001) | |
14 | |
15 % Demonstrates the GTM with a 2D target space and a 1D latent space. | |
16 % | |
17 % This script generates a simple data set in 2 dimensions, | |
18 % with an intrinsic dimensionality of 1, and trains a GTM | |
19 % with a 1-dimensional latent variable to model this data | |
20 % set, visually illustrating the training process | |
21 % | |
22 % Synopsis: gtm_demo | |
23 | |
24 % Generate and plot a 2D data set | |
25 | |
26 data_min = 0.15; | |
27 data_max = 3.05; | |
28 T = [data_min:0.05:data_max]'; | |
29 T = [T (T + 1.25*sin(2*T))]; | |
30 fh1 = figure; | |
31 plot(T(:,1), T(:,2), 'ro'); | |
32 axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]); | |
33 clc; | |
34 disp('This demonstration shows in detail how the EM algorithm works') | |
35 disp('for training a GTM with a one dimensional latent space.') | |
36 disp(' ') | |
37 fprintf([... | |
38 'The figure shows data generated by feeding a 1D uniform distribution\n', ... | |
39 '(on the X-axis) through a non-linear function (y = x + 1.25*sin(2*x))\n', ... | |
40 '\nPress any key to continue ...\n\n']); | |
41 pause; | |
42 | |
43 % Generate a unit circle figure, to be used for plotting | |
44 src = [0:(2*pi)/(20-1):2*pi]'; | |
45 unitC = [sin(src) cos(src)]; | |
46 | |
47 % Generate and plot (along with the data) an initial GTM model | |
48 | |
49 clc; | |
50 num_latent_points = 20; | |
51 num_rbf_centres = 5; | |
52 | |
53 net = gtm(1, num_latent_points, 2, num_rbf_centres, 'gaussian'); | |
54 | |
55 options = zeros(1, 18); | |
56 options(7) = 1; | |
57 net = gtminit(net, options, T, 'regular', num_latent_points, ... | |
58 num_rbf_centres); | |
59 | |
60 mix = gtmfwd(net); | |
61 % Replot the figure | |
62 hold off; | |
63 plot(mix.centres(:,1), mix.centres(:,2), 'g'); | |
64 hold on; | |
65 for i=1:num_latent_points | |
66 c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ... | |
67 ones(num_latent_points,1)*mix.centres(i,2)]; | |
68 fill(c(:,1), c(:,2), [0.8 1 0.8]); | |
69 end | |
70 plot(T(:,1), T(:,2), 'ro'); | |
71 plot(mix.centres(:,1), mix.centres(:,2), 'g+'); | |
72 plot(mix.centres(:,1), mix.centres(:,2), 'g'); | |
73 axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]); | |
74 drawnow; | |
75 title('Initial configuration'); | |
76 disp(' ') | |
77 fprintf([... | |
78 'The figure shows the starting point for the GTM, before the training.\n', ... | |
79 'A discrete latent variable distribution of %d points in 1 dimension \n', ... | |
80 'is mapped to the 1st principal component of the target data by an RBF.\n', ... | |
81 'with %d basis functions. Each of the %d points defines the centre of\n', ... | |
82 'a Gaussian in a Gaussian mixture, marked by the green ''+''-signs. The\n', ... | |
83 'mixture components all have equal variance, illustrated by the filled\n', ... | |
84 'circle around each ''+''-sign, the radii corresponding to 2 standard\n', ... | |
85 'deviations. The ''+''-signs are connected with a line according to their\n', ... | |
86 'corresponding ordering in latent space.\n\n', ... | |
87 'Press any key to begin training ...\n\n'], num_latent_points, ... | |
88 num_rbf_centres, num_latent_points); | |
89 pause; | |
90 | |
91 figure(fh1); | |
92 %%%% Train the GTM and plot it (along with the data) as training proceeds %%%% | |
93 options = foptions; | |
94 options(1) = -1; % Turn off all warning messages | |
95 options(14) = 1; | |
96 for j = 1:15 | |
97 [net, options] = gtmem(net, T, options); | |
98 hold off; | |
99 mix = gtmfwd(net); | |
100 plot(mix.centres(:,1), mix.centres(:,2), 'g'); | |
101 hold on; | |
102 for i=1:20 | |
103 c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ... | |
104 ones(20,1)*mix.centres(i,2)]; | |
105 fill(c(:,1), c(:,2), [0.8 1.0 0.8]); | |
106 end | |
107 plot(T(:,1), T(:,2), 'ro'); | |
108 plot(mix.centres(:,1), mix.centres(:,2), 'g+'); | |
109 plot(mix.centres(:,1), mix.centres(:,2), 'g'); | |
110 axis([0 3.5 0 3.5]); | |
111 title(['After ', int2str(j),' iterations of training.']); | |
112 drawnow; | |
113 if (j == 4) | |
114 fprintf([... | |
115 'The GTM initially adapts relatively quickly - already after \n', ... | |
116 '4 iterations of training, a rough fit is attained.\n\n', ... | |
117 'Press any key to continue training ...\n\n']); | |
118 pause; | |
119 figure(fh1); | |
120 elseif (j == 8) | |
121 fprintf([... | |
122 'After another 4 iterations of training: from now on further \n', ... | |
123 'training only makes small changes to the mapping, which combined with \n', ... | |
124 'decrements of the Gaussian mixture variance, optimize the fit in \n', ... | |
125 'terms of likelihood.\n\n', ... | |
126 'Press any key to continue training ...\n\n']); | |
127 pause; | |
128 figure(fh1); | |
129 else | |
130 pause(1); | |
131 end | |
132 end | |
133 | |
134 clc; | |
135 fprintf([... | |
136 'After 15 iterations of training the GTM can be regarded as converged. \n', ... | |
137 'Is has been adapted to fit the target data distribution as well \n', ... | |
138 'as possible, given prior smoothness constraints on the mapping. It \n', ... | |
139 'captures the fact that the probabilty density is higher at the two \n', ... | |
140 'bends of the curve, and lower towards its end points.\n\n']); | |
141 disp(' '); | |
142 disp('Press any key to exit.'); | |
143 pause; | |
144 | |
145 close(fh1); | |
146 clear all; | |
147 |