comparison toolboxes/FullBNT-1.0.7/netlab3.3/demgpard.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMGPARD Demonstrate ARD using a Gaussian Process.
2 %
3 % Description
4 % The data consists of three input variables X1, X2 and X3, and one
5 % target variable T. The target data is generated by computing
6 % SIN(2*PI*X1) and adding Gaussian noise, x2 is a copy of x1 with a
7 % higher level of added noise, and x3 is sampled randomly from a
8 % Gaussian distribution. A Gaussian Process, is trained by optimising
9 % the hyperparameters using the scaled conjugate gradient algorithm.
10 % The final values of the hyperparameters show that the model
11 % successfully identifies the importance of each input.
12 %
13 % See also
14 % DEMGP, GP, GPERR, GPFWD, GPGRAD, GPINIT, SCG
15 %
16
17 % Copyright (c) Ian T Nabney (1996-2001)
18
19 clc;
20 randn('state', 1729);
21 rand('state', 1729);
22 disp('This demonstration illustrates the technique of automatic relevance')
23 disp('determination (ARD) using a Gaussian Process.')
24 disp(' ');
25 disp('First, we set up a synthetic data set involving three input variables:')
26 disp('x1 is sampled uniformly from the range (0,1) and has a low level of')
27 disp('added Gaussian noise, x2 is a copy of x1 with a higher level of added')
28 disp('noise, and x3 is sampled randomly from a Gaussian distribution. The')
29 disp('single target variable is given by t = sin(2*pi*x1) with additive')
30 disp('Gaussian noise. Thus x1 is very relevant for determining the target')
31 disp('value, x2 is of some relevance, while x3 should in principle be')
32 disp('irrelevant.')
33 disp(' ');
34 disp('Press any key to see a plot of t against x1.')
35 pause;
36
37 ndata = 100;
38 x1 = rand(ndata, 1);
39 x2 = x1 + 0.05*randn(ndata, 1);
40 x3 = 0.5 + 0.5*randn(ndata, 1);
41 x = [x1, x2, x3];
42 t = sin(2*pi*x1) + 0.1*randn(ndata, 1);
43
44 % Plot the data and the original function.
45 h = figure;
46 plotvals = linspace(0, 1, 200)';
47 plot(x1, t, 'ob')
48 hold on
49 xlabel('Input x1')
50 ylabel('Target')
51 axis([0 1 -1.5 1.5])
52 [fx, fy] = fplot('sin(2*pi*x)', [0 1]);
53 plot(fx, fy, '-g', 'LineWidth', 2);
54 legend('data', 'function');
55
56 disp(' ');
57 disp('Press any key to continue')
58 pause; clc;
59
60 disp('The Gaussian Process has a separate hyperparameter for each input.')
61 disp('The hyperparameters are trained by error minimisation using the scaled.')
62 disp('conjugate gradient optimiser.')
63 disp(' ');
64 disp('Press any key to create and train the model.')
65 disp(' ');
66 pause;
67
68 net = gp(3, 'sqexp');
69 % Initialise the parameters.
70 prior.pr_mean = 0;
71 prior.pr_var = 0.1;
72 net = gpinit(net, x, t, prior);
73
74 % Now train to find the hyperparameters.
75 options = foptions;
76 options(1) = 1;
77 options(14) = 30;
78
79 [net, options] = netopt(net, options, x, t, 'scg');
80
81 rel = exp(net.inweights);
82
83 fprintf(1, ...
84 '\nFinal hyperparameters:\n\n bias:\t\t%10.6f\n noise:\t%10.6f\n', ...
85 exp(net.bias), exp(net.noise));
86 fprintf(1, ' Vertical scale: %8.6f\n', exp(net.fpar(1)));
87 fprintf(1, ' Input 1:\t%10.6f\n Input 2:\t%10.6f\n', ...
88 rel(1), rel(2));
89 fprintf(1, ' Input 3:\t%10.6f\n\n', rel(3));
90 disp(' ');
91 disp('We see that the inverse lengthscale associated with')
92 disp('input x1 is large, that of x2 has an intermediate value and the variance')
93 disp('of weights associated with x3 is small.')
94 disp(' ');
95 disp('This implies that the Gaussian Process is giving greatest emphasis')
96 disp('to x1 and least emphasis to x3, with intermediate emphasis on')
97 disp('x2 in the covariance function.')
98 disp(' ')
99 disp('Since the target t is statistically independent of x3 we might')
100 disp('expect the weights associated with this input would go to')
101 disp('zero. However, for any finite data set there may be some chance')
102 disp('correlation between x3 and t, and so the corresponding hyperparameter remains')
103 disp('finite.')
104 disp('Press any key to continue.')
105 pause
106
107 disp('Finally, we plot the output of the Gaussian Process along the line')
108 disp('x1 = x2 = x3, together with the true underlying function.')
109 xt = linspace(0, 1, 50);
110 xtest = [xt', xt', xt'];
111
112 cn = gpcovar(net, x);
113 cninv = inv(cn);
114 [ytest, sigsq] = gpfwd(net, xtest, cninv);
115 sig = sqrt(sigsq);
116
117 figure(h); hold on;
118 plot(xt, ytest, '-k');
119 plot(xt, ytest+(2*sig), '-b', xt, ytest-(2*sig), '-b');
120 axis([0 1 -1.5 1.5]);
121 fplot('sin(2*pi*x)', [0 1], '--m');
122
123 disp(' ');
124 disp('Press any key to end.')
125 pause; clc; close(h); clear all
126