comparison toolboxes/FullBNT-1.0.7/netlab3.3/demrbf1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMRBF1 Demonstrate simple regression using a radial basis function network.
2 %
3 % Description
4 % The problem consists of one input variable X and one target variable
5 % T with data generated by sampling X at equal intervals and then
6 % generating target data by computing SIN(2*PI*X) and adding Gaussian
7 % noise. This data is the same as that used in demmlp1.
8 %
9 % Three different RBF networks (with different activation functions)
10 % are trained in two stages. First, a Gaussian mixture model is trained
11 % using the EM algorithm, and the centres of this model are used to set
12 % the centres of the RBF. Second, the output weights (and biases) are
13 % determined using the pseudo-inverse of the design matrix.
14 %
15 % See also
16 % DEMMLP1, RBF, RBFFWD, GMM, GMMEM
17 %
18
19 % Copyright (c) Ian T Nabney (1996-2001)
20
21
22 % Generate the matrix of inputs x and targets t.
23 randn('state', 42);
24 rand('state', 42);
25 ndata = 20; % Number of data points.
26 noise = 0.2; % Standard deviation of noise distribution.
27 x = (linspace(0, 1, ndata))';
28 t = sin(2*pi*x) + noise*randn(ndata, 1);
29 mu = mean(x);
30 sigma = std(x);
31 tr_in = (x - mu)./(sigma);
32
33 clc
34 disp('This demonstration illustrates the use of a Radial Basis Function')
35 disp('network for regression problems. The data is generated from a noisy')
36 disp('sine function.')
37 disp(' ')
38 disp('Press any key to continue.')
39 pause
40 % Set up network parameters.
41 nin = 1; % Number of inputs.
42 nhidden = 7; % Number of hidden units.
43 nout = 1; % Number of outputs.
44
45 clc
46 disp('We assess the effect of three different activation functions.')
47 disp('First we create a network with Gaussian activations.')
48 disp(' ')
49 disp('Press any key to continue.')
50 pause
51 % Create and initialize network weight and parameter vectors.
52 net = rbf(nin, nhidden, nout, 'gaussian');
53
54 disp('A two-stage training algorithm is used: it uses a small number of')
55 disp('iterations of EM to position the centres, and then the pseudo-inverse')
56 disp('of the design matrix to find the second layer weights.')
57 disp(' ')
58 disp('Press any key to continue.')
59 pause
60 disp('Error values from EM training.')
61 % Use fast training method
62 options = foptions;
63 options(1) = 1; % Display EM training
64 options(14) = 10; % number of iterations of EM
65 net = rbftrain(net, options, tr_in, t);
66
67 disp(' ')
68 disp('Press any key to continue.')
69 pause
70 clc
71 disp('The second RBF network has thin plate spline activations.')
72 disp('The same centres are used again, so we just need to calculate')
73 disp('the second layer weights.')
74 disp(' ')
75 disp('Press any key to continue.')
76 pause
77 % Create a second RBF with thin plate spline functions
78 net2 = rbf(nin, nhidden, nout, 'tps');
79
80 % Re-use previous centres rather than calling rbftrain again
81 net2.c = net.c;
82 [y, act2] = rbffwd(net2, tr_in);
83
84 % Solve for new output weights and biases from RBF activations
85 temp = pinv([act2 ones(ndata, 1)]) * t;
86 net2.w2 = temp(1:nhidden, :);
87 net2.b2 = temp(nhidden+1, :);
88
89 disp('The third RBF network has r^4 log r activations.')
90 disp(' ')
91 disp('Press any key to continue.')
92 pause
93 % Create a third RBF with r^4 log r functions
94 net3 = rbf(nin, nhidden, nout, 'r4logr');
95
96 % Overwrite weight vector with parameters from first RBF
97 net3.c = net.c;
98 [y, act3] = rbffwd(net3, tr_in);
99 temp = pinv([act3 ones(ndata, 1)]) * t;
100 net3.w2 = temp(1:nhidden, :);
101 net3.b2 = temp(nhidden+1, :);
102
103 disp('Now we plot the data, underlying function, and network outputs')
104 disp('on a single graph to compare the results.')
105 disp(' ')
106 disp('Press any key to continue.')
107 pause
108 % Plot the data, the original function, and the trained network functions.
109 plotvals = [x(1):0.01:x(end)]';
110 inputvals = (plotvals-mu)./sigma;
111 y = rbffwd(net, inputvals);
112 y2 = rbffwd(net2, inputvals);
113 y3 = rbffwd(net3, inputvals);
114 fh1 = figure;
115
116 plot(x, t, 'ob')
117 hold on
118 xlabel('Input')
119 ylabel('Target')
120 axis([x(1) x(end) -1.5 1.5])
121 [fx, fy] = fplot('sin(2*pi*x)', [x(1) x(end)]);
122 plot(fx, fy, '-r', 'LineWidth', 2)
123 plot(plotvals, y, '--g', 'LineWidth', 2)
124 plot(plotvals, y2, 'k--', 'LineWidth', 2)
125 plot(plotvals, y3, '-.c', 'LineWidth', 2)
126 legend('data', 'function', 'Gaussian RBF', 'Thin plate spline RBF', ...
127 'r^4 log r RBF');
128 hold off
129
130 disp('RBF training errors are');
131 disp(['Gaussian ', num2str(rbferr(net, tr_in, t)), ' TPS ', ...
132 num2str(rbferr(net2, tr_in, t)), ' R4logr ', num2str(rbferr(net3, tr_in, t))]);
133
134 disp(' ')
135 disp('Press any key to end.')
136 pause
137 close(fh1);
138 clear all;