wolffd@0
|
1 %DEMMLP2 Demonstrate simple classification using a multi-layer perceptron
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % The problem consists of input data in two dimensions drawn from a
|
wolffd@0
|
5 % mixture of three Gaussians: two of which are assigned to a single
|
wolffd@0
|
6 % class. An MLP with logistic outputs trained with a quasi-Newton
|
wolffd@0
|
7 % optimisation algorithm is compared with the optimal Bayesian decision
|
wolffd@0
|
8 % rule.
|
wolffd@0
|
9 %
|
wolffd@0
|
10 % See also
|
wolffd@0
|
11 % MLP, MLPFWD, NETERR, QUASINEW
|
wolffd@0
|
12 %
|
wolffd@0
|
13
|
wolffd@0
|
14 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
15
|
wolffd@0
|
16
|
wolffd@0
|
17 % Set up some figure parameters
|
wolffd@0
|
18 AxisShift = 0.05;
|
wolffd@0
|
19 ClassSymbol1 = 'r.';
|
wolffd@0
|
20 ClassSymbol2 = 'y.';
|
wolffd@0
|
21 PointSize = 12;
|
wolffd@0
|
22 titleSize = 10;
|
wolffd@0
|
23
|
wolffd@0
|
24 % Fix the seeds
|
wolffd@0
|
25 rand('state', 423);
|
wolffd@0
|
26 randn('state', 423);
|
wolffd@0
|
27
|
wolffd@0
|
28 clc
|
wolffd@0
|
29 disp('This demonstration shows how an MLP with logistic outputs and')
|
wolffd@0
|
30 disp('and cross entropy error function can be trained to model the')
|
wolffd@0
|
31 disp('posterior class probabilities in a classification problem.')
|
wolffd@0
|
32 disp('The results are compared with the optimal Bayes rule classifier,')
|
wolffd@0
|
33 disp('which can be computed exactly as we know the form of the generating')
|
wolffd@0
|
34 disp('distribution.')
|
wolffd@0
|
35 disp(' ')
|
wolffd@0
|
36 disp('Press any key to continue.')
|
wolffd@0
|
37 pause
|
wolffd@0
|
38
|
wolffd@0
|
39 fh1 = figure;
|
wolffd@0
|
40 set(fh1, 'Name', 'True Data Distribution');
|
wolffd@0
|
41 whitebg(fh1, 'k');
|
wolffd@0
|
42
|
wolffd@0
|
43 %
|
wolffd@0
|
44 % Generate the data
|
wolffd@0
|
45 %
|
wolffd@0
|
46 n=200;
|
wolffd@0
|
47
|
wolffd@0
|
48 % Set up mixture model: 2d data with three centres
|
wolffd@0
|
49 % Class 1 is first centre, class 2 from the other two
|
wolffd@0
|
50 mix = gmm(2, 3, 'full');
|
wolffd@0
|
51 mix.priors = [0.5 0.25 0.25];
|
wolffd@0
|
52 mix.centres = [0 -0.1; 1 1; 1 -1];
|
wolffd@0
|
53 mix.covars(:,:,1) = [0.625 -0.2165; -0.2165 0.875];
|
wolffd@0
|
54 mix.covars(:,:,2) = [0.2241 -0.1368; -0.1368 0.9759];
|
wolffd@0
|
55 mix.covars(:,:,3) = [0.2375 0.1516; 0.1516 0.4125];
|
wolffd@0
|
56
|
wolffd@0
|
57 [data, label] = gmmsamp(mix, n);
|
wolffd@0
|
58
|
wolffd@0
|
59 %
|
wolffd@0
|
60 % Calculate some useful axis limits
|
wolffd@0
|
61 %
|
wolffd@0
|
62 x0 = min(data(:,1));
|
wolffd@0
|
63 x1 = max(data(:,1));
|
wolffd@0
|
64 y0 = min(data(:,2));
|
wolffd@0
|
65 y1 = max(data(:,2));
|
wolffd@0
|
66 dx = x1-x0;
|
wolffd@0
|
67 dy = y1-y0;
|
wolffd@0
|
68 expand = 5/100; % Add on 5 percent each way
|
wolffd@0
|
69 x0 = x0 - dx*expand;
|
wolffd@0
|
70 x1 = x1 + dx*expand;
|
wolffd@0
|
71 y0 = y0 - dy*expand;
|
wolffd@0
|
72 y1 = y1 + dy*expand;
|
wolffd@0
|
73 resolution = 100;
|
wolffd@0
|
74 step = dx/resolution;
|
wolffd@0
|
75 xrange = [x0:step:x1];
|
wolffd@0
|
76 yrange = [y0:step:y1];
|
wolffd@0
|
77 %
|
wolffd@0
|
78 % Generate the grid
|
wolffd@0
|
79 %
|
wolffd@0
|
80 [X Y]=meshgrid([x0:step:x1],[y0:step:y1]);
|
wolffd@0
|
81 %
|
wolffd@0
|
82 % Calculate the class conditional densities, the unconditional densities and
|
wolffd@0
|
83 % the posterior probabilities
|
wolffd@0
|
84 %
|
wolffd@0
|
85 px_j = gmmactiv(mix, [X(:) Y(:)]);
|
wolffd@0
|
86 px = reshape(px_j*(mix.priors)',size(X));
|
wolffd@0
|
87 post = gmmpost(mix, [X(:) Y(:)]);
|
wolffd@0
|
88 p1_x = reshape(post(:, 1), size(X));
|
wolffd@0
|
89 p2_x = reshape(post(:, 2) + post(:, 3), size(X));
|
wolffd@0
|
90
|
wolffd@0
|
91 %
|
wolffd@0
|
92 % Generate some pretty pictures !!
|
wolffd@0
|
93 %
|
wolffd@0
|
94 colormap(hot)
|
wolffd@0
|
95 colorbar
|
wolffd@0
|
96 subplot(1,2,1)
|
wolffd@0
|
97 hold on
|
wolffd@0
|
98 plot(data((label==1),1),data(label==1,2),ClassSymbol1, 'MarkerSize', PointSize)
|
wolffd@0
|
99 plot(data((label>1),1),data(label>1,2),ClassSymbol2, 'MarkerSize', PointSize)
|
wolffd@0
|
100 contour(xrange,yrange,p1_x,[0.5 0.5],'w-');
|
wolffd@0
|
101 axis([x0 x1 y0 y1])
|
wolffd@0
|
102 set(gca,'Box','On')
|
wolffd@0
|
103 title('The Sampled Data');
|
wolffd@0
|
104 rect=get(gca,'Position');
|
wolffd@0
|
105 rect(1)=rect(1)-AxisShift;
|
wolffd@0
|
106 rect(3)=rect(3)+AxisShift;
|
wolffd@0
|
107 set(gca,'Position',rect)
|
wolffd@0
|
108 hold off
|
wolffd@0
|
109
|
wolffd@0
|
110 subplot(1,2,2)
|
wolffd@0
|
111 imagesc(X(:),Y(:),px);
|
wolffd@0
|
112 hold on
|
wolffd@0
|
113 [cB, hB] = contour(xrange,yrange,p1_x,[0.5 0.5],'w:');
|
wolffd@0
|
114 set(hB,'LineWidth', 2);
|
wolffd@0
|
115 axis([x0 x1 y0 y1])
|
wolffd@0
|
116 set(gca,'YDir','normal')
|
wolffd@0
|
117 title('Probability Density p(x)')
|
wolffd@0
|
118 hold off
|
wolffd@0
|
119
|
wolffd@0
|
120 drawnow;
|
wolffd@0
|
121 clc;
|
wolffd@0
|
122 disp('The first figure shows the data sampled from a mixture of three')
|
wolffd@0
|
123 disp('Gaussians, the first of which (whose centre is near the origin) is')
|
wolffd@0
|
124 disp('labelled red and the other two are labelled yellow. The second plot')
|
wolffd@0
|
125 disp('shows the unconditional density of the data with the optimal Bayesian')
|
wolffd@0
|
126 disp('decision boundary superimposed.')
|
wolffd@0
|
127 disp(' ')
|
wolffd@0
|
128 disp('Press any key to continue.')
|
wolffd@0
|
129 pause
|
wolffd@0
|
130 fh2 = figure;
|
wolffd@0
|
131 set(fh2, 'Name', 'Class-conditional Densities and Posterior Probabilities');
|
wolffd@0
|
132 whitebg(fh2, 'w');
|
wolffd@0
|
133
|
wolffd@0
|
134 subplot(2,2,1)
|
wolffd@0
|
135 p1=reshape(px_j(:,1),size(X));
|
wolffd@0
|
136 imagesc(X(:),Y(:),p1);
|
wolffd@0
|
137 colormap hot
|
wolffd@0
|
138 colorbar
|
wolffd@0
|
139 axis(axis)
|
wolffd@0
|
140 set(gca,'YDir','normal')
|
wolffd@0
|
141 hold on
|
wolffd@0
|
142 plot(mix.centres(:,1),mix.centres(:,2),'b+','MarkerSize',8,'LineWidth',2)
|
wolffd@0
|
143 title('Density p(x|red)')
|
wolffd@0
|
144 hold off
|
wolffd@0
|
145
|
wolffd@0
|
146 subplot(2,2,2)
|
wolffd@0
|
147 p2=reshape((px_j(:,2)+px_j(:,3)),size(X));
|
wolffd@0
|
148 imagesc(X(:),Y(:),p2);
|
wolffd@0
|
149 colorbar
|
wolffd@0
|
150 set(gca,'YDir','normal')
|
wolffd@0
|
151 hold on
|
wolffd@0
|
152 plot(mix.centres(:,1),mix.centres(:,2),'b+','MarkerSize',8,'LineWidth',2)
|
wolffd@0
|
153 title('Density p(x|yellow)')
|
wolffd@0
|
154 hold off
|
wolffd@0
|
155
|
wolffd@0
|
156 subplot(2,2,3)
|
wolffd@0
|
157 imagesc(X(:),Y(:),p1_x);
|
wolffd@0
|
158 set(gca,'YDir','normal')
|
wolffd@0
|
159 colorbar
|
wolffd@0
|
160 title('Posterior Probability p(red|x)')
|
wolffd@0
|
161 hold on
|
wolffd@0
|
162 plot(mix.centres(:,1),mix.centres(:,2),'b+','MarkerSize',8,'LineWidth',2)
|
wolffd@0
|
163 hold off
|
wolffd@0
|
164
|
wolffd@0
|
165 subplot(2,2,4)
|
wolffd@0
|
166 imagesc(X(:),Y(:),p2_x);
|
wolffd@0
|
167 set(gca,'YDir','normal')
|
wolffd@0
|
168 colorbar
|
wolffd@0
|
169 title('Posterior Probability p(yellow|x)')
|
wolffd@0
|
170 hold on
|
wolffd@0
|
171 plot(mix.centres(:,1),mix.centres(:,2),'b+','MarkerSize',8,'LineWidth',2)
|
wolffd@0
|
172 hold off
|
wolffd@0
|
173
|
wolffd@0
|
174 % Now set up and train the MLP
|
wolffd@0
|
175 nhidden=6;
|
wolffd@0
|
176 nout=1;
|
wolffd@0
|
177 alpha = 0.2; % Weight decay
|
wolffd@0
|
178 ncycles = 60; % Number of training cycles.
|
wolffd@0
|
179 % Set up MLP network
|
wolffd@0
|
180 net = mlp(2, nhidden, nout, 'logistic', alpha);
|
wolffd@0
|
181 options = zeros(1,18);
|
wolffd@0
|
182 options(1) = 1; % Print out error values
|
wolffd@0
|
183 options(14) = ncycles;
|
wolffd@0
|
184
|
wolffd@0
|
185 mlpstring = ['We now set up an MLP with ', num2str(nhidden), ...
|
wolffd@0
|
186 ' hidden units, logistic output and cross'];
|
wolffd@0
|
187 trainstring = ['entropy error function, and train it for ', ...
|
wolffd@0
|
188 num2str(ncycles), ' cycles using the'];
|
wolffd@0
|
189 wdstring = ['quasi-Newton optimisation algorithm with weight decay of ', ...
|
wolffd@0
|
190 num2str(alpha), '.'];
|
wolffd@0
|
191
|
wolffd@0
|
192 % Force out the figure before training the MLP
|
wolffd@0
|
193 drawnow;
|
wolffd@0
|
194 disp(' ')
|
wolffd@0
|
195 disp('The second figure shows the class conditional densities and posterior')
|
wolffd@0
|
196 disp('probabilities for each class. The blue crosses mark the centres of')
|
wolffd@0
|
197 disp('the three Gaussians.')
|
wolffd@0
|
198 disp(' ')
|
wolffd@0
|
199 disp(mlpstring)
|
wolffd@0
|
200 disp(trainstring)
|
wolffd@0
|
201 disp(wdstring)
|
wolffd@0
|
202 disp(' ')
|
wolffd@0
|
203 disp('Press any key to continue.')
|
wolffd@0
|
204 pause
|
wolffd@0
|
205
|
wolffd@0
|
206 % Convert targets to 0-1 encoding
|
wolffd@0
|
207 target=[label==1];
|
wolffd@0
|
208
|
wolffd@0
|
209 % Train using quasi-Newton.
|
wolffd@0
|
210 [net] = netopt(net, options, data, target, 'quasinew');
|
wolffd@0
|
211 y = mlpfwd(net, data);
|
wolffd@0
|
212 yg = mlpfwd(net, [X(:) Y(:)]);
|
wolffd@0
|
213 yg = reshape(yg(:,1),size(X));
|
wolffd@0
|
214
|
wolffd@0
|
215 fh3 = figure;
|
wolffd@0
|
216 set(fh3, 'Name', 'Network Output');
|
wolffd@0
|
217 whitebg(fh3, 'k')
|
wolffd@0
|
218 subplot(1, 2, 1)
|
wolffd@0
|
219 hold on
|
wolffd@0
|
220 plot(data((label==1),1),data(label==1,2),'r.', 'MarkerSize', PointSize)
|
wolffd@0
|
221 plot(data((label>1),1),data(label>1,2),'y.', 'MarkerSize', PointSize)
|
wolffd@0
|
222 % Bayesian decision boundary
|
wolffd@0
|
223 [cB, hB] = contour(xrange,yrange,p1_x,[0.5 0.5],'b-');
|
wolffd@0
|
224 [cN, hN] = contour(xrange,yrange,yg,[0.5 0.5],'r-');
|
wolffd@0
|
225 set(hB, 'LineWidth', 2);
|
wolffd@0
|
226 set(hN, 'LineWidth', 2);
|
wolffd@0
|
227 Chandles = [hB(1) hN(1)];
|
wolffd@0
|
228 legend(Chandles, 'Bayes', ...
|
wolffd@0
|
229 'Network', 3);
|
wolffd@0
|
230
|
wolffd@0
|
231 axis([x0 x1 y0 y1])
|
wolffd@0
|
232 set(gca,'Box','on','XTick',[],'YTick',[])
|
wolffd@0
|
233
|
wolffd@0
|
234 title('Training Data','FontSize',titleSize);
|
wolffd@0
|
235 hold off
|
wolffd@0
|
236
|
wolffd@0
|
237 subplot(1, 2, 2)
|
wolffd@0
|
238 imagesc(X(:),Y(:),yg);
|
wolffd@0
|
239 colormap hot
|
wolffd@0
|
240 colorbar
|
wolffd@0
|
241 axis(axis)
|
wolffd@0
|
242 set(gca,'YDir','normal','XTick',[],'YTick',[])
|
wolffd@0
|
243 title('Network Output','FontSize',titleSize)
|
wolffd@0
|
244
|
wolffd@0
|
245 clc
|
wolffd@0
|
246 disp('This figure shows the training data with the decision boundary')
|
wolffd@0
|
247 disp('produced by the trained network and the network''s prediction of')
|
wolffd@0
|
248 disp('the posterior probability of the red class.')
|
wolffd@0
|
249 disp(' ')
|
wolffd@0
|
250 disp('Press any key to continue.')
|
wolffd@0
|
251 pause
|
wolffd@0
|
252
|
wolffd@0
|
253 %
|
wolffd@0
|
254 % Now generate and classify a test data set
|
wolffd@0
|
255 %
|
wolffd@0
|
256 [testdata testlabel] = gmmsamp(mix, n);
|
wolffd@0
|
257 testlab=[testlabel==1 testlabel>1];
|
wolffd@0
|
258
|
wolffd@0
|
259 % This is the Bayesian classification
|
wolffd@0
|
260 tpx_j = gmmpost(mix, testdata);
|
wolffd@0
|
261 Bpost = [tpx_j(:,1), tpx_j(:,2)+tpx_j(:,3)];
|
wolffd@0
|
262 [Bcon Brate]=confmat(Bpost, [testlabel==1 testlabel>1]);
|
wolffd@0
|
263
|
wolffd@0
|
264 % Compute network classification
|
wolffd@0
|
265 yt = mlpfwd(net, testdata);
|
wolffd@0
|
266 % Convert single output to posteriors for both classes
|
wolffd@0
|
267 testpost = [yt 1-yt];
|
wolffd@0
|
268 [C trate]=confmat(testpost,[testlabel==1 testlabel>1]);
|
wolffd@0
|
269
|
wolffd@0
|
270 fh4 = figure;
|
wolffd@0
|
271 set(fh4, 'Name', 'Decision Boundaries');
|
wolffd@0
|
272 whitebg(fh4, 'k');
|
wolffd@0
|
273 hold on
|
wolffd@0
|
274 plot(testdata((testlabel==1),1),testdata((testlabel==1),2),...
|
wolffd@0
|
275 ClassSymbol1, 'MarkerSize', PointSize)
|
wolffd@0
|
276 plot(testdata((testlabel>1),1),testdata((testlabel>1),2),...
|
wolffd@0
|
277 ClassSymbol2, 'MarkerSize', PointSize)
|
wolffd@0
|
278 % Bayesian decision boundary
|
wolffd@0
|
279 [cB, hB] = contour(xrange,yrange,p1_x,[0.5 0.5],'b-');
|
wolffd@0
|
280 set(hB, 'LineWidth', 2);
|
wolffd@0
|
281 % Network decision boundary
|
wolffd@0
|
282 [cN, hN] = contour(xrange,yrange,yg,[0.5 0.5],'r-');
|
wolffd@0
|
283 set(hN, 'LineWidth', 2);
|
wolffd@0
|
284 Chandles = [hB(1) hN(1)];
|
wolffd@0
|
285 legend(Chandles, 'Bayes decision boundary', ...
|
wolffd@0
|
286 'Network decision boundary', -1);
|
wolffd@0
|
287 axis([x0 x1 y0 y1])
|
wolffd@0
|
288 title('Test Data')
|
wolffd@0
|
289 set(gca,'Box','On','Xtick',[],'YTick',[])
|
wolffd@0
|
290
|
wolffd@0
|
291 clc
|
wolffd@0
|
292 disp('This figure shows the test data with the decision boundary')
|
wolffd@0
|
293 disp('produced by the trained network and the optimal Bayes rule.')
|
wolffd@0
|
294 disp(' ')
|
wolffd@0
|
295 disp('Press any key to continue.')
|
wolffd@0
|
296 pause
|
wolffd@0
|
297
|
wolffd@0
|
298 fh5 = figure;
|
wolffd@0
|
299 set(fh5, 'Name', 'Test Set Performance');
|
wolffd@0
|
300 whitebg(fh5, 'w');
|
wolffd@0
|
301 % Bayes rule performance
|
wolffd@0
|
302 subplot(1,2,1)
|
wolffd@0
|
303 plotmat(Bcon,'b','k',12)
|
wolffd@0
|
304 set(gca,'XTick',[0.5 1.5])
|
wolffd@0
|
305 set(gca,'YTick',[0.5 1.5])
|
wolffd@0
|
306 grid('off')
|
wolffd@0
|
307 set(gca,'XTickLabel',['Red ' ; 'Yellow'])
|
wolffd@0
|
308 set(gca,'YTickLabel',['Yellow' ; 'Red '])
|
wolffd@0
|
309 ylabel('True')
|
wolffd@0
|
310 xlabel('Predicted')
|
wolffd@0
|
311 title(['Bayes Confusion Matrix (' num2str(Brate(1)) '%)'])
|
wolffd@0
|
312
|
wolffd@0
|
313 % Network performance
|
wolffd@0
|
314 subplot(1,2, 2)
|
wolffd@0
|
315 plotmat(C,'b','k',12)
|
wolffd@0
|
316 set(gca,'XTick',[0.5 1.5])
|
wolffd@0
|
317 set(gca,'YTick',[0.5 1.5])
|
wolffd@0
|
318 grid('off')
|
wolffd@0
|
319 set(gca,'XTickLabel',['Red ' ; 'Yellow'])
|
wolffd@0
|
320 set(gca,'YTickLabel',['Yellow' ; 'Red '])
|
wolffd@0
|
321 ylabel('True')
|
wolffd@0
|
322 xlabel('Predicted')
|
wolffd@0
|
323 title(['Network Confusion Matrix (' num2str(trate(1)) '%)'])
|
wolffd@0
|
324
|
wolffd@0
|
325 disp('The final figure shows the confusion matrices for the')
|
wolffd@0
|
326 disp('two rules on the test set.')
|
wolffd@0
|
327 disp(' ')
|
wolffd@0
|
328 disp('Press any key to exit.')
|
wolffd@0
|
329 pause
|
wolffd@0
|
330 whitebg(fh1, 'w');
|
wolffd@0
|
331 whitebg(fh2, 'w');
|
wolffd@0
|
332 whitebg(fh3, 'w');
|
wolffd@0
|
333 whitebg(fh4, 'w');
|
wolffd@0
|
334 whitebg(fh5, 'w');
|
wolffd@0
|
335 close(fh1); close(fh2); close(fh3);
|
wolffd@0
|
336 close(fh4); close(fh5);
|
wolffd@0
|
337 clear all; |