comparison toolboxes/FullBNT-1.0.7/netlab3.3/demgtm2.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMGTM2 Demonstrate GTM for visualisation.
2 %
3 % Description
4 % This script demonstrates the use of a GTM with a two-dimensional
5 % latent space to visualise data in a higher dimensional space. This is
6 % done through the use of the mean responsibility and magnification
7 % factors.
8 %
9 % See also
10 % DEMGTM1, GTM, GTMEM, GTMPOST
11 %
12
13 % Copyright (c) Ian T Nabney (1996-2001)
14
15
16 % Fix seeds for reproducible results
17 rand('state', 420);
18 randn('state', 420);
19
20 ndata = 300
21 clc;
22 disp('This demonstration shows how a Generative Topographic Mapping')
23 disp('can be used to model and visualise high dimensional data. The')
24 disp('data is generated from a mixture of two spherical Gaussians in')
25 dstring = ['four dimensional space. ', num2str(ndata), ...
26 ' data points are generated.'];
27 disp(dstring);
28 disp(' ');
29 disp('Press any key to continue.')
30 pause
31 % Create data
32 data_dim = 4;
33 latent_dim = 2;
34 mix = gmm(data_dim, 2, 'spherical');
35 mix.centres = [1 1 1 1; 0 0 0 0];
36 mix.priors = [0.5 0.5];
37 mix.covars = [0.1 0.1];
38
39 [data, labels] = gmmsamp(mix, ndata);
40
41 latent_shape = [15 15]; % Number of latent points in each dimension
42 nlatent = prod(latent_shape); % Number of latent points
43 num_rbf_centres = 16;
44
45 clc;
46 dstring = ['Next we generate and initialise the GTM. There are ',...
47 num2str(nlatent), ' latent points'];
48 disp(dstring);
49 dstring = ['arranged in a square of ', num2str(latent_shape(1)), ...
50 ' points on a side. There are ', num2str(num_rbf_centres), ...
51 ' centres in the'];
52 disp(dstring);
53 disp('RBF model, which has Gaussian activation functions.')
54 disp(' ')
55 disp('Once the model is created, the latent data sample')
56 disp('and RBF centres are placed uniformly in the square [-1 1 -1 1].')
57 disp('The output weights of the RBF are computed to map the latent');
58 disp('space to the two dimensional PCA subspace of the data.');
59 disp(' ')
60 disp('Press any key to continue.');
61 pause;
62
63 % Create and initialise GTM model
64 net = gtm(latent_dim, nlatent, data_dim, num_rbf_centres, ...
65 'gaussian', 0.1);
66
67 options = foptions;
68 options(1) = -1;
69 options(7) = 1; % Set width factor of RBF
70 net = gtminit(net, options, data, 'regular', latent_shape, [4 4]);
71
72 options = foptions;
73 options(14) = 30;
74 options(1) = 1;
75
76 clc;
77 dstring = ['We now train the model with ', num2str(options(14)), ...
78 ' iterations of'];
79 disp(dstring)
80 disp('the EM algorithm for the GTM.')
81 disp(' ')
82 disp('Press any key to continue.')
83 pause;
84
85 [net, options] = gtmem(net, data, options);
86
87 disp(' ')
88 disp('Press any key to continue.')
89 pause;
90
91 clc;
92 disp('We now visualise the data by plotting, for each data point,');
93 disp('the posterior mean and mode (in latent space). These give');
94 disp('a summary of the entire posterior distribution in latent space.')
95 disp('The corresponding values are joined by a line to aid the')
96 disp('interpretation.')
97 disp(' ')
98 disp('Press any key to continue.');
99 pause;
100 % Plot posterior means
101 means = gtmlmean(net, data);
102 modes = gtmlmode(net, data);
103 PointSize = 12;
104 ClassSymbol1 = 'r.';
105 ClassSymbol2 = 'b.';
106 fh1 = figure;
107 hold on;
108 title('Visualisation in latent space')
109 plot(means((labels==1),1), means(labels==1,2), ...
110 ClassSymbol1, 'MarkerSize', PointSize)
111 plot(means((labels>1),1),means(labels>1,2),...
112 ClassSymbol2, 'MarkerSize', PointSize)
113
114 ClassSymbol1 = 'ro';
115 ClassSymbol2 = 'bo';
116 plot(modes(labels==1,1), modes(labels==1,2), ...
117 ClassSymbol1)
118 plot(modes(labels>1,1),modes(labels>1,2),...
119 ClassSymbol2)
120
121 % Join up means and modes
122 for n = 1:ndata
123 plot([means(n,1); modes(n,1)], [means(n,2); modes(n,2)], 'g-')
124 end
125 % Place legend outside data plot
126 legend('Mean (class 1)', 'Mean (class 2)', 'Mode (class 1)',...
127 'Mode (class 2)', -1);
128
129 % Display posterior for a data point
130 % Choose an interesting one with a large distance between mean and
131 % mode
132 [distance, point] = max(sum((means-modes).^2, 2));
133 resp = gtmpost(net, data(point, :));
134
135 disp(' ')
136 disp('For more detailed information, the full posterior distribution')
137 disp('(or responsibility) can be plotted in latent space for a')
138 disp('single data point. This point has been chosen as the one')
139 disp('with the largest distance between mean and mode.')
140 disp(' ')
141 disp('Press any key to continue.');
142 pause;
143
144 R = reshape(resp, fliplr(latent_shape));
145 XL = reshape(net.X(:,1), fliplr(latent_shape));
146 YL = reshape(net.X(:,2), fliplr(latent_shape));
147
148 fh2 = figure;
149 imagesc(net.X(:, 1), net.X(:,2), R);
150 hold on;
151 tstr = ['Responsibility for point ', num2str(point)];
152 title(tstr);
153 set(gca,'YDir','normal')
154 colormap(hot);
155 colorbar
156 disp(' ');
157 disp('Press any key to continue.')
158 pause
159
160 clc
161 disp('Finally, we visualise the data with the posterior means in')
162 disp('latent space as before, but superimpose the magnification')
163 disp('factors to highlight the separation between clusters.')
164 disp(' ')
165 disp('Note the large magnitude factors down the centre of the')
166 disp('graph, showing that the manifold is stretched more in')
167 disp('this region than within each of the two clusters.')
168 ClassSymbol1 = 'g.';
169 ClassSymbol2 = 'b.';
170
171 fh3 = figure;
172 mags = gtmmag(net, net.X);
173 % Reshape into grid form
174 Mags = reshape(mags, fliplr(latent_shape));
175 imagesc(net.X(:, 1), net.X(:,2), Mags);
176 hold on
177 title('Dataset visualisation with magnification factors')
178 set(gca,'YDir','normal')
179 colormap(hot);
180 colorbar
181 hold on; % Else the magnification plot disappears
182 plot(means(labels==1,1), means(labels==1,2), ...
183 ClassSymbol1, 'MarkerSize', PointSize)
184 plot(means(labels>1,1), means(labels>1,2), ...
185 ClassSymbol2, 'MarkerSize', PointSize)
186
187 disp(' ')
188 disp('Press any key to exit.')
189 pause
190
191 close(fh1);
192 close(fh2);
193 close(fh3);
194 clear all;