Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/demev2.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 %DEMEV2 Demonstrate Bayesian classification for the MLP. | |
2 % | |
3 % Description | |
4 % A synthetic two class two-dimensional dataset X is sampled from a | |
5 % mixture of four Gaussians. Each class is associated with two of the | |
6 % Gaussians so that the optimal decision boundary is non-linear. A 2- | |
7 % layer network with logistic outputs is trained by minimizing the | |
8 % cross-entropy error function with isotroipc Gaussian regularizer (one | |
9 % hyperparameter for each of the four standard weight groups), using | |
10 % the scaled conjugate gradient optimizer. The hyperparameter vectors | |
11 % ALPHA and BETA are re-estimated using the function EVIDENCE. A graph | |
12 % is plotted of the optimal, regularised, and unregularised decision | |
13 % boundaries. A further plot of the moderated versus unmoderated | |
14 % contours is generated. | |
15 % | |
16 % See also | |
17 % EVIDENCE, MLP, SCG, DEMARD, DEMMLP2 | |
18 % | |
19 | |
20 % Copyright (c) Ian T Nabney (1996-2001) | |
21 | |
22 | |
23 clc; | |
24 | |
25 disp('This program demonstrates the use of the evidence procedure on') | |
26 disp('a two-class problem. It also shows the improved generalisation') | |
27 disp('performance that can be achieved with moderated outputs; that is') | |
28 disp('predictions where an approximate integration over the true') | |
29 disp('posterior distribution is carried out.') | |
30 disp(' ') | |
31 disp('First we generate a synthetic dataset with two-dimensional input') | |
32 disp('sampled from a mixture of four Gaussians. Each class is') | |
33 disp('associated with two of the Gaussians so that the optimal decision') | |
34 disp('boundary is non-linear.') | |
35 disp(' ') | |
36 disp('Press any key to see a plot of the data.') | |
37 pause; | |
38 | |
39 % Generate the matrix of inputs x and targets t. | |
40 | |
41 rand('state', 423); | |
42 randn('state', 423); | |
43 | |
44 ClassSymbol1 = 'r.'; | |
45 ClassSymbol2 = 'y.'; | |
46 PointSize = 12; | |
47 titleSize = 10; | |
48 | |
49 fh1 = figure; | |
50 set(fh1, 'Name', 'True Data Distribution'); | |
51 whitebg(fh1, 'k'); | |
52 | |
53 % | |
54 % Generate the data | |
55 % | |
56 n=200; | |
57 | |
58 % Set up mixture model: 2d data with four centres | |
59 % Class 1 is first two centres, class 2 from the other two | |
60 mix = gmm(2, 4, 'full'); | |
61 mix.priors = [0.25 0.25 0.25 0.25]; | |
62 mix.centres = [0 -0.1; 1.5 0; 1 1; 1 -1]; | |
63 mix.covars(:,:,1) = [0.625 -0.2165; -0.2165 0.875]; | |
64 mix.covars(:,:,2) = [0.25 0; 0 0.25]; | |
65 mix.covars(:,:,3) = [0.2241 -0.1368; -0.1368 0.9759]; | |
66 mix.covars(:,:,4) = [0.2375 0.1516; 0.1516 0.4125]; | |
67 | |
68 [data, label] = gmmsamp(mix, n); | |
69 | |
70 % | |
71 % Calculate some useful axis limits | |
72 % | |
73 x0 = min(data(:,1)); | |
74 x1 = max(data(:,1)); | |
75 y0 = min(data(:,2)); | |
76 y1 = max(data(:,2)); | |
77 dx = x1-x0; | |
78 dy = y1-y0; | |
79 expand = 5/100; % Add on 5 percent each way | |
80 x0 = x0 - dx*expand; | |
81 x1 = x1 + dx*expand; | |
82 y0 = y0 - dy*expand; | |
83 y1 = y1 + dy*expand; | |
84 resolution = 100; | |
85 step = dx/resolution; | |
86 xrange = [x0:step:x1]; | |
87 yrange = [y0:step:y1]; | |
88 % | |
89 % Generate the grid | |
90 % | |
91 [X Y]=meshgrid([x0:step:x1],[y0:step:y1]); | |
92 % | |
93 % Calculate the class conditional densities, the unconditional densities and | |
94 % the posterior probabilities | |
95 % | |
96 px_j = gmmactiv(mix, [X(:) Y(:)]); | |
97 px = reshape(px_j*(mix.priors)',size(X)); | |
98 post = gmmpost(mix, [X(:) Y(:)]); | |
99 p1_x = reshape(post(:, 1) + post(:, 2), size(X)); | |
100 p2_x = reshape(post(:, 3) + post(:, 4), size(X)); | |
101 | |
102 plot(data((label<=2),1),data(label<=2,2),ClassSymbol1, 'MarkerSize', ... | |
103 PointSize) | |
104 hold on | |
105 axis([x0 x1 y0 y1]) | |
106 plot(data((label>2),1),data(label>2,2),ClassSymbol2, 'MarkerSize', ... | |
107 PointSize) | |
108 | |
109 % Convert targets to 0-1 encoding | |
110 target=[label<=2]; | |
111 disp(' ') | |
112 disp('Press any key to continue') | |
113 pause; clc; | |
114 | |
115 disp('Next we create a two-layer MLP network with 6 hidden units and') | |
116 disp('one logistic output. We use a separate inverse variance') | |
117 disp('hyperparameter for each group of weights (inputs, input bias,') | |
118 disp('outputs, output bias) and the weights are optimised with the') | |
119 disp('scaled conjugate gradient algorithm. After each 100 iterations') | |
120 disp('the hyperparameters are re-estimated twice. There are eight') | |
121 disp('cycles of the whole algorithm.') | |
122 disp(' ') | |
123 disp('Press any key to train the network and determine the hyperparameters.') | |
124 pause; | |
125 | |
126 % Set up network parameters. | |
127 nin = 2; % Number of inputs. | |
128 nhidden = 6; % Number of hidden units. | |
129 nout = 1; % Number of outputs. | |
130 alpha = 0.01; % Initial prior hyperparameter. | |
131 aw1 = 0.01; | |
132 ab1 = 0.01; | |
133 aw2 = 0.01; | |
134 ab2 = 0.01; | |
135 | |
136 % Create and initialize network weight vector. | |
137 prior = mlpprior(nin, nhidden, nout, aw1, ab1, aw2, ab2); | |
138 net = mlp(nin, nhidden, nout, 'logistic', prior); | |
139 | |
140 % Set up vector of options for the optimiser. | |
141 nouter = 8; % Number of outer loops. | |
142 ninner = 2; % Number of innter loops. | |
143 options = foptions; % Default options vector. | |
144 options(1) = 1; % This provides display of error values. | |
145 options(2) = 1.0e-5; % Absolute precision for weights. | |
146 options(3) = 1.0e-5; % Precision for objective function. | |
147 options(14) = 100; % Number of training cycles in inner loop. | |
148 | |
149 % Train using scaled conjugate gradients, re-estimating alpha and beta. | |
150 for k = 1:nouter | |
151 net = netopt(net, options, data, target, 'scg'); | |
152 [net, gamma] = evidence(net, data, target, ninner); | |
153 fprintf(1, '\nRe-estimation cycle %d:\n', k); | |
154 disp([' alpha = ', num2str(net.alpha')]); | |
155 fprintf(1, ' gamma = %8.5f\n\n', gamma); | |
156 disp(' ') | |
157 disp('Press any key to continue.') | |
158 pause; | |
159 end | |
160 | |
161 disp(' ') | |
162 disp('Network training and hyperparameter re-estimation are now complete.') | |
163 disp('Notice that the final error value is close to the number of data') | |
164 disp(['points (', num2str(n), ') divided by two.']) | |
165 disp('Also, the hyperparameter values differ, which suggests that a single') | |
166 disp('hyperparameter would not be so effective.') | |
167 disp(' ') | |
168 disp('First we train an MLP without Bayesian regularisation on the') | |
169 disp('same dataset using 400 iterations of scaled conjugate gradient') | |
170 disp(' ') | |
171 disp('Press any key to train the network by maximum likelihood.') | |
172 pause; | |
173 % Train standard network | |
174 net2 = mlp(nin, nhidden, nout, 'logistic'); | |
175 options(14) = 400; | |
176 net2 = netopt(net2, options, data, target, 'scg'); | |
177 y2g = mlpfwd(net2, [X(:), Y(:)]); | |
178 y2g = reshape(y2g(:, 1), size(X)); | |
179 | |
180 disp(' ') | |
181 disp('We can now plot the function represented by the trained networks.') | |
182 disp('We show the decision boundaries (output = 0.5) and the optimal') | |
183 disp('decision boundary given by applying Bayes'' theorem to the true') | |
184 disp('data model.') | |
185 disp(' ') | |
186 disp('Press any key to add the boundaries to the plot.') | |
187 pause; | |
188 | |
189 % Evaluate predictions. | |
190 [yg, ymodg] = mlpevfwd(net, data, target, [X(:) Y(:)]); | |
191 yg = reshape(yg(:,1),size(X)); | |
192 ymodg = reshape(ymodg(:,1),size(X)); | |
193 | |
194 % Bayesian decision boundary | |
195 [cB, hB] = contour(xrange,yrange,p1_x,[0.5 0.5],'b-'); | |
196 [cNb, hNb] = contour(xrange,yrange,yg,[0.5 0.5],'r-'); | |
197 [cN, hN] = contour(xrange,yrange,y2g,[0.5 0.5],'g-'); | |
198 set(hB, 'LineWidth', 2); | |
199 set(hNb, 'LineWidth', 2); | |
200 set(hN, 'LineWidth', 2); | |
201 Chandles = [hB(1) hNb(1) hN(1)]; | |
202 legend(Chandles, 'Bayes', ... | |
203 'Reg. Network', 'Network', 3); | |
204 | |
205 disp(' ') | |
206 disp('Note how the regularised network predictions are closer to the') | |
207 disp('optimal decision boundary, while the unregularised network is') | |
208 disp('overtrained.') | |
209 | |
210 disp(' ') | |
211 disp('We will now compare moderated and unmoderated outputs for the'); | |
212 disp('regularised network by showing the contour plot of the posterior') | |
213 disp('probability estimates.') | |
214 disp(' ') | |
215 disp('The first plot shows the regularised (moderated) predictions') | |
216 disp('and the second shows the standard predictions from the same network.') | |
217 disp('These agree at the level 0.5.') | |
218 disp('Press any key to continue') | |
219 pause | |
220 levels = 0:0.1:1; | |
221 fh4 = figure; | |
222 set(fh4, 'Name', 'Moderated outputs'); | |
223 hold on | |
224 plot(data((label<=2),1),data(label<=2,2),'r.', 'MarkerSize', PointSize) | |
225 plot(data((label>2),1),data(label>2,2),'y.', 'MarkerSize', PointSize) | |
226 | |
227 [cNby, hNby] = contour(xrange, yrange, ymodg, levels, 'k-'); | |
228 set(hNby, 'LineWidth', 1); | |
229 | |
230 fh5 = figure; | |
231 set(fh5, 'Name', 'Unmoderated outputs'); | |
232 hold on | |
233 plot(data((label<=2),1),data(label<=2,2),'r.', 'MarkerSize', PointSize) | |
234 plot(data((label>2),1),data(label>2,2),'y.', 'MarkerSize', PointSize) | |
235 | |
236 [cNbm, hNbm] = contour(xrange, yrange, yg, levels, 'k-'); | |
237 set(hNbm, 'LineWidth', 1); | |
238 | |
239 disp(' ') | |
240 disp('Note how the moderated contours are more widely spaced. This shows') | |
241 disp('that there is a larger region where the outputs are close to 0.5') | |
242 disp('and a smaller region where the outputs are close to 0 or 1.') | |
243 disp(' ') | |
244 disp('Press any key to exit') | |
245 pause | |
246 close(fh1); | |
247 close(fh4); | |
248 close(fh5); |