Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/demev3.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 %DEMEV3 Demonstrate Bayesian regression for the RBF. | |
2 % | |
3 % Description | |
4 % The problem consists an input variable X which sampled from a | |
5 % Gaussian distribution, and a target variable T generated by computing | |
6 % SIN(2*PI*X) and adding Gaussian noise. An RBF network with linear | |
7 % outputs is trained by minimizing a sum-of-squares error function with | |
8 % isotropic Gaussian regularizer, using the scaled conjugate gradient | |
9 % optimizer. The hyperparameters ALPHA and BETA are re-estimated using | |
10 % the function EVIDENCE. A graph is plotted of the original function, | |
11 % the training data, the trained network function, and the error bars. | |
12 % | |
13 % See also | |
14 % DEMEV1, EVIDENCE, RBF, SCG, NETEVFWD | |
15 % | |
16 | |
17 % Copyright (c) Ian T Nabney (1996-2001) | |
18 | |
19 clc; | |
20 disp('This demonstration illustrates the application of Bayesian') | |
21 disp('re-estimation to determine the hyperparameters in a simple regression') | |
22 disp('problem using an RBF netowk. It is based on a the fact that the') | |
23 disp('posterior distribution for the output weights of an RBF is Gaussian') | |
24 disp('and uses the evidence maximization framework of MacKay.') | |
25 disp(' ') | |
26 disp('First, we generate a synthetic data set consisting of a single input') | |
27 disp('variable x sampled from a Gaussian distribution, and a target variable') | |
28 disp('t obtained by evaluating sin(2*pi*x) and adding Gaussian noise.') | |
29 disp(' ') | |
30 disp('Press any key to see a plot of the data together with the sine function.') | |
31 pause; | |
32 | |
33 % Generate the matrix of inputs x and targets t. | |
34 | |
35 ndata = 16; % Number of data points. | |
36 noise = 0.1; % Standard deviation of noise distribution. | |
37 randn('state', 0); | |
38 rand('state', 0); | |
39 x = 0.25 + 0.07*randn(ndata, 1); | |
40 t = sin(2*pi*x) + noise*randn(size(x)); | |
41 | |
42 % Plot the data and the original sine function. | |
43 h = figure; | |
44 nplot = 200; | |
45 plotvals = linspace(0, 1, nplot)'; | |
46 plot(x, t, 'ok') | |
47 xlabel('Input') | |
48 ylabel('Target') | |
49 hold on | |
50 axis([0 1 -1.5 1.5]) | |
51 fplot('sin(2*pi*x)', [0 1], '-g') | |
52 legend('data', 'function'); | |
53 | |
54 disp(' ') | |
55 disp('Press any key to continue') | |
56 pause; clc; | |
57 | |
58 disp('Next we create a two-layer MLP network having 3 hidden units and one') | |
59 disp('linear output. The model assumes Gaussian target noise governed by an') | |
60 disp('inverse variance hyperparmeter beta, and uses a simple Gaussian prior') | |
61 disp('distribution governed by an inverse variance hyperparameter alpha.') | |
62 disp(' '); | |
63 disp('The network weights and the hyperparameters are initialised and then') | |
64 disp('the output layer weights are optimized with the scaled conjugate gradient') | |
65 disp('algorithm using the SCG function, with the hyperparameters kept') | |
66 disp('fixed. After a maximum of 50 iterations, the hyperparameters are') | |
67 disp('re-estimated using the EVIDENCE function. The process of optimizing') | |
68 disp('the weights with fixed hyperparameters and then re-estimating the') | |
69 disp('hyperparameters is repeated for a total of 3 cycles.') | |
70 disp(' ') | |
71 disp('Press any key to train the network and determine the hyperparameters.') | |
72 pause; | |
73 | |
74 % Set up network parameters. | |
75 nin = 1; % Number of inputs. | |
76 nhidden = 3; % Number of hidden units. | |
77 nout = 1; % Number of outputs. | |
78 alpha = 0.01; % Initial prior hyperparameter. | |
79 beta_init = 50.0; % Initial noise hyperparameter. | |
80 | |
81 % Create and initialize network weight vector. | |
82 net = rbf(nin, nhidden, nout, 'tps', 'linear', alpha, beta_init); | |
83 [net.mask, prior] = rbfprior('tps', nin, nhidden, nout, alpha, alpha); | |
84 net = netinit(net, prior); | |
85 | |
86 options = foptions; | |
87 options(14) = 5; % At most 5 EM iterations for basis functions | |
88 options(1) = -1; % Turn off all messages | |
89 net = rbfsetbf(net, options, x); % Initialise the basis functions | |
90 | |
91 % Now train the network | |
92 nouter = 5; | |
93 ninner = 2; | |
94 options = foptions; | |
95 options(1) = 1; | |
96 options(2) = 1.0e-5; % Absolute precision for weights. | |
97 options(3) = 1.0e-5; % Precision for objective function. | |
98 options(14) = 50; % Number of training cycles in inner loop. | |
99 | |
100 % Train using scaled conjugate gradients, re-estimating alpha and beta. | |
101 for k = 1:nouter | |
102 net = netopt(net, options, x, t, 'scg'); | |
103 [net, gamma] = evidence(net, x, t, ninner); | |
104 fprintf(1, '\nRe-estimation cycle %d:\n', k); | |
105 fprintf(1, ' alpha = %8.5f\n', net.alpha); | |
106 fprintf(1, ' beta = %8.5f\n', net.beta); | |
107 fprintf(1, ' gamma = %8.5f\n\n', gamma); | |
108 disp(' ') | |
109 disp('Press any key to continue.') | |
110 pause; | |
111 end | |
112 | |
113 fprintf(1, 'true beta: %f\n', 1/(noise*noise)); | |
114 | |
115 disp(' ') | |
116 disp('Network training and hyperparameter re-estimation are now complete.') | |
117 disp('Compare the final value for the hyperparameter beta with the true') | |
118 disp('value.') | |
119 disp(' ') | |
120 disp('Notice that the final error value is close to the number of data') | |
121 disp(['points (', num2str(ndata),') divided by two.']) | |
122 disp(' ') | |
123 disp('Press any key to continue.') | |
124 pause; clc; | |
125 disp('We can now plot the function represented by the trained network. This') | |
126 disp('corresponds to the mean of the predictive distribution. We can also') | |
127 disp('plot ''error bars'' representing one standard deviation of the') | |
128 disp('predictive distribution around the mean.') | |
129 disp(' ') | |
130 disp('Press any key to add the network function and error bars to the plot.') | |
131 pause; | |
132 | |
133 % Evaluate error bars. | |
134 [y, sig2] = netevfwd(netpak(net), net, x, t, plotvals); | |
135 sig = sqrt(sig2); | |
136 | |
137 % Plot the data, the original function, and the trained network function. | |
138 [y, z] = rbffwd(net, plotvals); | |
139 figure(h); hold on; | |
140 plot(plotvals, y, '-r') | |
141 xlabel('Input') | |
142 ylabel('Target') | |
143 plot(plotvals, y + sig, '-b'); | |
144 plot(plotvals, y - sig, '-b'); | |
145 legend('data', 'function', 'network', 'error bars'); | |
146 | |
147 disp(' ') | |
148 disp('Notice how the confidence interval spanned by the ''error bars'' is') | |
149 disp('smaller in the region of input space where the data density is high,') | |
150 disp('and becomes larger in regions away from the data.') | |
151 disp(' ') | |
152 disp('Press any key to end.') | |
153 pause; clc; close(h); | |
154 |