wolffd@0
|
1 %DEMGP Demonstrate simple regression using a Gaussian Process.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % The problem consists of one input variable X and one target variable
|
wolffd@0
|
5 % T. The values in X are chosen in two separated clusters and the
|
wolffd@0
|
6 % target data is generated by computing SIN(2*PI*X) and adding Gaussian
|
wolffd@0
|
7 % noise. Two Gaussian Processes, each with different covariance
|
wolffd@0
|
8 % functions are trained by optimising the hyperparameters using the
|
wolffd@0
|
9 % scaled conjugate gradient algorithm. The final predictions are
|
wolffd@0
|
10 % plotted together with 2 standard deviation error bars.
|
wolffd@0
|
11 %
|
wolffd@0
|
12 % See also
|
wolffd@0
|
13 % GP, GPERR, GPFWD, GPGRAD, GPINIT, SCG
|
wolffd@0
|
14 %
|
wolffd@0
|
15
|
wolffd@0
|
16 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
17
|
wolffd@0
|
18
|
wolffd@0
|
19 % Find out if flops is available (i.e. pre-version 6 Matlab)
|
wolffd@0
|
20 v = version;
|
wolffd@0
|
21 if (str2num(strtok(v, '.')) >= 6)
|
wolffd@0
|
22 flops_works = logical(0);
|
wolffd@0
|
23 else
|
wolffd@0
|
24 flops_works = logical(1);
|
wolffd@0
|
25 end
|
wolffd@0
|
26
|
wolffd@0
|
27 randn('state', 42);
|
wolffd@0
|
28 x = [0.1 0.15 0.2 0.25 0.65 0.7 0.75 0.8 0.85 0.9]';
|
wolffd@0
|
29 ndata = length(x);
|
wolffd@0
|
30 t = sin(2*pi*x) + 0.05*randn(ndata, 1);
|
wolffd@0
|
31
|
wolffd@0
|
32 xtest = linspace(0, 1, 50)';
|
wolffd@0
|
33
|
wolffd@0
|
34 clc
|
wolffd@0
|
35 disp('This demonstration illustrates the use of a Gaussian Process')
|
wolffd@0
|
36 disp('model for regression problems. The data is generated from a noisy')
|
wolffd@0
|
37 disp('sine function.')
|
wolffd@0
|
38 disp(' ')
|
wolffd@0
|
39 disp('Press any key to continue.')
|
wolffd@0
|
40 pause
|
wolffd@0
|
41
|
wolffd@0
|
42 flops(0);
|
wolffd@0
|
43 % Initialise the parameters.
|
wolffd@0
|
44 net = gp(1, 'sqexp');
|
wolffd@0
|
45 prior.pr_mean = 0;
|
wolffd@0
|
46 prior.pr_var = 1;
|
wolffd@0
|
47 net = gpinit(net, x, t, prior);
|
wolffd@0
|
48
|
wolffd@0
|
49 clc
|
wolffd@0
|
50 disp('The first GP uses the squared exponential covariance function.')
|
wolffd@0
|
51 disp('The hyperparameters are initialised by sampling from a Gaussian with a')
|
wolffd@0
|
52 disp(['mean of ', num2str(prior.pr_mean), ' and variance ', ...
|
wolffd@0
|
53 num2str(prior.pr_var), '.'])
|
wolffd@0
|
54 disp('After initializing the network, we train it using the scaled conjugate')
|
wolffd@0
|
55 disp('gradients algorithm for 20 cycles.')
|
wolffd@0
|
56 disp(' ')
|
wolffd@0
|
57 disp('Press any key to continue')
|
wolffd@0
|
58 pause
|
wolffd@0
|
59
|
wolffd@0
|
60 % Now train to find the hyperparameters.
|
wolffd@0
|
61 options = foptions;
|
wolffd@0
|
62 options(1) = 1; % Display training error values
|
wolffd@0
|
63 options(14) = 20;
|
wolffd@0
|
64 flops(0)
|
wolffd@0
|
65 [net, options] = netopt(net, options, x, t, 'scg');
|
wolffd@0
|
66 if flops_works
|
wolffd@0
|
67 sflops = flops;
|
wolffd@0
|
68 end
|
wolffd@0
|
69
|
wolffd@0
|
70 disp('The second GP uses the rational quadratic covariance function.')
|
wolffd@0
|
71 disp('The hyperparameters are initialised by sampling from a Gaussian with a')
|
wolffd@0
|
72 disp(['mean of ', num2str(prior.pr_mean), ' and variance ', num2str(prior.pr_var)])
|
wolffd@0
|
73 disp('After initializing the network, we train it using the scaled conjugate')
|
wolffd@0
|
74 disp('gradients algorithm for 20 cycles.')
|
wolffd@0
|
75 disp(' ')
|
wolffd@0
|
76 disp('Press any key to continue')
|
wolffd@0
|
77 pause
|
wolffd@0
|
78 flops(0)
|
wolffd@0
|
79 net2 = gp(1, 'ratquad');
|
wolffd@0
|
80 net2 = gpinit(net2, x, t, prior);
|
wolffd@0
|
81 flops(0)
|
wolffd@0
|
82 [net2, options] = netopt(net2, options, x, t, 'scg');
|
wolffd@0
|
83 if flops_works
|
wolffd@0
|
84 rflops = flops;
|
wolffd@0
|
85 end
|
wolffd@0
|
86
|
wolffd@0
|
87 disp(' ')
|
wolffd@0
|
88 disp('Press any key to continue')
|
wolffd@0
|
89 disp(' ')
|
wolffd@0
|
90 pause
|
wolffd@0
|
91 clc
|
wolffd@0
|
92
|
wolffd@0
|
93 fprintf(1, 'For squared exponential covariance function,');
|
wolffd@0
|
94 if flops_works
|
wolffd@0
|
95 fprintf(1, 'flops = %d', sflops);
|
wolffd@0
|
96 end
|
wolffd@0
|
97 fprintf(1, '\nfinal hyperparameters:\n')
|
wolffd@0
|
98 format_string = strcat(' bias:\t\t\t%10.6f\n noise:\t\t%10.6f\n', ...
|
wolffd@0
|
99 ' inverse lengthscale:\t%10.6f\n vertical scale:\t%10.6f\n');
|
wolffd@0
|
100 fprintf(1, format_string, ...
|
wolffd@0
|
101 exp(net.bias), exp(net.noise), exp(net.inweights(1)), exp(net.fpar(1)));
|
wolffd@0
|
102 fprintf(1, '\n\nFor rational quadratic covariance function,');
|
wolffd@0
|
103 if flops_works
|
wolffd@0
|
104 fprintf(1, 'flops = %d', rflops);
|
wolffd@0
|
105 end
|
wolffd@0
|
106 fprintf(1, '\nfinal hyperparameters:\n')
|
wolffd@0
|
107 format_string = [format_string ' cov decay order:\t%10.6f\n'];
|
wolffd@0
|
108 fprintf(1, format_string, ...
|
wolffd@0
|
109 exp(net2.bias), exp(net2.noise), exp(net2.inweights(1)), ...
|
wolffd@0
|
110 exp(net2.fpar(1)), exp(net2.fpar(2)));
|
wolffd@0
|
111 disp(' ')
|
wolffd@0
|
112 disp('Press any key to continue')
|
wolffd@0
|
113 pause
|
wolffd@0
|
114
|
wolffd@0
|
115 disp(' ')
|
wolffd@0
|
116 disp('Now we plot the data, underlying function, model outputs and two')
|
wolffd@0
|
117 disp('standard deviation error bars on a single graph to compare the results.')
|
wolffd@0
|
118 disp(' ')
|
wolffd@0
|
119 disp('Press any key to continue.')
|
wolffd@0
|
120 pause
|
wolffd@0
|
121 cn = gpcovar(net, x);
|
wolffd@0
|
122 cninv = inv(cn);
|
wolffd@0
|
123 [ytest, sigsq] = gpfwd(net, xtest, cninv);
|
wolffd@0
|
124 sig = sqrt(sigsq);
|
wolffd@0
|
125
|
wolffd@0
|
126 fh1 = figure;
|
wolffd@0
|
127 hold on
|
wolffd@0
|
128 plot(x, t, 'ok');
|
wolffd@0
|
129 xlabel('Input')
|
wolffd@0
|
130 ylabel('Target')
|
wolffd@0
|
131 fplot('sin(2*pi*x)', [0 1], '--m');
|
wolffd@0
|
132 plot(xtest, ytest, '-k');
|
wolffd@0
|
133 plot(xtest, ytest+(2*sig), '-b', xtest, ytest-(2*sig), '-b');
|
wolffd@0
|
134 axis([0 1 -1.5 1.5]);
|
wolffd@0
|
135 title('Squared exponential covariance function')
|
wolffd@0
|
136 legend('data', 'function', 'GP', 'error bars');
|
wolffd@0
|
137 hold off
|
wolffd@0
|
138
|
wolffd@0
|
139 cninv2 = inv(gpcovar(net2, x));
|
wolffd@0
|
140 [ytest2, sigsq2] = gpfwd(net2, xtest, cninv2);
|
wolffd@0
|
141 sig2 = sqrt(sigsq2);
|
wolffd@0
|
142 fh2 = figure;
|
wolffd@0
|
143 hold on
|
wolffd@0
|
144 plot(x, t, 'ok');
|
wolffd@0
|
145 xlabel('Input')
|
wolffd@0
|
146 ylabel('Target')
|
wolffd@0
|
147 fplot('sin(2*pi*x)', [0 1], '--m');
|
wolffd@0
|
148 plot(xtest, ytest2, '-k');
|
wolffd@0
|
149 plot(xtest, ytest2+(2*sig2), '-b', xtest, ytest2-(2*sig2), '-b');
|
wolffd@0
|
150 axis([0 1 -1.5 1.5]);
|
wolffd@0
|
151 title('Rational quadratic covariance function')
|
wolffd@0
|
152 legend('data', 'function', 'GP', 'error bars');
|
wolffd@0
|
153 hold off
|
wolffd@0
|
154
|
wolffd@0
|
155 disp(' ')
|
wolffd@0
|
156 disp('Press any key to end.')
|
wolffd@0
|
157 pause
|
wolffd@0
|
158 close(fh1);
|
wolffd@0
|
159 close(fh2);
|
wolffd@0
|
160 clear all; |