comparison toolboxes/FullBNT-1.0.7/netlab3.3/demmlp2.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMMLP2 Demonstrate simple classification using a multi-layer perceptron
2 %
3 % Description
4 % The problem consists of input data in two dimensions drawn from a
5 % mixture of three Gaussians: two of which are assigned to a single
6 % class. An MLP with logistic outputs trained with a quasi-Newton
7 % optimisation algorithm is compared with the optimal Bayesian decision
8 % rule.
9 %
10 % See also
11 % MLP, MLPFWD, NETERR, QUASINEW
12 %
13
14 % Copyright (c) Ian T Nabney (1996-2001)
15
16
17 % Set up some figure parameters
18 AxisShift = 0.05;
19 ClassSymbol1 = 'r.';
20 ClassSymbol2 = 'y.';
21 PointSize = 12;
22 titleSize = 10;
23
24 % Fix the seeds
25 rand('state', 423);
26 randn('state', 423);
27
28 clc
29 disp('This demonstration shows how an MLP with logistic outputs and')
30 disp('and cross entropy error function can be trained to model the')
31 disp('posterior class probabilities in a classification problem.')
32 disp('The results are compared with the optimal Bayes rule classifier,')
33 disp('which can be computed exactly as we know the form of the generating')
34 disp('distribution.')
35 disp(' ')
36 disp('Press any key to continue.')
37 pause
38
39 fh1 = figure;
40 set(fh1, 'Name', 'True Data Distribution');
41 whitebg(fh1, 'k');
42
43 %
44 % Generate the data
45 %
46 n=200;
47
48 % Set up mixture model: 2d data with three centres
49 % Class 1 is first centre, class 2 from the other two
50 mix = gmm(2, 3, 'full');
51 mix.priors = [0.5 0.25 0.25];
52 mix.centres = [0 -0.1; 1 1; 1 -1];
53 mix.covars(:,:,1) = [0.625 -0.2165; -0.2165 0.875];
54 mix.covars(:,:,2) = [0.2241 -0.1368; -0.1368 0.9759];
55 mix.covars(:,:,3) = [0.2375 0.1516; 0.1516 0.4125];
56
57 [data, label] = gmmsamp(mix, n);
58
59 %
60 % Calculate some useful axis limits
61 %
62 x0 = min(data(:,1));
63 x1 = max(data(:,1));
64 y0 = min(data(:,2));
65 y1 = max(data(:,2));
66 dx = x1-x0;
67 dy = y1-y0;
68 expand = 5/100; % Add on 5 percent each way
69 x0 = x0 - dx*expand;
70 x1 = x1 + dx*expand;
71 y0 = y0 - dy*expand;
72 y1 = y1 + dy*expand;
73 resolution = 100;
74 step = dx/resolution;
75 xrange = [x0:step:x1];
76 yrange = [y0:step:y1];
77 %
78 % Generate the grid
79 %
80 [X Y]=meshgrid([x0:step:x1],[y0:step:y1]);
81 %
82 % Calculate the class conditional densities, the unconditional densities and
83 % the posterior probabilities
84 %
85 px_j = gmmactiv(mix, [X(:) Y(:)]);
86 px = reshape(px_j*(mix.priors)',size(X));
87 post = gmmpost(mix, [X(:) Y(:)]);
88 p1_x = reshape(post(:, 1), size(X));
89 p2_x = reshape(post(:, 2) + post(:, 3), size(X));
90
91 %
92 % Generate some pretty pictures !!
93 %
94 colormap(hot)
95 colorbar
96 subplot(1,2,1)
97 hold on
98 plot(data((label==1),1),data(label==1,2),ClassSymbol1, 'MarkerSize', PointSize)
99 plot(data((label>1),1),data(label>1,2),ClassSymbol2, 'MarkerSize', PointSize)
100 contour(xrange,yrange,p1_x,[0.5 0.5],'w-');
101 axis([x0 x1 y0 y1])
102 set(gca,'Box','On')
103 title('The Sampled Data');
104 rect=get(gca,'Position');
105 rect(1)=rect(1)-AxisShift;
106 rect(3)=rect(3)+AxisShift;
107 set(gca,'Position',rect)
108 hold off
109
110 subplot(1,2,2)
111 imagesc(X(:),Y(:),px);
112 hold on
113 [cB, hB] = contour(xrange,yrange,p1_x,[0.5 0.5],'w:');
114 set(hB,'LineWidth', 2);
115 axis([x0 x1 y0 y1])
116 set(gca,'YDir','normal')
117 title('Probability Density p(x)')
118 hold off
119
120 drawnow;
121 clc;
122 disp('The first figure shows the data sampled from a mixture of three')
123 disp('Gaussians, the first of which (whose centre is near the origin) is')
124 disp('labelled red and the other two are labelled yellow. The second plot')
125 disp('shows the unconditional density of the data with the optimal Bayesian')
126 disp('decision boundary superimposed.')
127 disp(' ')
128 disp('Press any key to continue.')
129 pause
130 fh2 = figure;
131 set(fh2, 'Name', 'Class-conditional Densities and Posterior Probabilities');
132 whitebg(fh2, 'w');
133
134 subplot(2,2,1)
135 p1=reshape(px_j(:,1),size(X));
136 imagesc(X(:),Y(:),p1);
137 colormap hot
138 colorbar
139 axis(axis)
140 set(gca,'YDir','normal')
141 hold on
142 plot(mix.centres(:,1),mix.centres(:,2),'b+','MarkerSize',8,'LineWidth',2)
143 title('Density p(x|red)')
144 hold off
145
146 subplot(2,2,2)
147 p2=reshape((px_j(:,2)+px_j(:,3)),size(X));
148 imagesc(X(:),Y(:),p2);
149 colorbar
150 set(gca,'YDir','normal')
151 hold on
152 plot(mix.centres(:,1),mix.centres(:,2),'b+','MarkerSize',8,'LineWidth',2)
153 title('Density p(x|yellow)')
154 hold off
155
156 subplot(2,2,3)
157 imagesc(X(:),Y(:),p1_x);
158 set(gca,'YDir','normal')
159 colorbar
160 title('Posterior Probability p(red|x)')
161 hold on
162 plot(mix.centres(:,1),mix.centres(:,2),'b+','MarkerSize',8,'LineWidth',2)
163 hold off
164
165 subplot(2,2,4)
166 imagesc(X(:),Y(:),p2_x);
167 set(gca,'YDir','normal')
168 colorbar
169 title('Posterior Probability p(yellow|x)')
170 hold on
171 plot(mix.centres(:,1),mix.centres(:,2),'b+','MarkerSize',8,'LineWidth',2)
172 hold off
173
174 % Now set up and train the MLP
175 nhidden=6;
176 nout=1;
177 alpha = 0.2; % Weight decay
178 ncycles = 60; % Number of training cycles.
179 % Set up MLP network
180 net = mlp(2, nhidden, nout, 'logistic', alpha);
181 options = zeros(1,18);
182 options(1) = 1; % Print out error values
183 options(14) = ncycles;
184
185 mlpstring = ['We now set up an MLP with ', num2str(nhidden), ...
186 ' hidden units, logistic output and cross'];
187 trainstring = ['entropy error function, and train it for ', ...
188 num2str(ncycles), ' cycles using the'];
189 wdstring = ['quasi-Newton optimisation algorithm with weight decay of ', ...
190 num2str(alpha), '.'];
191
192 % Force out the figure before training the MLP
193 drawnow;
194 disp(' ')
195 disp('The second figure shows the class conditional densities and posterior')
196 disp('probabilities for each class. The blue crosses mark the centres of')
197 disp('the three Gaussians.')
198 disp(' ')
199 disp(mlpstring)
200 disp(trainstring)
201 disp(wdstring)
202 disp(' ')
203 disp('Press any key to continue.')
204 pause
205
206 % Convert targets to 0-1 encoding
207 target=[label==1];
208
209 % Train using quasi-Newton.
210 [net] = netopt(net, options, data, target, 'quasinew');
211 y = mlpfwd(net, data);
212 yg = mlpfwd(net, [X(:) Y(:)]);
213 yg = reshape(yg(:,1),size(X));
214
215 fh3 = figure;
216 set(fh3, 'Name', 'Network Output');
217 whitebg(fh3, 'k')
218 subplot(1, 2, 1)
219 hold on
220 plot(data((label==1),1),data(label==1,2),'r.', 'MarkerSize', PointSize)
221 plot(data((label>1),1),data(label>1,2),'y.', 'MarkerSize', PointSize)
222 % Bayesian decision boundary
223 [cB, hB] = contour(xrange,yrange,p1_x,[0.5 0.5],'b-');
224 [cN, hN] = contour(xrange,yrange,yg,[0.5 0.5],'r-');
225 set(hB, 'LineWidth', 2);
226 set(hN, 'LineWidth', 2);
227 Chandles = [hB(1) hN(1)];
228 legend(Chandles, 'Bayes', ...
229 'Network', 3);
230
231 axis([x0 x1 y0 y1])
232 set(gca,'Box','on','XTick',[],'YTick',[])
233
234 title('Training Data','FontSize',titleSize);
235 hold off
236
237 subplot(1, 2, 2)
238 imagesc(X(:),Y(:),yg);
239 colormap hot
240 colorbar
241 axis(axis)
242 set(gca,'YDir','normal','XTick',[],'YTick',[])
243 title('Network Output','FontSize',titleSize)
244
245 clc
246 disp('This figure shows the training data with the decision boundary')
247 disp('produced by the trained network and the network''s prediction of')
248 disp('the posterior probability of the red class.')
249 disp(' ')
250 disp('Press any key to continue.')
251 pause
252
253 %
254 % Now generate and classify a test data set
255 %
256 [testdata testlabel] = gmmsamp(mix, n);
257 testlab=[testlabel==1 testlabel>1];
258
259 % This is the Bayesian classification
260 tpx_j = gmmpost(mix, testdata);
261 Bpost = [tpx_j(:,1), tpx_j(:,2)+tpx_j(:,3)];
262 [Bcon Brate]=confmat(Bpost, [testlabel==1 testlabel>1]);
263
264 % Compute network classification
265 yt = mlpfwd(net, testdata);
266 % Convert single output to posteriors for both classes
267 testpost = [yt 1-yt];
268 [C trate]=confmat(testpost,[testlabel==1 testlabel>1]);
269
270 fh4 = figure;
271 set(fh4, 'Name', 'Decision Boundaries');
272 whitebg(fh4, 'k');
273 hold on
274 plot(testdata((testlabel==1),1),testdata((testlabel==1),2),...
275 ClassSymbol1, 'MarkerSize', PointSize)
276 plot(testdata((testlabel>1),1),testdata((testlabel>1),2),...
277 ClassSymbol2, 'MarkerSize', PointSize)
278 % Bayesian decision boundary
279 [cB, hB] = contour(xrange,yrange,p1_x,[0.5 0.5],'b-');
280 set(hB, 'LineWidth', 2);
281 % Network decision boundary
282 [cN, hN] = contour(xrange,yrange,yg,[0.5 0.5],'r-');
283 set(hN, 'LineWidth', 2);
284 Chandles = [hB(1) hN(1)];
285 legend(Chandles, 'Bayes decision boundary', ...
286 'Network decision boundary', -1);
287 axis([x0 x1 y0 y1])
288 title('Test Data')
289 set(gca,'Box','On','Xtick',[],'YTick',[])
290
291 clc
292 disp('This figure shows the test data with the decision boundary')
293 disp('produced by the trained network and the optimal Bayes rule.')
294 disp(' ')
295 disp('Press any key to continue.')
296 pause
297
298 fh5 = figure;
299 set(fh5, 'Name', 'Test Set Performance');
300 whitebg(fh5, 'w');
301 % Bayes rule performance
302 subplot(1,2,1)
303 plotmat(Bcon,'b','k',12)
304 set(gca,'XTick',[0.5 1.5])
305 set(gca,'YTick',[0.5 1.5])
306 grid('off')
307 set(gca,'XTickLabel',['Red ' ; 'Yellow'])
308 set(gca,'YTickLabel',['Yellow' ; 'Red '])
309 ylabel('True')
310 xlabel('Predicted')
311 title(['Bayes Confusion Matrix (' num2str(Brate(1)) '%)'])
312
313 % Network performance
314 subplot(1,2, 2)
315 plotmat(C,'b','k',12)
316 set(gca,'XTick',[0.5 1.5])
317 set(gca,'YTick',[0.5 1.5])
318 grid('off')
319 set(gca,'XTickLabel',['Red ' ; 'Yellow'])
320 set(gca,'YTickLabel',['Yellow' ; 'Red '])
321 ylabel('True')
322 xlabel('Predicted')
323 title(['Network Confusion Matrix (' num2str(trate(1)) '%)'])
324
325 disp('The final figure shows the confusion matrices for the')
326 disp('two rules on the test set.')
327 disp(' ')
328 disp('Press any key to exit.')
329 pause
330 whitebg(fh1, 'w');
331 whitebg(fh2, 'w');
332 whitebg(fh3, 'w');
333 whitebg(fh4, 'w');
334 whitebg(fh5, 'w');
335 close(fh1); close(fh2); close(fh3);
336 close(fh4); close(fh5);
337 clear all;