Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/demgtm2.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 %DEMGTM2 Demonstrate GTM for visualisation. | |
2 % | |
3 % Description | |
4 % This script demonstrates the use of a GTM with a two-dimensional | |
5 % latent space to visualise data in a higher dimensional space. This is | |
6 % done through the use of the mean responsibility and magnification | |
7 % factors. | |
8 % | |
9 % See also | |
10 % DEMGTM1, GTM, GTMEM, GTMPOST | |
11 % | |
12 | |
13 % Copyright (c) Ian T Nabney (1996-2001) | |
14 | |
15 | |
16 % Fix seeds for reproducible results | |
17 rand('state', 420); | |
18 randn('state', 420); | |
19 | |
20 ndata = 300 | |
21 clc; | |
22 disp('This demonstration shows how a Generative Topographic Mapping') | |
23 disp('can be used to model and visualise high dimensional data. The') | |
24 disp('data is generated from a mixture of two spherical Gaussians in') | |
25 dstring = ['four dimensional space. ', num2str(ndata), ... | |
26 ' data points are generated.']; | |
27 disp(dstring); | |
28 disp(' '); | |
29 disp('Press any key to continue.') | |
30 pause | |
31 % Create data | |
32 data_dim = 4; | |
33 latent_dim = 2; | |
34 mix = gmm(data_dim, 2, 'spherical'); | |
35 mix.centres = [1 1 1 1; 0 0 0 0]; | |
36 mix.priors = [0.5 0.5]; | |
37 mix.covars = [0.1 0.1]; | |
38 | |
39 [data, labels] = gmmsamp(mix, ndata); | |
40 | |
41 latent_shape = [15 15]; % Number of latent points in each dimension | |
42 nlatent = prod(latent_shape); % Number of latent points | |
43 num_rbf_centres = 16; | |
44 | |
45 clc; | |
46 dstring = ['Next we generate and initialise the GTM. There are ',... | |
47 num2str(nlatent), ' latent points']; | |
48 disp(dstring); | |
49 dstring = ['arranged in a square of ', num2str(latent_shape(1)), ... | |
50 ' points on a side. There are ', num2str(num_rbf_centres), ... | |
51 ' centres in the']; | |
52 disp(dstring); | |
53 disp('RBF model, which has Gaussian activation functions.') | |
54 disp(' ') | |
55 disp('Once the model is created, the latent data sample') | |
56 disp('and RBF centres are placed uniformly in the square [-1 1 -1 1].') | |
57 disp('The output weights of the RBF are computed to map the latent'); | |
58 disp('space to the two dimensional PCA subspace of the data.'); | |
59 disp(' ') | |
60 disp('Press any key to continue.'); | |
61 pause; | |
62 | |
63 % Create and initialise GTM model | |
64 net = gtm(latent_dim, nlatent, data_dim, num_rbf_centres, ... | |
65 'gaussian', 0.1); | |
66 | |
67 options = foptions; | |
68 options(1) = -1; | |
69 options(7) = 1; % Set width factor of RBF | |
70 net = gtminit(net, options, data, 'regular', latent_shape, [4 4]); | |
71 | |
72 options = foptions; | |
73 options(14) = 30; | |
74 options(1) = 1; | |
75 | |
76 clc; | |
77 dstring = ['We now train the model with ', num2str(options(14)), ... | |
78 ' iterations of']; | |
79 disp(dstring) | |
80 disp('the EM algorithm for the GTM.') | |
81 disp(' ') | |
82 disp('Press any key to continue.') | |
83 pause; | |
84 | |
85 [net, options] = gtmem(net, data, options); | |
86 | |
87 disp(' ') | |
88 disp('Press any key to continue.') | |
89 pause; | |
90 | |
91 clc; | |
92 disp('We now visualise the data by plotting, for each data point,'); | |
93 disp('the posterior mean and mode (in latent space). These give'); | |
94 disp('a summary of the entire posterior distribution in latent space.') | |
95 disp('The corresponding values are joined by a line to aid the') | |
96 disp('interpretation.') | |
97 disp(' ') | |
98 disp('Press any key to continue.'); | |
99 pause; | |
100 % Plot posterior means | |
101 means = gtmlmean(net, data); | |
102 modes = gtmlmode(net, data); | |
103 PointSize = 12; | |
104 ClassSymbol1 = 'r.'; | |
105 ClassSymbol2 = 'b.'; | |
106 fh1 = figure; | |
107 hold on; | |
108 title('Visualisation in latent space') | |
109 plot(means((labels==1),1), means(labels==1,2), ... | |
110 ClassSymbol1, 'MarkerSize', PointSize) | |
111 plot(means((labels>1),1),means(labels>1,2),... | |
112 ClassSymbol2, 'MarkerSize', PointSize) | |
113 | |
114 ClassSymbol1 = 'ro'; | |
115 ClassSymbol2 = 'bo'; | |
116 plot(modes(labels==1,1), modes(labels==1,2), ... | |
117 ClassSymbol1) | |
118 plot(modes(labels>1,1),modes(labels>1,2),... | |
119 ClassSymbol2) | |
120 | |
121 % Join up means and modes | |
122 for n = 1:ndata | |
123 plot([means(n,1); modes(n,1)], [means(n,2); modes(n,2)], 'g-') | |
124 end | |
125 % Place legend outside data plot | |
126 legend('Mean (class 1)', 'Mean (class 2)', 'Mode (class 1)',... | |
127 'Mode (class 2)', -1); | |
128 | |
129 % Display posterior for a data point | |
130 % Choose an interesting one with a large distance between mean and | |
131 % mode | |
132 [distance, point] = max(sum((means-modes).^2, 2)); | |
133 resp = gtmpost(net, data(point, :)); | |
134 | |
135 disp(' ') | |
136 disp('For more detailed information, the full posterior distribution') | |
137 disp('(or responsibility) can be plotted in latent space for a') | |
138 disp('single data point. This point has been chosen as the one') | |
139 disp('with the largest distance between mean and mode.') | |
140 disp(' ') | |
141 disp('Press any key to continue.'); | |
142 pause; | |
143 | |
144 R = reshape(resp, fliplr(latent_shape)); | |
145 XL = reshape(net.X(:,1), fliplr(latent_shape)); | |
146 YL = reshape(net.X(:,2), fliplr(latent_shape)); | |
147 | |
148 fh2 = figure; | |
149 imagesc(net.X(:, 1), net.X(:,2), R); | |
150 hold on; | |
151 tstr = ['Responsibility for point ', num2str(point)]; | |
152 title(tstr); | |
153 set(gca,'YDir','normal') | |
154 colormap(hot); | |
155 colorbar | |
156 disp(' '); | |
157 disp('Press any key to continue.') | |
158 pause | |
159 | |
160 clc | |
161 disp('Finally, we visualise the data with the posterior means in') | |
162 disp('latent space as before, but superimpose the magnification') | |
163 disp('factors to highlight the separation between clusters.') | |
164 disp(' ') | |
165 disp('Note the large magnitude factors down the centre of the') | |
166 disp('graph, showing that the manifold is stretched more in') | |
167 disp('this region than within each of the two clusters.') | |
168 ClassSymbol1 = 'g.'; | |
169 ClassSymbol2 = 'b.'; | |
170 | |
171 fh3 = figure; | |
172 mags = gtmmag(net, net.X); | |
173 % Reshape into grid form | |
174 Mags = reshape(mags, fliplr(latent_shape)); | |
175 imagesc(net.X(:, 1), net.X(:,2), Mags); | |
176 hold on | |
177 title('Dataset visualisation with magnification factors') | |
178 set(gca,'YDir','normal') | |
179 colormap(hot); | |
180 colorbar | |
181 hold on; % Else the magnification plot disappears | |
182 plot(means(labels==1,1), means(labels==1,2), ... | |
183 ClassSymbol1, 'MarkerSize', PointSize) | |
184 plot(means(labels>1,1), means(labels>1,2), ... | |
185 ClassSymbol2, 'MarkerSize', PointSize) | |
186 | |
187 disp(' ') | |
188 disp('Press any key to exit.') | |
189 pause | |
190 | |
191 close(fh1); | |
192 close(fh2); | |
193 close(fh3); | |
194 clear all; |