diff toolboxes/FullBNT-1.0.7/netlab3.3/demev3.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/netlab3.3/demev3.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,154 @@
+%DEMEV3	Demonstrate Bayesian regression for the RBF.
+%
+%	Description
+%	The problem consists an input variable X which sampled from a
+%	Gaussian distribution, and a target variable T generated by computing
+%	SIN(2*PI*X) and adding Gaussian noise. An RBF network with linear
+%	outputs is trained by minimizing a sum-of-squares error function with
+%	isotropic Gaussian regularizer, using the scaled conjugate gradient
+%	optimizer. The hyperparameters ALPHA and BETA are re-estimated using
+%	the function EVIDENCE. A graph  is plotted of the original function,
+%	the training data, the trained network function, and the error bars.
+%
+%	See also
+%	DEMEV1, EVIDENCE, RBF, SCG, NETEVFWD
+%
+
+%	Copyright (c) Ian T Nabney (1996-2001)
+
+clc;
+disp('This demonstration illustrates the application of Bayesian')
+disp('re-estimation to determine the hyperparameters in a simple regression')
+disp('problem using an RBF netowk. It is based on a the fact that the')
+disp('posterior distribution for the output weights of an RBF is Gaussian')
+disp('and uses the evidence maximization framework of MacKay.')
+disp(' ')
+disp('First, we generate a synthetic data set consisting of a single input')
+disp('variable x sampled from a Gaussian distribution, and a target variable')
+disp('t obtained by evaluating sin(2*pi*x) and adding Gaussian noise.')
+disp(' ')
+disp('Press any key to see a plot of the data together with the sine function.')
+pause;
+
+% Generate the matrix of inputs x and targets t.
+
+ndata = 16;			% Number of data points.
+noise = 0.1;			% Standard deviation of noise distribution.
+randn('state', 0);
+rand('state', 0);
+x = 0.25 + 0.07*randn(ndata, 1);
+t = sin(2*pi*x) + noise*randn(size(x));
+
+% Plot the data and the original sine function.
+h = figure;
+nplot = 200;
+plotvals = linspace(0, 1, nplot)';
+plot(x, t, 'ok')
+xlabel('Input')
+ylabel('Target')
+hold on
+axis([0 1 -1.5 1.5])
+fplot('sin(2*pi*x)', [0 1], '-g')
+legend('data', 'function');
+
+disp(' ')
+disp('Press any key to continue')
+pause; clc;
+
+disp('Next we create a two-layer MLP network having 3 hidden units and one')
+disp('linear output. The model assumes Gaussian target noise governed by an')
+disp('inverse variance hyperparmeter beta, and uses a simple Gaussian prior')
+disp('distribution governed by an inverse variance hyperparameter alpha.')
+disp(' ');
+disp('The network weights and the hyperparameters are initialised and then')
+disp('the output layer weights are optimized with the scaled conjugate gradient')
+disp('algorithm using the SCG function, with the hyperparameters kept')
+disp('fixed. After a maximum of 50 iterations, the hyperparameters are')
+disp('re-estimated using the EVIDENCE function. The process of optimizing')
+disp('the weights with fixed hyperparameters and then re-estimating the')
+disp('hyperparameters is repeated for a total of 3 cycles.')
+disp(' ')
+disp('Press any key to train the network and determine the hyperparameters.')
+pause;
+
+% Set up network parameters.
+nin = 1;		% Number of inputs.
+nhidden = 3;		% Number of hidden units.
+nout = 1;		% Number of outputs.
+alpha = 0.01;		% Initial prior hyperparameter. 
+beta_init = 50.0;	% Initial noise hyperparameter.
+
+% Create and initialize network weight vector.
+net = rbf(nin, nhidden, nout, 'tps', 'linear', alpha, beta_init);
+[net.mask, prior] = rbfprior('tps', nin, nhidden, nout, alpha, alpha);
+net = netinit(net, prior);
+
+options = foptions;
+options(14) = 5;  % At most 5 EM iterations for basis functions
+options(1) = -1;  % Turn off all messages
+net = rbfsetbf(net, options, x);  % Initialise the basis functions
+
+% Now train the network
+nouter = 5;
+ninner = 2;
+options = foptions;
+options(1) = 1;
+options(2) = 1.0e-5;		% Absolute precision for weights.
+options(3) = 1.0e-5;		% Precision for objective function.
+options(14) = 50;		% Number of training cycles in inner loop. 
+
+% Train using scaled conjugate gradients, re-estimating alpha and beta.
+for k = 1:nouter
+  net = netopt(net, options, x, t, 'scg');
+  [net, gamma] = evidence(net, x, t, ninner);
+  fprintf(1, '\nRe-estimation cycle %d:\n', k);
+  fprintf(1, '  alpha =  %8.5f\n', net.alpha);
+  fprintf(1, '  beta  =  %8.5f\n', net.beta);
+  fprintf(1, '  gamma =  %8.5f\n\n', gamma);
+  disp(' ')
+  disp('Press any key to continue.')
+  pause;
+end
+
+fprintf(1, 'true beta: %f\n', 1/(noise*noise));
+
+disp(' ')
+disp('Network training and hyperparameter re-estimation are now complete.') 
+disp('Compare the final value for the hyperparameter beta with the true') 
+disp('value.')
+disp(' ')
+disp('Notice that the final error value is close to the number of data')
+disp(['points (', num2str(ndata),') divided by two.'])
+disp(' ')
+disp('Press any key to continue.')
+pause; clc;
+disp('We can now plot the function represented by the trained network. This')
+disp('corresponds to the mean of the predictive distribution. We can also')
+disp('plot ''error bars'' representing one standard deviation of the')
+disp('predictive distribution around the mean.')
+disp(' ')
+disp('Press any key to add the network function and error bars to the plot.')
+pause;
+
+% Evaluate error bars.
+[y, sig2] = netevfwd(netpak(net), net, x, t, plotvals);
+sig = sqrt(sig2);
+
+% Plot the data, the original function, and the trained network function.
+[y, z] = rbffwd(net, plotvals);
+figure(h); hold on;
+plot(plotvals, y, '-r')
+xlabel('Input')
+ylabel('Target')
+plot(plotvals, y + sig, '-b');
+plot(plotvals, y - sig, '-b');
+legend('data', 'function', 'network', 'error bars');
+
+disp(' ')
+disp('Notice how the confidence interval spanned by the ''error bars'' is')
+disp('smaller in the region of input space where the data density is high,')
+disp('and becomes larger in regions away from the data.')
+disp(' ')
+disp('Press any key to end.')
+pause; clc; close(h); 
+