Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/demev1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 %DEMEV1 Demonstrate Bayesian regression for the MLP. | |
2 % | |
3 % Description | |
4 % The problem consists an input variable X which sampled from a | |
5 % Gaussian distribution, and a target variable T generated by computing | |
6 % SIN(2*PI*X) and adding Gaussian noise. A 2-layer network with linear | |
7 % outputs is trained by minimizing a sum-of-squares error function with | |
8 % isotropic Gaussian regularizer, using the scaled conjugate gradient | |
9 % optimizer. The hyperparameters ALPHA and BETA are re-estimated using | |
10 % the function EVIDENCE. A graph is plotted of the original function, | |
11 % the training data, the trained network function, and the error bars. | |
12 % | |
13 % See also | |
14 % EVIDENCE, MLP, SCG, DEMARD, DEMMLP1 | |
15 % | |
16 | |
17 % Copyright (c) Ian T Nabney (1996-2001) | |
18 | |
19 clc; | |
20 disp('This demonstration illustrates the application of Bayesian') | |
21 disp('re-estimation to determine the hyperparameters in a simple regression') | |
22 disp('problem. It is based on a local quadratic approximation to a mode of') | |
23 disp('the posterior distribution and the evidence maximization framework of') | |
24 disp('MacKay.') | |
25 disp(' ') | |
26 disp('First, we generate a synthetic data set consisting of a single input') | |
27 disp('variable x sampled from a Gaussian distribution, and a target variable') | |
28 disp('t obtained by evaluating sin(2*pi*x) and adding Gaussian noise.') | |
29 disp(' ') | |
30 disp('Press any key to see a plot of the data together with the sine function.') | |
31 pause; | |
32 | |
33 % Generate the matrix of inputs x and targets t. | |
34 | |
35 ndata = 16; % Number of data points. | |
36 noise = 0.1; % Standard deviation of noise distribution. | |
37 randn('state', 0); | |
38 x = 0.25 + 0.07*randn(ndata, 1); | |
39 t = sin(2*pi*x) + noise*randn(size(x)); | |
40 | |
41 % Plot the data and the original sine function. | |
42 h = figure; | |
43 nplot = 200; | |
44 plotvals = linspace(0, 1, nplot)'; | |
45 plot(x, t, 'ok') | |
46 xlabel('Input') | |
47 ylabel('Target') | |
48 hold on | |
49 axis([0 1 -1.5 1.5]) | |
50 fplot('sin(2*pi*x)', [0 1], '-g') | |
51 legend('data', 'function'); | |
52 | |
53 disp(' ') | |
54 disp('Press any key to continue') | |
55 pause; clc; | |
56 | |
57 disp('Next we create a two-layer MLP network having 3 hidden units and one') | |
58 disp('linear output. The model assumes Gaussian target noise governed by an') | |
59 disp('inverse variance hyperparmeter beta, and uses a simple Gaussian prior') | |
60 disp('distribution governed by an inverse variance hyperparameter alpha.') | |
61 disp(' '); | |
62 disp('The network weights and the hyperparameters are initialised and then') | |
63 disp('the weights are optimized with the scaled conjugate gradient') | |
64 disp('algorithm using the SCG function, with the hyperparameters kept') | |
65 disp('fixed. After a maximum of 500 iterations, the hyperparameters are') | |
66 disp('re-estimated using the EVIDENCE function. The process of optimizing') | |
67 disp('the weights with fixed hyperparameters and then re-estimating the') | |
68 disp('hyperparameters is repeated for a total of 3 cycles.') | |
69 disp(' ') | |
70 disp('Press any key to train the network and determine the hyperparameters.') | |
71 pause; | |
72 | |
73 % Set up network parameters. | |
74 nin = 1; % Number of inputs. | |
75 nhidden = 3; % Number of hidden units. | |
76 nout = 1; % Number of outputs. | |
77 alpha = 0.01; % Initial prior hyperparameter. | |
78 beta_init = 50.0; % Initial noise hyperparameter. | |
79 | |
80 % Create and initialize network weight vector. | |
81 net = mlp(nin, nhidden, nout, 'linear', alpha, beta_init); | |
82 | |
83 % Set up vector of options for the optimiser. | |
84 nouter = 3; % Number of outer loops. | |
85 ninner = 1; % Number of innter loops. | |
86 options = zeros(1,18); % Default options vector. | |
87 options(1) = 1; % This provides display of error values. | |
88 options(2) = 1.0e-7; % Absolute precision for weights. | |
89 options(3) = 1.0e-7; % Precision for objective function. | |
90 options(14) = 500; % Number of training cycles in inner loop. | |
91 | |
92 % Train using scaled conjugate gradients, re-estimating alpha and beta. | |
93 for k = 1:nouter | |
94 net = netopt(net, options, x, t, 'scg'); | |
95 [net, gamma] = evidence(net, x, t, ninner); | |
96 fprintf(1, '\nRe-estimation cycle %d:\n', k); | |
97 fprintf(1, ' alpha = %8.5f\n', net.alpha); | |
98 fprintf(1, ' beta = %8.5f\n', net.beta); | |
99 fprintf(1, ' gamma = %8.5f\n\n', gamma); | |
100 disp(' ') | |
101 disp('Press any key to continue.') | |
102 pause; | |
103 end | |
104 | |
105 fprintf(1, 'true beta: %f\n', 1/(noise*noise)); | |
106 | |
107 disp(' ') | |
108 disp('Network training and hyperparameter re-estimation are now complete.') | |
109 disp('Compare the final value for the hyperparameter beta with the true') | |
110 disp('value.') | |
111 disp(' ') | |
112 disp('Notice that the final error value is close to the number of data') | |
113 disp(['points (', num2str(ndata),') divided by two.']) | |
114 disp(' ') | |
115 disp('Press any key to continue.') | |
116 pause; clc; | |
117 disp('We can now plot the function represented by the trained network. This') | |
118 disp('corresponds to the mean of the predictive distribution. We can also') | |
119 disp('plot ''error bars'' representing one standard deviation of the') | |
120 disp('predictive distribution around the mean.') | |
121 disp(' ') | |
122 disp('Press any key to add the network function and error bars to the plot.') | |
123 pause; | |
124 | |
125 % Evaluate error bars. | |
126 [y, sig2] = netevfwd(mlppak(net), net, x, t, plotvals); | |
127 sig = sqrt(sig2); | |
128 | |
129 % Plot the data, the original function, and the trained network function. | |
130 [y, z] = mlpfwd(net, plotvals); | |
131 figure(h); hold on; | |
132 plot(plotvals, y, '-r') | |
133 xlabel('Input') | |
134 ylabel('Target') | |
135 plot(plotvals, y + sig, '-b'); | |
136 plot(plotvals, y - sig, '-b'); | |
137 legend('data', 'function', 'network', 'error bars'); | |
138 | |
139 disp(' ') | |
140 disp('Notice how the confidence interval spanned by the ''error bars'' is') | |
141 disp('smaller in the region of input space where the data density is high,') | |
142 disp('and becomes larger in regions away from the data.') | |
143 disp(' ') | |
144 disp('Press any key to end.') | |
145 pause; clc; close(h); | |
146 %clear all |