Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/demolgd1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 %DEMOLGD1 Demonstrate simple MLP optimisation with on-line gradient descent | |
2 % | |
3 % Description | |
4 % The problem consists of one input variable X and one target variable | |
5 % T with data generated by sampling X at equal intervals and then | |
6 % generating target data by computing SIN(2*PI*X) and adding Gaussian | |
7 % noise. A 2-layer network with linear outputs is trained by minimizing | |
8 % a sum-of-squares error function using on-line gradient descent. | |
9 % | |
10 % See also | |
11 % DEMMLP1, OLGD | |
12 % | |
13 | |
14 % Copyright (c) Ian T Nabney (1996-2001) | |
15 | |
16 | |
17 % Generate the matrix of inputs x and targets t. | |
18 | |
19 ndata = 20; % Number of data points. | |
20 noise = 0.2; % Standard deviation of noise distribution. | |
21 x = [0:1/(ndata - 1):1]'; | |
22 randn('state', 42); | |
23 rand('state', 42); | |
24 t = sin(2*pi*x) + noise*randn(ndata, 1); | |
25 | |
26 clc | |
27 disp('This demonstration illustrates the use of the on-line gradient') | |
28 disp('descent algorithm to train a Multi-Layer Perceptron network for') | |
29 disp('regression problems. It is intended to illustrate the drawbacks') | |
30 disp('of this algorithm compared to more powerful non-linear optimisation') | |
31 disp('algorithms, such as conjugate gradients.') | |
32 disp(' ') | |
33 disp('First we generate the data from a noisy sine function and construct') | |
34 disp('the network.') | |
35 disp(' ') | |
36 disp('Press any key to continue.') | |
37 pause | |
38 % Set up network parameters. | |
39 nin = 1; % Number of inputs. | |
40 nhidden = 3; % Number of hidden units. | |
41 nout = 1; % Number of outputs. | |
42 alpha = 0.01; % Coefficient of weight-decay prior. | |
43 | |
44 % Create and initialize network weight vector. | |
45 net = mlp(nin, nhidden, nout, 'linear'); | |
46 % Initialise weights reasonably close to 0 | |
47 net = mlpinit(net, 10); | |
48 | |
49 % Set up vector of options for the optimiser. | |
50 options = foptions; | |
51 options(1) = 1; % This provides display of error values. | |
52 options(14) = 20; % Number of training cycles. | |
53 options(18) = 0.1; % Learning rate | |
54 %options(17) = 0.4; % Momentum | |
55 options(17) = 0.4; % Momentum | |
56 options(5) = 1; % Do randomise pattern order | |
57 clc | |
58 disp('Then we set the options for the training algorithm.') | |
59 disp(['In the first phase of training, which lasts for ',... | |
60 num2str(options(14)), ' cycles,']) | |
61 disp(['the learning rate is ', num2str(options(18)), ... | |
62 ' and the momentum is ', num2str(options(17)), '.']) | |
63 disp('The error values are displayed at the end of each pass through the') | |
64 disp('entire pattern set.') | |
65 disp(' ') | |
66 disp('Press any key to continue.') | |
67 pause | |
68 | |
69 % Train using online gradient descent | |
70 [net, options] = olgd(net, options, x, t); | |
71 | |
72 % Now allow learning rate to decay and remove momentum | |
73 options(2) = 0; | |
74 options(3) = 0; | |
75 options(17) = 0.4; % Turn off momentum | |
76 options(5) = 1; % Randomise pattern order | |
77 options(6) = 1; % Set learning rate decay on | |
78 options(14) = 200; | |
79 options(18) = 0.1; % Initial learning rate | |
80 | |
81 disp(['In the second phase of training, which lasts for up to ',... | |
82 num2str(options(14)), ' cycles,']) | |
83 disp(['the learning rate starts at ', num2str(options(18)), ... | |
84 ', decaying at 1/t and the momentum is ', num2str(options(17)), '.']) | |
85 disp(' ') | |
86 disp('Press any key to continue.') | |
87 pause | |
88 [net, options] = olgd(net, options, x, t); | |
89 | |
90 clc | |
91 disp('Now we plot the data, underlying function, and network outputs') | |
92 disp('on a single graph to compare the results.') | |
93 disp(' ') | |
94 disp('Press any key to continue.') | |
95 pause | |
96 | |
97 % Plot the data, the original function, and the trained network function. | |
98 plotvals = [0:0.01:1]'; | |
99 y = mlpfwd(net, plotvals); | |
100 fh1 = figure; | |
101 plot(x, t, 'ob') | |
102 hold on | |
103 axis([0 1 -1.5 1.5]) | |
104 fplot('sin(2*pi*x)', [0 1], '--g') | |
105 plot(plotvals, y, '-r') | |
106 legend('data', 'function', 'network'); | |
107 hold off | |
108 | |
109 disp('Note the very poor fit to the data: this should be compared with') | |
110 disp('the results obtained in demmlp1.') | |
111 disp(' ') | |
112 disp('Press any key to exit.') | |
113 pause | |
114 close(fh1); | |
115 clear all; |