comparison toolboxes/FullBNT-1.0.7/netlab3.3/demev3.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMEV3 Demonstrate Bayesian regression for the RBF.
2 %
3 % Description
4 % The problem consists an input variable X which sampled from a
5 % Gaussian distribution, and a target variable T generated by computing
6 % SIN(2*PI*X) and adding Gaussian noise. An RBF network with linear
7 % outputs is trained by minimizing a sum-of-squares error function with
8 % isotropic Gaussian regularizer, using the scaled conjugate gradient
9 % optimizer. The hyperparameters ALPHA and BETA are re-estimated using
10 % the function EVIDENCE. A graph is plotted of the original function,
11 % the training data, the trained network function, and the error bars.
12 %
13 % See also
14 % DEMEV1, EVIDENCE, RBF, SCG, NETEVFWD
15 %
16
17 % Copyright (c) Ian T Nabney (1996-2001)
18
19 clc;
20 disp('This demonstration illustrates the application of Bayesian')
21 disp('re-estimation to determine the hyperparameters in a simple regression')
22 disp('problem using an RBF netowk. It is based on a the fact that the')
23 disp('posterior distribution for the output weights of an RBF is Gaussian')
24 disp('and uses the evidence maximization framework of MacKay.')
25 disp(' ')
26 disp('First, we generate a synthetic data set consisting of a single input')
27 disp('variable x sampled from a Gaussian distribution, and a target variable')
28 disp('t obtained by evaluating sin(2*pi*x) and adding Gaussian noise.')
29 disp(' ')
30 disp('Press any key to see a plot of the data together with the sine function.')
31 pause;
32
33 % Generate the matrix of inputs x and targets t.
34
35 ndata = 16; % Number of data points.
36 noise = 0.1; % Standard deviation of noise distribution.
37 randn('state', 0);
38 rand('state', 0);
39 x = 0.25 + 0.07*randn(ndata, 1);
40 t = sin(2*pi*x) + noise*randn(size(x));
41
42 % Plot the data and the original sine function.
43 h = figure;
44 nplot = 200;
45 plotvals = linspace(0, 1, nplot)';
46 plot(x, t, 'ok')
47 xlabel('Input')
48 ylabel('Target')
49 hold on
50 axis([0 1 -1.5 1.5])
51 fplot('sin(2*pi*x)', [0 1], '-g')
52 legend('data', 'function');
53
54 disp(' ')
55 disp('Press any key to continue')
56 pause; clc;
57
58 disp('Next we create a two-layer MLP network having 3 hidden units and one')
59 disp('linear output. The model assumes Gaussian target noise governed by an')
60 disp('inverse variance hyperparmeter beta, and uses a simple Gaussian prior')
61 disp('distribution governed by an inverse variance hyperparameter alpha.')
62 disp(' ');
63 disp('The network weights and the hyperparameters are initialised and then')
64 disp('the output layer weights are optimized with the scaled conjugate gradient')
65 disp('algorithm using the SCG function, with the hyperparameters kept')
66 disp('fixed. After a maximum of 50 iterations, the hyperparameters are')
67 disp('re-estimated using the EVIDENCE function. The process of optimizing')
68 disp('the weights with fixed hyperparameters and then re-estimating the')
69 disp('hyperparameters is repeated for a total of 3 cycles.')
70 disp(' ')
71 disp('Press any key to train the network and determine the hyperparameters.')
72 pause;
73
74 % Set up network parameters.
75 nin = 1; % Number of inputs.
76 nhidden = 3; % Number of hidden units.
77 nout = 1; % Number of outputs.
78 alpha = 0.01; % Initial prior hyperparameter.
79 beta_init = 50.0; % Initial noise hyperparameter.
80
81 % Create and initialize network weight vector.
82 net = rbf(nin, nhidden, nout, 'tps', 'linear', alpha, beta_init);
83 [net.mask, prior] = rbfprior('tps', nin, nhidden, nout, alpha, alpha);
84 net = netinit(net, prior);
85
86 options = foptions;
87 options(14) = 5; % At most 5 EM iterations for basis functions
88 options(1) = -1; % Turn off all messages
89 net = rbfsetbf(net, options, x); % Initialise the basis functions
90
91 % Now train the network
92 nouter = 5;
93 ninner = 2;
94 options = foptions;
95 options(1) = 1;
96 options(2) = 1.0e-5; % Absolute precision for weights.
97 options(3) = 1.0e-5; % Precision for objective function.
98 options(14) = 50; % Number of training cycles in inner loop.
99
100 % Train using scaled conjugate gradients, re-estimating alpha and beta.
101 for k = 1:nouter
102 net = netopt(net, options, x, t, 'scg');
103 [net, gamma] = evidence(net, x, t, ninner);
104 fprintf(1, '\nRe-estimation cycle %d:\n', k);
105 fprintf(1, ' alpha = %8.5f\n', net.alpha);
106 fprintf(1, ' beta = %8.5f\n', net.beta);
107 fprintf(1, ' gamma = %8.5f\n\n', gamma);
108 disp(' ')
109 disp('Press any key to continue.')
110 pause;
111 end
112
113 fprintf(1, 'true beta: %f\n', 1/(noise*noise));
114
115 disp(' ')
116 disp('Network training and hyperparameter re-estimation are now complete.')
117 disp('Compare the final value for the hyperparameter beta with the true')
118 disp('value.')
119 disp(' ')
120 disp('Notice that the final error value is close to the number of data')
121 disp(['points (', num2str(ndata),') divided by two.'])
122 disp(' ')
123 disp('Press any key to continue.')
124 pause; clc;
125 disp('We can now plot the function represented by the trained network. This')
126 disp('corresponds to the mean of the predictive distribution. We can also')
127 disp('plot ''error bars'' representing one standard deviation of the')
128 disp('predictive distribution around the mean.')
129 disp(' ')
130 disp('Press any key to add the network function and error bars to the plot.')
131 pause;
132
133 % Evaluate error bars.
134 [y, sig2] = netevfwd(netpak(net), net, x, t, plotvals);
135 sig = sqrt(sig2);
136
137 % Plot the data, the original function, and the trained network function.
138 [y, z] = rbffwd(net, plotvals);
139 figure(h); hold on;
140 plot(plotvals, y, '-r')
141 xlabel('Input')
142 ylabel('Target')
143 plot(plotvals, y + sig, '-b');
144 plot(plotvals, y - sig, '-b');
145 legend('data', 'function', 'network', 'error bars');
146
147 disp(' ')
148 disp('Notice how the confidence interval spanned by the ''error bars'' is')
149 disp('smaller in the region of input space where the data density is high,')
150 disp('and becomes larger in regions away from the data.')
151 disp(' ')
152 disp('Press any key to end.')
153 pause; clc; close(h);
154