wolffd@0
|
1 %DEMARD Automatic relevance determination using the MLP.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % This script demonstrates the technique of automatic relevance
|
wolffd@0
|
5 % determination (ARD) using a synthetic problem having three input
|
wolffd@0
|
6 % variables: X1 is sampled uniformly from the range (0,1) and has a low
|
wolffd@0
|
7 % level of added Gaussian noise, X2 is a copy of X1 with a higher level
|
wolffd@0
|
8 % of added noise, and X3 is sampled randomly from a Gaussian
|
wolffd@0
|
9 % distribution. The single target variable is determined by
|
wolffd@0
|
10 % SIN(2*PI*X1) with additive Gaussian noise. Thus X1 is very relevant
|
wolffd@0
|
11 % for determining the target value, X2 is of some relevance, while X3
|
wolffd@0
|
12 % is irrelevant. The prior over weights is given by the ARD Gaussian
|
wolffd@0
|
13 % prior with a separate hyper-parameter for the group of weights
|
wolffd@0
|
14 % associated with each input. A multi-layer perceptron is trained on
|
wolffd@0
|
15 % this data, with re-estimation of the hyper-parameters using EVIDENCE.
|
wolffd@0
|
16 % The final values for the hyper-parameters reflect the relative
|
wolffd@0
|
17 % importance of the three inputs.
|
wolffd@0
|
18 %
|
wolffd@0
|
19 % See also
|
wolffd@0
|
20 % DEMMLP1, DEMEV1, MLP, EVIDENCE
|
wolffd@0
|
21 %
|
wolffd@0
|
22
|
wolffd@0
|
23 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
24
|
wolffd@0
|
25 clc;
|
wolffd@0
|
26 disp('This demonstration illustrates the technique of automatic relevance')
|
wolffd@0
|
27 disp('determination (ARD) using a multi-layer perceptron.')
|
wolffd@0
|
28 disp(' ');
|
wolffd@0
|
29 disp('First, we set up a synthetic data set involving three input variables:')
|
wolffd@0
|
30 disp('x1 is sampled uniformly from the range (0,1) and has a low level of')
|
wolffd@0
|
31 disp('added Gaussian noise, x2 is a copy of x1 with a higher level of added')
|
wolffd@0
|
32 disp('noise, and x3 is sampled randomly from a Gaussian distribution. The')
|
wolffd@0
|
33 disp('single target variable is given by t = sin(2*pi*x1) with additive')
|
wolffd@0
|
34 disp('Gaussian noise. Thus x1 is very relevant for determining the target')
|
wolffd@0
|
35 disp('value, x2 is of some relevance, while x3 should in principle be')
|
wolffd@0
|
36 disp('irrelevant.')
|
wolffd@0
|
37 disp(' ');
|
wolffd@0
|
38 disp('Press any key to see a plot of t against x1.')
|
wolffd@0
|
39 pause;
|
wolffd@0
|
40
|
wolffd@0
|
41 % Generate the data set.
|
wolffd@0
|
42 randn('state', 0);
|
wolffd@0
|
43 rand('state', 0);
|
wolffd@0
|
44 ndata = 100;
|
wolffd@0
|
45 noise = 0.05;
|
wolffd@0
|
46 x1 = rand(ndata, 1) + 0.002*randn(ndata, 1);
|
wolffd@0
|
47 x2 = x1 + 0.02*randn(ndata, 1);
|
wolffd@0
|
48 x3 = 0.5 + 0.2*randn(ndata, 1);
|
wolffd@0
|
49 x = [x1, x2, x3];
|
wolffd@0
|
50 t = sin(2*pi*x1) + noise*randn(ndata, 1);
|
wolffd@0
|
51
|
wolffd@0
|
52 % Plot the data and the original function.
|
wolffd@0
|
53 h = figure;
|
wolffd@0
|
54 plotvals = linspace(0, 1, 200)';
|
wolffd@0
|
55 plot(x1, t, 'ob')
|
wolffd@0
|
56 hold on
|
wolffd@0
|
57 axis([0 1 -1.5 1.5])
|
wolffd@0
|
58 [fx, fy] = fplot('sin(2*pi*x)', [0 1]);
|
wolffd@0
|
59 plot(fx, fy, '-g', 'LineWidth', 2);
|
wolffd@0
|
60 legend('data', 'function');
|
wolffd@0
|
61
|
wolffd@0
|
62 disp(' ');
|
wolffd@0
|
63 disp('Press any key to continue')
|
wolffd@0
|
64 pause; clc;
|
wolffd@0
|
65
|
wolffd@0
|
66 disp('The prior over weights is given by the ARD Gaussian prior with a')
|
wolffd@0
|
67 disp('separate hyper-parameter for the group of weights associated with each')
|
wolffd@0
|
68 disp('input. This prior is set up using the utility MLPPRIOR. The network is')
|
wolffd@0
|
69 disp('trained by error minimization using scaled conjugate gradient function')
|
wolffd@0
|
70 disp('SCG. There are two cycles of training, and at the end of each cycle')
|
wolffd@0
|
71 disp('the hyper-parameters are re-estimated using EVIDENCE.')
|
wolffd@0
|
72 disp(' ');
|
wolffd@0
|
73 disp('Press any key to create and train the network.')
|
wolffd@0
|
74 disp(' ');
|
wolffd@0
|
75 pause;
|
wolffd@0
|
76
|
wolffd@0
|
77 % Set up network parameters.
|
wolffd@0
|
78 nin = 3; % Number of inputs.
|
wolffd@0
|
79 nhidden = 2; % Number of hidden units.
|
wolffd@0
|
80 nout = 1; % Number of outputs.
|
wolffd@0
|
81 aw1 = 0.01*ones(1, nin); % First-layer ARD hyperparameters.
|
wolffd@0
|
82 ab1 = 0.01; % Hyperparameter for hidden unit biases.
|
wolffd@0
|
83 aw2 = 0.01; % Hyperparameter for second-layer weights.
|
wolffd@0
|
84 ab2 = 0.01; % Hyperparameter for output unit biases.
|
wolffd@0
|
85 beta = 50.0; % Coefficient of data error.
|
wolffd@0
|
86
|
wolffd@0
|
87 % Create and initialize network.
|
wolffd@0
|
88 prior = mlpprior(nin, nhidden, nout, aw1, ab1, aw2, ab2);
|
wolffd@0
|
89 net = mlp(nin, nhidden, nout, 'linear', prior, beta);
|
wolffd@0
|
90
|
wolffd@0
|
91 % Set up vector of options for the optimiser.
|
wolffd@0
|
92 nouter = 2; % Number of outer loops
|
wolffd@0
|
93 ninner = 10; % Number of inner loops
|
wolffd@0
|
94 options = zeros(1,18); % Default options vector.
|
wolffd@0
|
95 options(1) = 1; % This provides display of error values.
|
wolffd@0
|
96 options(2) = 1.0e-7; % This ensures that convergence must occur
|
wolffd@0
|
97 options(3) = 1.0e-7;
|
wolffd@0
|
98 options(14) = 300; % Number of training cycles in inner loop.
|
wolffd@0
|
99
|
wolffd@0
|
100 % Train using scaled conjugate gradients, re-estimating alpha and beta.
|
wolffd@0
|
101 for k = 1:nouter
|
wolffd@0
|
102 net = netopt(net, options, x, t, 'scg');
|
wolffd@0
|
103 [net, gamma] = evidence(net, x, t, ninner);
|
wolffd@0
|
104 fprintf(1, '\n\nRe-estimation cycle %d:\n', k);
|
wolffd@0
|
105 disp('The first three alphas are the hyperparameters for the corresponding');
|
wolffd@0
|
106 disp('input to hidden unit weights. The remainder are the hyperparameters');
|
wolffd@0
|
107 disp('for the hidden unit biases, second layer weights and output unit')
|
wolffd@0
|
108 disp('biases, respectively.')
|
wolffd@0
|
109 fprintf(1, ' alpha = %8.5f\n', net.alpha);
|
wolffd@0
|
110 fprintf(1, ' beta = %8.5f\n', net.beta);
|
wolffd@0
|
111 fprintf(1, ' gamma = %8.5f\n\n', gamma);
|
wolffd@0
|
112 disp(' ')
|
wolffd@0
|
113 disp('Press any key to continue.')
|
wolffd@0
|
114 pause
|
wolffd@0
|
115 end
|
wolffd@0
|
116
|
wolffd@0
|
117 % Plot the function corresponding to the trained network.
|
wolffd@0
|
118 figure(h); hold on;
|
wolffd@0
|
119 [y, z] = mlpfwd(net, plotvals*ones(1,3));
|
wolffd@0
|
120 plot(plotvals, y, '-r', 'LineWidth', 2)
|
wolffd@0
|
121 legend('data', 'function', 'network');
|
wolffd@0
|
122
|
wolffd@0
|
123 disp('Press any key to continue.');
|
wolffd@0
|
124 pause; clc;
|
wolffd@0
|
125
|
wolffd@0
|
126 disp('We can now read off the hyperparameter values corresponding to the')
|
wolffd@0
|
127 disp('three inputs x1, x2 and x3:')
|
wolffd@0
|
128 disp(' ');
|
wolffd@0
|
129 fprintf(1, ' alpha1: %8.5f\n', net.alpha(1));
|
wolffd@0
|
130 fprintf(1, ' alpha2: %8.5f\n', net.alpha(2));
|
wolffd@0
|
131 fprintf(1, ' alpha3: %8.5f\n', net.alpha(3));
|
wolffd@0
|
132 disp(' ');
|
wolffd@0
|
133 disp('Since each alpha corresponds to an inverse variance, we see that the')
|
wolffd@0
|
134 disp('posterior variance for weights associated with input x1 is large, that')
|
wolffd@0
|
135 disp('of x2 has an intermediate value and the variance of weights associated')
|
wolffd@0
|
136 disp('with x3 is small.')
|
wolffd@0
|
137 disp(' ')
|
wolffd@0
|
138 disp('Press any key to continue.')
|
wolffd@0
|
139 disp(' ')
|
wolffd@0
|
140 pause
|
wolffd@0
|
141 disp('This is confirmed by looking at the corresponding weight values:')
|
wolffd@0
|
142 disp(' ');
|
wolffd@0
|
143 fprintf(1, ' %8.5f %8.5f\n', net.w1');
|
wolffd@0
|
144 disp(' ');
|
wolffd@0
|
145 disp('where the three rows correspond to weights asssociated with x1, x2 and')
|
wolffd@0
|
146 disp('x3 respectively. We see that the network is giving greatest emphasis')
|
wolffd@0
|
147 disp('to x1 and least emphasis to x3, with intermediate emphasis on')
|
wolffd@0
|
148 disp('x2. Since the target t is statistically independent of x3 we might')
|
wolffd@0
|
149 disp('expect the weights associated with this input would go to')
|
wolffd@0
|
150 disp('zero. However, for any finite data set there may be some chance')
|
wolffd@0
|
151 disp('correlation between x3 and t, and so the corresponding alpha remains')
|
wolffd@0
|
152 disp('finite.')
|
wolffd@0
|
153
|
wolffd@0
|
154 disp(' ');
|
wolffd@0
|
155 disp('Press any key to end.')
|
wolffd@0
|
156 pause; clc; close(h); clear all
|
wolffd@0
|
157
|