comparison toolboxes/FullBNT-1.0.7/netlab3.3/demolgd1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMOLGD1 Demonstrate simple MLP optimisation with on-line gradient descent
2 %
3 % Description
4 % The problem consists of one input variable X and one target variable
5 % T with data generated by sampling X at equal intervals and then
6 % generating target data by computing SIN(2*PI*X) and adding Gaussian
7 % noise. A 2-layer network with linear outputs is trained by minimizing
8 % a sum-of-squares error function using on-line gradient descent.
9 %
10 % See also
11 % DEMMLP1, OLGD
12 %
13
14 % Copyright (c) Ian T Nabney (1996-2001)
15
16
17 % Generate the matrix of inputs x and targets t.
18
19 ndata = 20; % Number of data points.
20 noise = 0.2; % Standard deviation of noise distribution.
21 x = [0:1/(ndata - 1):1]';
22 randn('state', 42);
23 rand('state', 42);
24 t = sin(2*pi*x) + noise*randn(ndata, 1);
25
26 clc
27 disp('This demonstration illustrates the use of the on-line gradient')
28 disp('descent algorithm to train a Multi-Layer Perceptron network for')
29 disp('regression problems. It is intended to illustrate the drawbacks')
30 disp('of this algorithm compared to more powerful non-linear optimisation')
31 disp('algorithms, such as conjugate gradients.')
32 disp(' ')
33 disp('First we generate the data from a noisy sine function and construct')
34 disp('the network.')
35 disp(' ')
36 disp('Press any key to continue.')
37 pause
38 % Set up network parameters.
39 nin = 1; % Number of inputs.
40 nhidden = 3; % Number of hidden units.
41 nout = 1; % Number of outputs.
42 alpha = 0.01; % Coefficient of weight-decay prior.
43
44 % Create and initialize network weight vector.
45 net = mlp(nin, nhidden, nout, 'linear');
46 % Initialise weights reasonably close to 0
47 net = mlpinit(net, 10);
48
49 % Set up vector of options for the optimiser.
50 options = foptions;
51 options(1) = 1; % This provides display of error values.
52 options(14) = 20; % Number of training cycles.
53 options(18) = 0.1; % Learning rate
54 %options(17) = 0.4; % Momentum
55 options(17) = 0.4; % Momentum
56 options(5) = 1; % Do randomise pattern order
57 clc
58 disp('Then we set the options for the training algorithm.')
59 disp(['In the first phase of training, which lasts for ',...
60 num2str(options(14)), ' cycles,'])
61 disp(['the learning rate is ', num2str(options(18)), ...
62 ' and the momentum is ', num2str(options(17)), '.'])
63 disp('The error values are displayed at the end of each pass through the')
64 disp('entire pattern set.')
65 disp(' ')
66 disp('Press any key to continue.')
67 pause
68
69 % Train using online gradient descent
70 [net, options] = olgd(net, options, x, t);
71
72 % Now allow learning rate to decay and remove momentum
73 options(2) = 0;
74 options(3) = 0;
75 options(17) = 0.4; % Turn off momentum
76 options(5) = 1; % Randomise pattern order
77 options(6) = 1; % Set learning rate decay on
78 options(14) = 200;
79 options(18) = 0.1; % Initial learning rate
80
81 disp(['In the second phase of training, which lasts for up to ',...
82 num2str(options(14)), ' cycles,'])
83 disp(['the learning rate starts at ', num2str(options(18)), ...
84 ', decaying at 1/t and the momentum is ', num2str(options(17)), '.'])
85 disp(' ')
86 disp('Press any key to continue.')
87 pause
88 [net, options] = olgd(net, options, x, t);
89
90 clc
91 disp('Now we plot the data, underlying function, and network outputs')
92 disp('on a single graph to compare the results.')
93 disp(' ')
94 disp('Press any key to continue.')
95 pause
96
97 % Plot the data, the original function, and the trained network function.
98 plotvals = [0:0.01:1]';
99 y = mlpfwd(net, plotvals);
100 fh1 = figure;
101 plot(x, t, 'ob')
102 hold on
103 axis([0 1 -1.5 1.5])
104 fplot('sin(2*pi*x)', [0 1], '--g')
105 plot(plotvals, y, '-r')
106 legend('data', 'function', 'network');
107 hold off
108
109 disp('Note the very poor fit to the data: this should be compared with')
110 disp('the results obtained in demmlp1.')
111 disp(' ')
112 disp('Press any key to exit.')
113 pause
114 close(fh1);
115 clear all;