wolffd@0
|
1 %DEMGTM2 Demonstrate GTM for visualisation.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % This script demonstrates the use of a GTM with a two-dimensional
|
wolffd@0
|
5 % latent space to visualise data in a higher dimensional space. This is
|
wolffd@0
|
6 % done through the use of the mean responsibility and magnification
|
wolffd@0
|
7 % factors.
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % See also
|
wolffd@0
|
10 % DEMGTM1, GTM, GTMEM, GTMPOST
|
wolffd@0
|
11 %
|
wolffd@0
|
12
|
wolffd@0
|
13 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
14
|
wolffd@0
|
15
|
wolffd@0
|
16 % Fix seeds for reproducible results
|
wolffd@0
|
17 rand('state', 420);
|
wolffd@0
|
18 randn('state', 420);
|
wolffd@0
|
19
|
wolffd@0
|
20 ndata = 300
|
wolffd@0
|
21 clc;
|
wolffd@0
|
22 disp('This demonstration shows how a Generative Topographic Mapping')
|
wolffd@0
|
23 disp('can be used to model and visualise high dimensional data. The')
|
wolffd@0
|
24 disp('data is generated from a mixture of two spherical Gaussians in')
|
wolffd@0
|
25 dstring = ['four dimensional space. ', num2str(ndata), ...
|
wolffd@0
|
26 ' data points are generated.'];
|
wolffd@0
|
27 disp(dstring);
|
wolffd@0
|
28 disp(' ');
|
wolffd@0
|
29 disp('Press any key to continue.')
|
wolffd@0
|
30 pause
|
wolffd@0
|
31 % Create data
|
wolffd@0
|
32 data_dim = 4;
|
wolffd@0
|
33 latent_dim = 2;
|
wolffd@0
|
34 mix = gmm(data_dim, 2, 'spherical');
|
wolffd@0
|
35 mix.centres = [1 1 1 1; 0 0 0 0];
|
wolffd@0
|
36 mix.priors = [0.5 0.5];
|
wolffd@0
|
37 mix.covars = [0.1 0.1];
|
wolffd@0
|
38
|
wolffd@0
|
39 [data, labels] = gmmsamp(mix, ndata);
|
wolffd@0
|
40
|
wolffd@0
|
41 latent_shape = [15 15]; % Number of latent points in each dimension
|
wolffd@0
|
42 nlatent = prod(latent_shape); % Number of latent points
|
wolffd@0
|
43 num_rbf_centres = 16;
|
wolffd@0
|
44
|
wolffd@0
|
45 clc;
|
wolffd@0
|
46 dstring = ['Next we generate and initialise the GTM. There are ',...
|
wolffd@0
|
47 num2str(nlatent), ' latent points'];
|
wolffd@0
|
48 disp(dstring);
|
wolffd@0
|
49 dstring = ['arranged in a square of ', num2str(latent_shape(1)), ...
|
wolffd@0
|
50 ' points on a side. There are ', num2str(num_rbf_centres), ...
|
wolffd@0
|
51 ' centres in the'];
|
wolffd@0
|
52 disp(dstring);
|
wolffd@0
|
53 disp('RBF model, which has Gaussian activation functions.')
|
wolffd@0
|
54 disp(' ')
|
wolffd@0
|
55 disp('Once the model is created, the latent data sample')
|
wolffd@0
|
56 disp('and RBF centres are placed uniformly in the square [-1 1 -1 1].')
|
wolffd@0
|
57 disp('The output weights of the RBF are computed to map the latent');
|
wolffd@0
|
58 disp('space to the two dimensional PCA subspace of the data.');
|
wolffd@0
|
59 disp(' ')
|
wolffd@0
|
60 disp('Press any key to continue.');
|
wolffd@0
|
61 pause;
|
wolffd@0
|
62
|
wolffd@0
|
63 % Create and initialise GTM model
|
wolffd@0
|
64 net = gtm(latent_dim, nlatent, data_dim, num_rbf_centres, ...
|
wolffd@0
|
65 'gaussian', 0.1);
|
wolffd@0
|
66
|
wolffd@0
|
67 options = foptions;
|
wolffd@0
|
68 options(1) = -1;
|
wolffd@0
|
69 options(7) = 1; % Set width factor of RBF
|
wolffd@0
|
70 net = gtminit(net, options, data, 'regular', latent_shape, [4 4]);
|
wolffd@0
|
71
|
wolffd@0
|
72 options = foptions;
|
wolffd@0
|
73 options(14) = 30;
|
wolffd@0
|
74 options(1) = 1;
|
wolffd@0
|
75
|
wolffd@0
|
76 clc;
|
wolffd@0
|
77 dstring = ['We now train the model with ', num2str(options(14)), ...
|
wolffd@0
|
78 ' iterations of'];
|
wolffd@0
|
79 disp(dstring)
|
wolffd@0
|
80 disp('the EM algorithm for the GTM.')
|
wolffd@0
|
81 disp(' ')
|
wolffd@0
|
82 disp('Press any key to continue.')
|
wolffd@0
|
83 pause;
|
wolffd@0
|
84
|
wolffd@0
|
85 [net, options] = gtmem(net, data, options);
|
wolffd@0
|
86
|
wolffd@0
|
87 disp(' ')
|
wolffd@0
|
88 disp('Press any key to continue.')
|
wolffd@0
|
89 pause;
|
wolffd@0
|
90
|
wolffd@0
|
91 clc;
|
wolffd@0
|
92 disp('We now visualise the data by plotting, for each data point,');
|
wolffd@0
|
93 disp('the posterior mean and mode (in latent space). These give');
|
wolffd@0
|
94 disp('a summary of the entire posterior distribution in latent space.')
|
wolffd@0
|
95 disp('The corresponding values are joined by a line to aid the')
|
wolffd@0
|
96 disp('interpretation.')
|
wolffd@0
|
97 disp(' ')
|
wolffd@0
|
98 disp('Press any key to continue.');
|
wolffd@0
|
99 pause;
|
wolffd@0
|
100 % Plot posterior means
|
wolffd@0
|
101 means = gtmlmean(net, data);
|
wolffd@0
|
102 modes = gtmlmode(net, data);
|
wolffd@0
|
103 PointSize = 12;
|
wolffd@0
|
104 ClassSymbol1 = 'r.';
|
wolffd@0
|
105 ClassSymbol2 = 'b.';
|
wolffd@0
|
106 fh1 = figure;
|
wolffd@0
|
107 hold on;
|
wolffd@0
|
108 title('Visualisation in latent space')
|
wolffd@0
|
109 plot(means((labels==1),1), means(labels==1,2), ...
|
wolffd@0
|
110 ClassSymbol1, 'MarkerSize', PointSize)
|
wolffd@0
|
111 plot(means((labels>1),1),means(labels>1,2),...
|
wolffd@0
|
112 ClassSymbol2, 'MarkerSize', PointSize)
|
wolffd@0
|
113
|
wolffd@0
|
114 ClassSymbol1 = 'ro';
|
wolffd@0
|
115 ClassSymbol2 = 'bo';
|
wolffd@0
|
116 plot(modes(labels==1,1), modes(labels==1,2), ...
|
wolffd@0
|
117 ClassSymbol1)
|
wolffd@0
|
118 plot(modes(labels>1,1),modes(labels>1,2),...
|
wolffd@0
|
119 ClassSymbol2)
|
wolffd@0
|
120
|
wolffd@0
|
121 % Join up means and modes
|
wolffd@0
|
122 for n = 1:ndata
|
wolffd@0
|
123 plot([means(n,1); modes(n,1)], [means(n,2); modes(n,2)], 'g-')
|
wolffd@0
|
124 end
|
wolffd@0
|
125 % Place legend outside data plot
|
wolffd@0
|
126 legend('Mean (class 1)', 'Mean (class 2)', 'Mode (class 1)',...
|
wolffd@0
|
127 'Mode (class 2)', -1);
|
wolffd@0
|
128
|
wolffd@0
|
129 % Display posterior for a data point
|
wolffd@0
|
130 % Choose an interesting one with a large distance between mean and
|
wolffd@0
|
131 % mode
|
wolffd@0
|
132 [distance, point] = max(sum((means-modes).^2, 2));
|
wolffd@0
|
133 resp = gtmpost(net, data(point, :));
|
wolffd@0
|
134
|
wolffd@0
|
135 disp(' ')
|
wolffd@0
|
136 disp('For more detailed information, the full posterior distribution')
|
wolffd@0
|
137 disp('(or responsibility) can be plotted in latent space for a')
|
wolffd@0
|
138 disp('single data point. This point has been chosen as the one')
|
wolffd@0
|
139 disp('with the largest distance between mean and mode.')
|
wolffd@0
|
140 disp(' ')
|
wolffd@0
|
141 disp('Press any key to continue.');
|
wolffd@0
|
142 pause;
|
wolffd@0
|
143
|
wolffd@0
|
144 R = reshape(resp, fliplr(latent_shape));
|
wolffd@0
|
145 XL = reshape(net.X(:,1), fliplr(latent_shape));
|
wolffd@0
|
146 YL = reshape(net.X(:,2), fliplr(latent_shape));
|
wolffd@0
|
147
|
wolffd@0
|
148 fh2 = figure;
|
wolffd@0
|
149 imagesc(net.X(:, 1), net.X(:,2), R);
|
wolffd@0
|
150 hold on;
|
wolffd@0
|
151 tstr = ['Responsibility for point ', num2str(point)];
|
wolffd@0
|
152 title(tstr);
|
wolffd@0
|
153 set(gca,'YDir','normal')
|
wolffd@0
|
154 colormap(hot);
|
wolffd@0
|
155 colorbar
|
wolffd@0
|
156 disp(' ');
|
wolffd@0
|
157 disp('Press any key to continue.')
|
wolffd@0
|
158 pause
|
wolffd@0
|
159
|
wolffd@0
|
160 clc
|
wolffd@0
|
161 disp('Finally, we visualise the data with the posterior means in')
|
wolffd@0
|
162 disp('latent space as before, but superimpose the magnification')
|
wolffd@0
|
163 disp('factors to highlight the separation between clusters.')
|
wolffd@0
|
164 disp(' ')
|
wolffd@0
|
165 disp('Note the large magnitude factors down the centre of the')
|
wolffd@0
|
166 disp('graph, showing that the manifold is stretched more in')
|
wolffd@0
|
167 disp('this region than within each of the two clusters.')
|
wolffd@0
|
168 ClassSymbol1 = 'g.';
|
wolffd@0
|
169 ClassSymbol2 = 'b.';
|
wolffd@0
|
170
|
wolffd@0
|
171 fh3 = figure;
|
wolffd@0
|
172 mags = gtmmag(net, net.X);
|
wolffd@0
|
173 % Reshape into grid form
|
wolffd@0
|
174 Mags = reshape(mags, fliplr(latent_shape));
|
wolffd@0
|
175 imagesc(net.X(:, 1), net.X(:,2), Mags);
|
wolffd@0
|
176 hold on
|
wolffd@0
|
177 title('Dataset visualisation with magnification factors')
|
wolffd@0
|
178 set(gca,'YDir','normal')
|
wolffd@0
|
179 colormap(hot);
|
wolffd@0
|
180 colorbar
|
wolffd@0
|
181 hold on; % Else the magnification plot disappears
|
wolffd@0
|
182 plot(means(labels==1,1), means(labels==1,2), ...
|
wolffd@0
|
183 ClassSymbol1, 'MarkerSize', PointSize)
|
wolffd@0
|
184 plot(means(labels>1,1), means(labels>1,2), ...
|
wolffd@0
|
185 ClassSymbol2, 'MarkerSize', PointSize)
|
wolffd@0
|
186
|
wolffd@0
|
187 disp(' ')
|
wolffd@0
|
188 disp('Press any key to exit.')
|
wolffd@0
|
189 pause
|
wolffd@0
|
190
|
wolffd@0
|
191 close(fh1);
|
wolffd@0
|
192 close(fh2);
|
wolffd@0
|
193 close(fh3);
|
wolffd@0
|
194 clear all; |