wolffd@0
|
1 %DEMRBF1 Demonstrate simple regression using a radial basis function network.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % The problem consists of one input variable X and one target variable
|
wolffd@0
|
5 % T with data generated by sampling X at equal intervals and then
|
wolffd@0
|
6 % generating target data by computing SIN(2*PI*X) and adding Gaussian
|
wolffd@0
|
7 % noise. This data is the same as that used in demmlp1.
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % Three different RBF networks (with different activation functions)
|
wolffd@0
|
10 % are trained in two stages. First, a Gaussian mixture model is trained
|
wolffd@0
|
11 % using the EM algorithm, and the centres of this model are used to set
|
wolffd@0
|
12 % the centres of the RBF. Second, the output weights (and biases) are
|
wolffd@0
|
13 % determined using the pseudo-inverse of the design matrix.
|
wolffd@0
|
14 %
|
wolffd@0
|
15 % See also
|
wolffd@0
|
16 % DEMMLP1, RBF, RBFFWD, GMM, GMMEM
|
wolffd@0
|
17 %
|
wolffd@0
|
18
|
wolffd@0
|
19 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
20
|
wolffd@0
|
21
|
wolffd@0
|
22 % Generate the matrix of inputs x and targets t.
|
wolffd@0
|
23 randn('state', 42);
|
wolffd@0
|
24 rand('state', 42);
|
wolffd@0
|
25 ndata = 20; % Number of data points.
|
wolffd@0
|
26 noise = 0.2; % Standard deviation of noise distribution.
|
wolffd@0
|
27 x = (linspace(0, 1, ndata))';
|
wolffd@0
|
28 t = sin(2*pi*x) + noise*randn(ndata, 1);
|
wolffd@0
|
29 mu = mean(x);
|
wolffd@0
|
30 sigma = std(x);
|
wolffd@0
|
31 tr_in = (x - mu)./(sigma);
|
wolffd@0
|
32
|
wolffd@0
|
33 clc
|
wolffd@0
|
34 disp('This demonstration illustrates the use of a Radial Basis Function')
|
wolffd@0
|
35 disp('network for regression problems. The data is generated from a noisy')
|
wolffd@0
|
36 disp('sine function.')
|
wolffd@0
|
37 disp(' ')
|
wolffd@0
|
38 disp('Press any key to continue.')
|
wolffd@0
|
39 pause
|
wolffd@0
|
40 % Set up network parameters.
|
wolffd@0
|
41 nin = 1; % Number of inputs.
|
wolffd@0
|
42 nhidden = 7; % Number of hidden units.
|
wolffd@0
|
43 nout = 1; % Number of outputs.
|
wolffd@0
|
44
|
wolffd@0
|
45 clc
|
wolffd@0
|
46 disp('We assess the effect of three different activation functions.')
|
wolffd@0
|
47 disp('First we create a network with Gaussian activations.')
|
wolffd@0
|
48 disp(' ')
|
wolffd@0
|
49 disp('Press any key to continue.')
|
wolffd@0
|
50 pause
|
wolffd@0
|
51 % Create and initialize network weight and parameter vectors.
|
wolffd@0
|
52 net = rbf(nin, nhidden, nout, 'gaussian');
|
wolffd@0
|
53
|
wolffd@0
|
54 disp('A two-stage training algorithm is used: it uses a small number of')
|
wolffd@0
|
55 disp('iterations of EM to position the centres, and then the pseudo-inverse')
|
wolffd@0
|
56 disp('of the design matrix to find the second layer weights.')
|
wolffd@0
|
57 disp(' ')
|
wolffd@0
|
58 disp('Press any key to continue.')
|
wolffd@0
|
59 pause
|
wolffd@0
|
60 disp('Error values from EM training.')
|
wolffd@0
|
61 % Use fast training method
|
wolffd@0
|
62 options = foptions;
|
wolffd@0
|
63 options(1) = 1; % Display EM training
|
wolffd@0
|
64 options(14) = 10; % number of iterations of EM
|
wolffd@0
|
65 net = rbftrain(net, options, tr_in, t);
|
wolffd@0
|
66
|
wolffd@0
|
67 disp(' ')
|
wolffd@0
|
68 disp('Press any key to continue.')
|
wolffd@0
|
69 pause
|
wolffd@0
|
70 clc
|
wolffd@0
|
71 disp('The second RBF network has thin plate spline activations.')
|
wolffd@0
|
72 disp('The same centres are used again, so we just need to calculate')
|
wolffd@0
|
73 disp('the second layer weights.')
|
wolffd@0
|
74 disp(' ')
|
wolffd@0
|
75 disp('Press any key to continue.')
|
wolffd@0
|
76 pause
|
wolffd@0
|
77 % Create a second RBF with thin plate spline functions
|
wolffd@0
|
78 net2 = rbf(nin, nhidden, nout, 'tps');
|
wolffd@0
|
79
|
wolffd@0
|
80 % Re-use previous centres rather than calling rbftrain again
|
wolffd@0
|
81 net2.c = net.c;
|
wolffd@0
|
82 [y, act2] = rbffwd(net2, tr_in);
|
wolffd@0
|
83
|
wolffd@0
|
84 % Solve for new output weights and biases from RBF activations
|
wolffd@0
|
85 temp = pinv([act2 ones(ndata, 1)]) * t;
|
wolffd@0
|
86 net2.w2 = temp(1:nhidden, :);
|
wolffd@0
|
87 net2.b2 = temp(nhidden+1, :);
|
wolffd@0
|
88
|
wolffd@0
|
89 disp('The third RBF network has r^4 log r activations.')
|
wolffd@0
|
90 disp(' ')
|
wolffd@0
|
91 disp('Press any key to continue.')
|
wolffd@0
|
92 pause
|
wolffd@0
|
93 % Create a third RBF with r^4 log r functions
|
wolffd@0
|
94 net3 = rbf(nin, nhidden, nout, 'r4logr');
|
wolffd@0
|
95
|
wolffd@0
|
96 % Overwrite weight vector with parameters from first RBF
|
wolffd@0
|
97 net3.c = net.c;
|
wolffd@0
|
98 [y, act3] = rbffwd(net3, tr_in);
|
wolffd@0
|
99 temp = pinv([act3 ones(ndata, 1)]) * t;
|
wolffd@0
|
100 net3.w2 = temp(1:nhidden, :);
|
wolffd@0
|
101 net3.b2 = temp(nhidden+1, :);
|
wolffd@0
|
102
|
wolffd@0
|
103 disp('Now we plot the data, underlying function, and network outputs')
|
wolffd@0
|
104 disp('on a single graph to compare the results.')
|
wolffd@0
|
105 disp(' ')
|
wolffd@0
|
106 disp('Press any key to continue.')
|
wolffd@0
|
107 pause
|
wolffd@0
|
108 % Plot the data, the original function, and the trained network functions.
|
wolffd@0
|
109 plotvals = [x(1):0.01:x(end)]';
|
wolffd@0
|
110 inputvals = (plotvals-mu)./sigma;
|
wolffd@0
|
111 y = rbffwd(net, inputvals);
|
wolffd@0
|
112 y2 = rbffwd(net2, inputvals);
|
wolffd@0
|
113 y3 = rbffwd(net3, inputvals);
|
wolffd@0
|
114 fh1 = figure;
|
wolffd@0
|
115
|
wolffd@0
|
116 plot(x, t, 'ob')
|
wolffd@0
|
117 hold on
|
wolffd@0
|
118 xlabel('Input')
|
wolffd@0
|
119 ylabel('Target')
|
wolffd@0
|
120 axis([x(1) x(end) -1.5 1.5])
|
wolffd@0
|
121 [fx, fy] = fplot('sin(2*pi*x)', [x(1) x(end)]);
|
wolffd@0
|
122 plot(fx, fy, '-r', 'LineWidth', 2)
|
wolffd@0
|
123 plot(plotvals, y, '--g', 'LineWidth', 2)
|
wolffd@0
|
124 plot(plotvals, y2, 'k--', 'LineWidth', 2)
|
wolffd@0
|
125 plot(plotvals, y3, '-.c', 'LineWidth', 2)
|
wolffd@0
|
126 legend('data', 'function', 'Gaussian RBF', 'Thin plate spline RBF', ...
|
wolffd@0
|
127 'r^4 log r RBF');
|
wolffd@0
|
128 hold off
|
wolffd@0
|
129
|
wolffd@0
|
130 disp('RBF training errors are');
|
wolffd@0
|
131 disp(['Gaussian ', num2str(rbferr(net, tr_in, t)), ' TPS ', ...
|
wolffd@0
|
132 num2str(rbferr(net2, tr_in, t)), ' R4logr ', num2str(rbferr(net3, tr_in, t))]);
|
wolffd@0
|
133
|
wolffd@0
|
134 disp(' ')
|
wolffd@0
|
135 disp('Press any key to end.')
|
wolffd@0
|
136 pause
|
wolffd@0
|
137 close(fh1);
|
wolffd@0
|
138 clear all;
|