comparison toolboxes/FullBNT-1.0.7/netlab3.3/demard.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMARD Automatic relevance determination using the MLP.
2 %
3 % Description
4 % This script demonstrates the technique of automatic relevance
5 % determination (ARD) using a synthetic problem having three input
6 % variables: X1 is sampled uniformly from the range (0,1) and has a low
7 % level of added Gaussian noise, X2 is a copy of X1 with a higher level
8 % of added noise, and X3 is sampled randomly from a Gaussian
9 % distribution. The single target variable is determined by
10 % SIN(2*PI*X1) with additive Gaussian noise. Thus X1 is very relevant
11 % for determining the target value, X2 is of some relevance, while X3
12 % is irrelevant. The prior over weights is given by the ARD Gaussian
13 % prior with a separate hyper-parameter for the group of weights
14 % associated with each input. A multi-layer perceptron is trained on
15 % this data, with re-estimation of the hyper-parameters using EVIDENCE.
16 % The final values for the hyper-parameters reflect the relative
17 % importance of the three inputs.
18 %
19 % See also
20 % DEMMLP1, DEMEV1, MLP, EVIDENCE
21 %
22
23 % Copyright (c) Ian T Nabney (1996-2001)
24
25 clc;
26 disp('This demonstration illustrates the technique of automatic relevance')
27 disp('determination (ARD) using a multi-layer perceptron.')
28 disp(' ');
29 disp('First, we set up a synthetic data set involving three input variables:')
30 disp('x1 is sampled uniformly from the range (0,1) and has a low level of')
31 disp('added Gaussian noise, x2 is a copy of x1 with a higher level of added')
32 disp('noise, and x3 is sampled randomly from a Gaussian distribution. The')
33 disp('single target variable is given by t = sin(2*pi*x1) with additive')
34 disp('Gaussian noise. Thus x1 is very relevant for determining the target')
35 disp('value, x2 is of some relevance, while x3 should in principle be')
36 disp('irrelevant.')
37 disp(' ');
38 disp('Press any key to see a plot of t against x1.')
39 pause;
40
41 % Generate the data set.
42 randn('state', 0);
43 rand('state', 0);
44 ndata = 100;
45 noise = 0.05;
46 x1 = rand(ndata, 1) + 0.002*randn(ndata, 1);
47 x2 = x1 + 0.02*randn(ndata, 1);
48 x3 = 0.5 + 0.2*randn(ndata, 1);
49 x = [x1, x2, x3];
50 t = sin(2*pi*x1) + noise*randn(ndata, 1);
51
52 % Plot the data and the original function.
53 h = figure;
54 plotvals = linspace(0, 1, 200)';
55 plot(x1, t, 'ob')
56 hold on
57 axis([0 1 -1.5 1.5])
58 [fx, fy] = fplot('sin(2*pi*x)', [0 1]);
59 plot(fx, fy, '-g', 'LineWidth', 2);
60 legend('data', 'function');
61
62 disp(' ');
63 disp('Press any key to continue')
64 pause; clc;
65
66 disp('The prior over weights is given by the ARD Gaussian prior with a')
67 disp('separate hyper-parameter for the group of weights associated with each')
68 disp('input. This prior is set up using the utility MLPPRIOR. The network is')
69 disp('trained by error minimization using scaled conjugate gradient function')
70 disp('SCG. There are two cycles of training, and at the end of each cycle')
71 disp('the hyper-parameters are re-estimated using EVIDENCE.')
72 disp(' ');
73 disp('Press any key to create and train the network.')
74 disp(' ');
75 pause;
76
77 % Set up network parameters.
78 nin = 3; % Number of inputs.
79 nhidden = 2; % Number of hidden units.
80 nout = 1; % Number of outputs.
81 aw1 = 0.01*ones(1, nin); % First-layer ARD hyperparameters.
82 ab1 = 0.01; % Hyperparameter for hidden unit biases.
83 aw2 = 0.01; % Hyperparameter for second-layer weights.
84 ab2 = 0.01; % Hyperparameter for output unit biases.
85 beta = 50.0; % Coefficient of data error.
86
87 % Create and initialize network.
88 prior = mlpprior(nin, nhidden, nout, aw1, ab1, aw2, ab2);
89 net = mlp(nin, nhidden, nout, 'linear', prior, beta);
90
91 % Set up vector of options for the optimiser.
92 nouter = 2; % Number of outer loops
93 ninner = 10; % Number of inner loops
94 options = zeros(1,18); % Default options vector.
95 options(1) = 1; % This provides display of error values.
96 options(2) = 1.0e-7; % This ensures that convergence must occur
97 options(3) = 1.0e-7;
98 options(14) = 300; % Number of training cycles in inner loop.
99
100 % Train using scaled conjugate gradients, re-estimating alpha and beta.
101 for k = 1:nouter
102 net = netopt(net, options, x, t, 'scg');
103 [net, gamma] = evidence(net, x, t, ninner);
104 fprintf(1, '\n\nRe-estimation cycle %d:\n', k);
105 disp('The first three alphas are the hyperparameters for the corresponding');
106 disp('input to hidden unit weights. The remainder are the hyperparameters');
107 disp('for the hidden unit biases, second layer weights and output unit')
108 disp('biases, respectively.')
109 fprintf(1, ' alpha = %8.5f\n', net.alpha);
110 fprintf(1, ' beta = %8.5f\n', net.beta);
111 fprintf(1, ' gamma = %8.5f\n\n', gamma);
112 disp(' ')
113 disp('Press any key to continue.')
114 pause
115 end
116
117 % Plot the function corresponding to the trained network.
118 figure(h); hold on;
119 [y, z] = mlpfwd(net, plotvals*ones(1,3));
120 plot(plotvals, y, '-r', 'LineWidth', 2)
121 legend('data', 'function', 'network');
122
123 disp('Press any key to continue.');
124 pause; clc;
125
126 disp('We can now read off the hyperparameter values corresponding to the')
127 disp('three inputs x1, x2 and x3:')
128 disp(' ');
129 fprintf(1, ' alpha1: %8.5f\n', net.alpha(1));
130 fprintf(1, ' alpha2: %8.5f\n', net.alpha(2));
131 fprintf(1, ' alpha3: %8.5f\n', net.alpha(3));
132 disp(' ');
133 disp('Since each alpha corresponds to an inverse variance, we see that the')
134 disp('posterior variance for weights associated with input x1 is large, that')
135 disp('of x2 has an intermediate value and the variance of weights associated')
136 disp('with x3 is small.')
137 disp(' ')
138 disp('Press any key to continue.')
139 disp(' ')
140 pause
141 disp('This is confirmed by looking at the corresponding weight values:')
142 disp(' ');
143 fprintf(1, ' %8.5f %8.5f\n', net.w1');
144 disp(' ');
145 disp('where the three rows correspond to weights asssociated with x1, x2 and')
146 disp('x3 respectively. We see that the network is giving greatest emphasis')
147 disp('to x1 and least emphasis to x3, with intermediate emphasis on')
148 disp('x2. Since the target t is statistically independent of x3 we might')
149 disp('expect the weights associated with this input would go to')
150 disp('zero. However, for any finite data set there may be some chance')
151 disp('correlation between x3 and t, and so the corresponding alpha remains')
152 disp('finite.')
153
154 disp(' ');
155 disp('Press any key to end.')
156 pause; clc; close(h); clear all
157