comparison toolboxes/FullBNT-1.0.7/netlab3.3/demev2.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMEV2 Demonstrate Bayesian classification for the MLP.
2 %
3 % Description
4 % A synthetic two class two-dimensional dataset X is sampled from a
5 % mixture of four Gaussians. Each class is associated with two of the
6 % Gaussians so that the optimal decision boundary is non-linear. A 2-
7 % layer network with logistic outputs is trained by minimizing the
8 % cross-entropy error function with isotroipc Gaussian regularizer (one
9 % hyperparameter for each of the four standard weight groups), using
10 % the scaled conjugate gradient optimizer. The hyperparameter vectors
11 % ALPHA and BETA are re-estimated using the function EVIDENCE. A graph
12 % is plotted of the optimal, regularised, and unregularised decision
13 % boundaries. A further plot of the moderated versus unmoderated
14 % contours is generated.
15 %
16 % See also
17 % EVIDENCE, MLP, SCG, DEMARD, DEMMLP2
18 %
19
20 % Copyright (c) Ian T Nabney (1996-2001)
21
22
23 clc;
24
25 disp('This program demonstrates the use of the evidence procedure on')
26 disp('a two-class problem. It also shows the improved generalisation')
27 disp('performance that can be achieved with moderated outputs; that is')
28 disp('predictions where an approximate integration over the true')
29 disp('posterior distribution is carried out.')
30 disp(' ')
31 disp('First we generate a synthetic dataset with two-dimensional input')
32 disp('sampled from a mixture of four Gaussians. Each class is')
33 disp('associated with two of the Gaussians so that the optimal decision')
34 disp('boundary is non-linear.')
35 disp(' ')
36 disp('Press any key to see a plot of the data.')
37 pause;
38
39 % Generate the matrix of inputs x and targets t.
40
41 rand('state', 423);
42 randn('state', 423);
43
44 ClassSymbol1 = 'r.';
45 ClassSymbol2 = 'y.';
46 PointSize = 12;
47 titleSize = 10;
48
49 fh1 = figure;
50 set(fh1, 'Name', 'True Data Distribution');
51 whitebg(fh1, 'k');
52
53 %
54 % Generate the data
55 %
56 n=200;
57
58 % Set up mixture model: 2d data with four centres
59 % Class 1 is first two centres, class 2 from the other two
60 mix = gmm(2, 4, 'full');
61 mix.priors = [0.25 0.25 0.25 0.25];
62 mix.centres = [0 -0.1; 1.5 0; 1 1; 1 -1];
63 mix.covars(:,:,1) = [0.625 -0.2165; -0.2165 0.875];
64 mix.covars(:,:,2) = [0.25 0; 0 0.25];
65 mix.covars(:,:,3) = [0.2241 -0.1368; -0.1368 0.9759];
66 mix.covars(:,:,4) = [0.2375 0.1516; 0.1516 0.4125];
67
68 [data, label] = gmmsamp(mix, n);
69
70 %
71 % Calculate some useful axis limits
72 %
73 x0 = min(data(:,1));
74 x1 = max(data(:,1));
75 y0 = min(data(:,2));
76 y1 = max(data(:,2));
77 dx = x1-x0;
78 dy = y1-y0;
79 expand = 5/100; % Add on 5 percent each way
80 x0 = x0 - dx*expand;
81 x1 = x1 + dx*expand;
82 y0 = y0 - dy*expand;
83 y1 = y1 + dy*expand;
84 resolution = 100;
85 step = dx/resolution;
86 xrange = [x0:step:x1];
87 yrange = [y0:step:y1];
88 %
89 % Generate the grid
90 %
91 [X Y]=meshgrid([x0:step:x1],[y0:step:y1]);
92 %
93 % Calculate the class conditional densities, the unconditional densities and
94 % the posterior probabilities
95 %
96 px_j = gmmactiv(mix, [X(:) Y(:)]);
97 px = reshape(px_j*(mix.priors)',size(X));
98 post = gmmpost(mix, [X(:) Y(:)]);
99 p1_x = reshape(post(:, 1) + post(:, 2), size(X));
100 p2_x = reshape(post(:, 3) + post(:, 4), size(X));
101
102 plot(data((label<=2),1),data(label<=2,2),ClassSymbol1, 'MarkerSize', ...
103 PointSize)
104 hold on
105 axis([x0 x1 y0 y1])
106 plot(data((label>2),1),data(label>2,2),ClassSymbol2, 'MarkerSize', ...
107 PointSize)
108
109 % Convert targets to 0-1 encoding
110 target=[label<=2];
111 disp(' ')
112 disp('Press any key to continue')
113 pause; clc;
114
115 disp('Next we create a two-layer MLP network with 6 hidden units and')
116 disp('one logistic output. We use a separate inverse variance')
117 disp('hyperparameter for each group of weights (inputs, input bias,')
118 disp('outputs, output bias) and the weights are optimised with the')
119 disp('scaled conjugate gradient algorithm. After each 100 iterations')
120 disp('the hyperparameters are re-estimated twice. There are eight')
121 disp('cycles of the whole algorithm.')
122 disp(' ')
123 disp('Press any key to train the network and determine the hyperparameters.')
124 pause;
125
126 % Set up network parameters.
127 nin = 2; % Number of inputs.
128 nhidden = 6; % Number of hidden units.
129 nout = 1; % Number of outputs.
130 alpha = 0.01; % Initial prior hyperparameter.
131 aw1 = 0.01;
132 ab1 = 0.01;
133 aw2 = 0.01;
134 ab2 = 0.01;
135
136 % Create and initialize network weight vector.
137 prior = mlpprior(nin, nhidden, nout, aw1, ab1, aw2, ab2);
138 net = mlp(nin, nhidden, nout, 'logistic', prior);
139
140 % Set up vector of options for the optimiser.
141 nouter = 8; % Number of outer loops.
142 ninner = 2; % Number of innter loops.
143 options = foptions; % Default options vector.
144 options(1) = 1; % This provides display of error values.
145 options(2) = 1.0e-5; % Absolute precision for weights.
146 options(3) = 1.0e-5; % Precision for objective function.
147 options(14) = 100; % Number of training cycles in inner loop.
148
149 % Train using scaled conjugate gradients, re-estimating alpha and beta.
150 for k = 1:nouter
151 net = netopt(net, options, data, target, 'scg');
152 [net, gamma] = evidence(net, data, target, ninner);
153 fprintf(1, '\nRe-estimation cycle %d:\n', k);
154 disp([' alpha = ', num2str(net.alpha')]);
155 fprintf(1, ' gamma = %8.5f\n\n', gamma);
156 disp(' ')
157 disp('Press any key to continue.')
158 pause;
159 end
160
161 disp(' ')
162 disp('Network training and hyperparameter re-estimation are now complete.')
163 disp('Notice that the final error value is close to the number of data')
164 disp(['points (', num2str(n), ') divided by two.'])
165 disp('Also, the hyperparameter values differ, which suggests that a single')
166 disp('hyperparameter would not be so effective.')
167 disp(' ')
168 disp('First we train an MLP without Bayesian regularisation on the')
169 disp('same dataset using 400 iterations of scaled conjugate gradient')
170 disp(' ')
171 disp('Press any key to train the network by maximum likelihood.')
172 pause;
173 % Train standard network
174 net2 = mlp(nin, nhidden, nout, 'logistic');
175 options(14) = 400;
176 net2 = netopt(net2, options, data, target, 'scg');
177 y2g = mlpfwd(net2, [X(:), Y(:)]);
178 y2g = reshape(y2g(:, 1), size(X));
179
180 disp(' ')
181 disp('We can now plot the function represented by the trained networks.')
182 disp('We show the decision boundaries (output = 0.5) and the optimal')
183 disp('decision boundary given by applying Bayes'' theorem to the true')
184 disp('data model.')
185 disp(' ')
186 disp('Press any key to add the boundaries to the plot.')
187 pause;
188
189 % Evaluate predictions.
190 [yg, ymodg] = mlpevfwd(net, data, target, [X(:) Y(:)]);
191 yg = reshape(yg(:,1),size(X));
192 ymodg = reshape(ymodg(:,1),size(X));
193
194 % Bayesian decision boundary
195 [cB, hB] = contour(xrange,yrange,p1_x,[0.5 0.5],'b-');
196 [cNb, hNb] = contour(xrange,yrange,yg,[0.5 0.5],'r-');
197 [cN, hN] = contour(xrange,yrange,y2g,[0.5 0.5],'g-');
198 set(hB, 'LineWidth', 2);
199 set(hNb, 'LineWidth', 2);
200 set(hN, 'LineWidth', 2);
201 Chandles = [hB(1) hNb(1) hN(1)];
202 legend(Chandles, 'Bayes', ...
203 'Reg. Network', 'Network', 3);
204
205 disp(' ')
206 disp('Note how the regularised network predictions are closer to the')
207 disp('optimal decision boundary, while the unregularised network is')
208 disp('overtrained.')
209
210 disp(' ')
211 disp('We will now compare moderated and unmoderated outputs for the');
212 disp('regularised network by showing the contour plot of the posterior')
213 disp('probability estimates.')
214 disp(' ')
215 disp('The first plot shows the regularised (moderated) predictions')
216 disp('and the second shows the standard predictions from the same network.')
217 disp('These agree at the level 0.5.')
218 disp('Press any key to continue')
219 pause
220 levels = 0:0.1:1;
221 fh4 = figure;
222 set(fh4, 'Name', 'Moderated outputs');
223 hold on
224 plot(data((label<=2),1),data(label<=2,2),'r.', 'MarkerSize', PointSize)
225 plot(data((label>2),1),data(label>2,2),'y.', 'MarkerSize', PointSize)
226
227 [cNby, hNby] = contour(xrange, yrange, ymodg, levels, 'k-');
228 set(hNby, 'LineWidth', 1);
229
230 fh5 = figure;
231 set(fh5, 'Name', 'Unmoderated outputs');
232 hold on
233 plot(data((label<=2),1),data(label<=2,2),'r.', 'MarkerSize', PointSize)
234 plot(data((label>2),1),data(label>2,2),'y.', 'MarkerSize', PointSize)
235
236 [cNbm, hNbm] = contour(xrange, yrange, yg, levels, 'k-');
237 set(hNbm, 'LineWidth', 1);
238
239 disp(' ')
240 disp('Note how the moderated contours are more widely spaced. This shows')
241 disp('that there is a larger region where the outputs are close to 0.5')
242 disp('and a smaller region where the outputs are close to 0 or 1.')
243 disp(' ')
244 disp('Press any key to exit')
245 pause
246 close(fh1);
247 close(fh4);
248 close(fh5);