wolffd@0
|
1 %DEMHMC2 Demonstrate Bayesian regression with Hybrid Monte Carlo sampling.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % The problem consists of one input variable X and one target variable
|
wolffd@0
|
5 % T with data generated by sampling X at equal intervals and then
|
wolffd@0
|
6 % generating target data by computing SIN(2*PI*X) and adding Gaussian
|
wolffd@0
|
7 % noise. The model is a 2-layer network with linear outputs, and the
|
wolffd@0
|
8 % hybrid Monte Carlo algorithm (without persistence) is used to sample
|
wolffd@0
|
9 % from the posterior distribution of the weights. The graph shows the
|
wolffd@0
|
10 % underlying function, 100 samples from the function given by the
|
wolffd@0
|
11 % posterior distribution of the weights, and the average prediction
|
wolffd@0
|
12 % (weighted by the posterior probabilities).
|
wolffd@0
|
13 %
|
wolffd@0
|
14 % See also
|
wolffd@0
|
15 % DEMHMC3, HMC, MLP, MLPERR, MLPGRAD
|
wolffd@0
|
16 %
|
wolffd@0
|
17
|
wolffd@0
|
18 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
19
|
wolffd@0
|
20
|
wolffd@0
|
21 % Generate the matrix of inputs x and targets t.
|
wolffd@0
|
22 ndata = 20; % Number of data points.
|
wolffd@0
|
23 noise = 0.1; % Standard deviation of noise distribution.
|
wolffd@0
|
24 nin = 1; % Number of inputs.
|
wolffd@0
|
25 nout = 1; % Number of outputs.
|
wolffd@0
|
26
|
wolffd@0
|
27 seed = 42; % Seed for random weight initialization.
|
wolffd@0
|
28 randn('state', seed);
|
wolffd@0
|
29 rand('state', seed);
|
wolffd@0
|
30
|
wolffd@0
|
31 x = 0.25 + 0.1*randn(ndata, nin);
|
wolffd@0
|
32 t = sin(2*pi*x) + noise*randn(size(x));
|
wolffd@0
|
33
|
wolffd@0
|
34 clc
|
wolffd@0
|
35 disp('This demonstration illustrates the use of the hybrid Monte Carlo')
|
wolffd@0
|
36 disp('algorithm to sample from the posterior weight distribution of a')
|
wolffd@0
|
37 disp('multi-layer perceptron.')
|
wolffd@0
|
38 disp(' ')
|
wolffd@0
|
39 disp('A regression problem is used, with the one-dimensional data drawn')
|
wolffd@0
|
40 disp('from a noisy sine function. The x values are sampled from a normal')
|
wolffd@0
|
41 disp('distribution with mean 0.25 and variance 0.01.')
|
wolffd@0
|
42 disp(' ')
|
wolffd@0
|
43 disp('First we initialise the network.')
|
wolffd@0
|
44 disp(' ')
|
wolffd@0
|
45 disp('Press any key to continue.')
|
wolffd@0
|
46 pause
|
wolffd@0
|
47
|
wolffd@0
|
48 % Set up network parameters.
|
wolffd@0
|
49 nhidden = 5; % Number of hidden units.
|
wolffd@0
|
50 alpha = 0.001; % Coefficient of weight-decay prior.
|
wolffd@0
|
51 beta = 100.0; % Coefficient of data error.
|
wolffd@0
|
52
|
wolffd@0
|
53 % Create and initialize network model.
|
wolffd@0
|
54 % Initialise weights reasonably close to 0
|
wolffd@0
|
55 net = mlp(nin, nhidden, nout, 'linear', alpha, beta);
|
wolffd@0
|
56 net = mlpinit(net, 10);
|
wolffd@0
|
57
|
wolffd@0
|
58 clc
|
wolffd@0
|
59 disp('Next we take 100 samples from the posterior distribution. The first')
|
wolffd@0
|
60 disp('200 samples at the start of the chain are omitted. As persistence')
|
wolffd@0
|
61 disp('is not used, the momentum is randomised at each step. 100 iterations')
|
wolffd@0
|
62 disp('are used at each step. The new state is accepted if the threshold')
|
wolffd@0
|
63 disp('value is greater than a random number between 0 and 1.')
|
wolffd@0
|
64 disp(' ')
|
wolffd@0
|
65 disp('Negative step numbers indicate samples discarded from the start of the')
|
wolffd@0
|
66 disp('chain.')
|
wolffd@0
|
67 disp(' ')
|
wolffd@0
|
68 disp('Press any key to continue.')
|
wolffd@0
|
69 pause
|
wolffd@0
|
70 % Set up vector of options for hybrid Monte Carlo.
|
wolffd@0
|
71 nsamples = 100; % Number of retained samples.
|
wolffd@0
|
72
|
wolffd@0
|
73 options = foptions; % Default options vector.
|
wolffd@0
|
74 options(1) = 1; % Switch on diagnostics.
|
wolffd@0
|
75 options(7) = 100; % Number of steps in trajectory.
|
wolffd@0
|
76 options(14) = nsamples; % Number of Monte Carlo samples returned.
|
wolffd@0
|
77 options(15) = 200; % Number of samples omitted at start of chain.
|
wolffd@0
|
78 options(18) = 0.002; % Step size.
|
wolffd@0
|
79
|
wolffd@0
|
80 w = mlppak(net);
|
wolffd@0
|
81 % Initialise HMC
|
wolffd@0
|
82 hmc('state', 42);
|
wolffd@0
|
83 [samples, energies] = hmc('neterr', w, options, 'netgrad', net, x, t);
|
wolffd@0
|
84
|
wolffd@0
|
85 clc
|
wolffd@0
|
86 disp('The plot shows the underlying noise free function, the 100 samples')
|
wolffd@0
|
87 disp('produced from the MLP, and their average as a Monte Carlo estimate')
|
wolffd@0
|
88 disp('of the true posterior average.')
|
wolffd@0
|
89 disp(' ')
|
wolffd@0
|
90 disp('Press any key to continue.')
|
wolffd@0
|
91 pause
|
wolffd@0
|
92 nplot = 300;
|
wolffd@0
|
93 plotvals = [0 : 1/(nplot - 1) : 1]';
|
wolffd@0
|
94 pred = zeros(size(plotvals));
|
wolffd@0
|
95 fh = figure;
|
wolffd@0
|
96 for k = 1:nsamples
|
wolffd@0
|
97 w2 = samples(k,:);
|
wolffd@0
|
98 net2 = mlpunpak(net, w2);
|
wolffd@0
|
99 y = mlpfwd(net2, plotvals);
|
wolffd@0
|
100 % Average sample predictions as Monte Carlo estimate of true integral
|
wolffd@0
|
101 pred = pred + y;
|
wolffd@0
|
102 h4 = plot(plotvals, y, '-r', 'LineWidth', 1);
|
wolffd@0
|
103 if k == 1
|
wolffd@0
|
104 hold on
|
wolffd@0
|
105 end
|
wolffd@0
|
106 end
|
wolffd@0
|
107 pred = pred./nsamples;
|
wolffd@0
|
108
|
wolffd@0
|
109 % Plot data
|
wolffd@0
|
110 h1 = plot(x, t, 'ob', 'LineWidth', 2, 'MarkerFaceColor', 'blue');
|
wolffd@0
|
111 axis([0 1 -3 3])
|
wolffd@0
|
112
|
wolffd@0
|
113 % Plot function
|
wolffd@0
|
114 [fx, fy] = fplot('sin(2*pi*x)', [0 1], '--g');
|
wolffd@0
|
115 h2 = plot(fx, fy, '--g', 'LineWidth', 2);
|
wolffd@0
|
116 set(gca, 'box', 'on');
|
wolffd@0
|
117
|
wolffd@0
|
118 % Plot averaged prediction
|
wolffd@0
|
119 h3 = plot(plotvals, pred, '-c', 'LineWidth', 2);
|
wolffd@0
|
120 hold off
|
wolffd@0
|
121
|
wolffd@0
|
122 lstrings = char('Data', 'Function', 'Prediction', 'Samples');
|
wolffd@0
|
123 legend([h1 h2 h3 h4], lstrings, 3);
|
wolffd@0
|
124
|
wolffd@0
|
125 disp('Note how the predictions become much further from the true function')
|
wolffd@0
|
126 disp('away from the region of high data density.')
|
wolffd@0
|
127 disp(' ')
|
wolffd@0
|
128 disp('Press any key to exit.')
|
wolffd@0
|
129 pause
|
wolffd@0
|
130 close(fh);
|
wolffd@0
|
131 clear all;
|
wolffd@0
|
132
|