comparison toolboxes/FullBNT-1.0.7/netlab3.3/demgp.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMGP Demonstrate simple regression using a Gaussian Process.
2 %
3 % Description
4 % The problem consists of one input variable X and one target variable
5 % T. The values in X are chosen in two separated clusters and the
6 % target data is generated by computing SIN(2*PI*X) and adding Gaussian
7 % noise. Two Gaussian Processes, each with different covariance
8 % functions are trained by optimising the hyperparameters using the
9 % scaled conjugate gradient algorithm. The final predictions are
10 % plotted together with 2 standard deviation error bars.
11 %
12 % See also
13 % GP, GPERR, GPFWD, GPGRAD, GPINIT, SCG
14 %
15
16 % Copyright (c) Ian T Nabney (1996-2001)
17
18
19 % Find out if flops is available (i.e. pre-version 6 Matlab)
20 v = version;
21 if (str2num(strtok(v, '.')) >= 6)
22 flops_works = logical(0);
23 else
24 flops_works = logical(1);
25 end
26
27 randn('state', 42);
28 x = [0.1 0.15 0.2 0.25 0.65 0.7 0.75 0.8 0.85 0.9]';
29 ndata = length(x);
30 t = sin(2*pi*x) + 0.05*randn(ndata, 1);
31
32 xtest = linspace(0, 1, 50)';
33
34 clc
35 disp('This demonstration illustrates the use of a Gaussian Process')
36 disp('model for regression problems. The data is generated from a noisy')
37 disp('sine function.')
38 disp(' ')
39 disp('Press any key to continue.')
40 pause
41
42 flops(0);
43 % Initialise the parameters.
44 net = gp(1, 'sqexp');
45 prior.pr_mean = 0;
46 prior.pr_var = 1;
47 net = gpinit(net, x, t, prior);
48
49 clc
50 disp('The first GP uses the squared exponential covariance function.')
51 disp('The hyperparameters are initialised by sampling from a Gaussian with a')
52 disp(['mean of ', num2str(prior.pr_mean), ' and variance ', ...
53 num2str(prior.pr_var), '.'])
54 disp('After initializing the network, we train it using the scaled conjugate')
55 disp('gradients algorithm for 20 cycles.')
56 disp(' ')
57 disp('Press any key to continue')
58 pause
59
60 % Now train to find the hyperparameters.
61 options = foptions;
62 options(1) = 1; % Display training error values
63 options(14) = 20;
64 flops(0)
65 [net, options] = netopt(net, options, x, t, 'scg');
66 if flops_works
67 sflops = flops;
68 end
69
70 disp('The second GP uses the rational quadratic covariance function.')
71 disp('The hyperparameters are initialised by sampling from a Gaussian with a')
72 disp(['mean of ', num2str(prior.pr_mean), ' and variance ', num2str(prior.pr_var)])
73 disp('After initializing the network, we train it using the scaled conjugate')
74 disp('gradients algorithm for 20 cycles.')
75 disp(' ')
76 disp('Press any key to continue')
77 pause
78 flops(0)
79 net2 = gp(1, 'ratquad');
80 net2 = gpinit(net2, x, t, prior);
81 flops(0)
82 [net2, options] = netopt(net2, options, x, t, 'scg');
83 if flops_works
84 rflops = flops;
85 end
86
87 disp(' ')
88 disp('Press any key to continue')
89 disp(' ')
90 pause
91 clc
92
93 fprintf(1, 'For squared exponential covariance function,');
94 if flops_works
95 fprintf(1, 'flops = %d', sflops);
96 end
97 fprintf(1, '\nfinal hyperparameters:\n')
98 format_string = strcat(' bias:\t\t\t%10.6f\n noise:\t\t%10.6f\n', ...
99 ' inverse lengthscale:\t%10.6f\n vertical scale:\t%10.6f\n');
100 fprintf(1, format_string, ...
101 exp(net.bias), exp(net.noise), exp(net.inweights(1)), exp(net.fpar(1)));
102 fprintf(1, '\n\nFor rational quadratic covariance function,');
103 if flops_works
104 fprintf(1, 'flops = %d', rflops);
105 end
106 fprintf(1, '\nfinal hyperparameters:\n')
107 format_string = [format_string ' cov decay order:\t%10.6f\n'];
108 fprintf(1, format_string, ...
109 exp(net2.bias), exp(net2.noise), exp(net2.inweights(1)), ...
110 exp(net2.fpar(1)), exp(net2.fpar(2)));
111 disp(' ')
112 disp('Press any key to continue')
113 pause
114
115 disp(' ')
116 disp('Now we plot the data, underlying function, model outputs and two')
117 disp('standard deviation error bars on a single graph to compare the results.')
118 disp(' ')
119 disp('Press any key to continue.')
120 pause
121 cn = gpcovar(net, x);
122 cninv = inv(cn);
123 [ytest, sigsq] = gpfwd(net, xtest, cninv);
124 sig = sqrt(sigsq);
125
126 fh1 = figure;
127 hold on
128 plot(x, t, 'ok');
129 xlabel('Input')
130 ylabel('Target')
131 fplot('sin(2*pi*x)', [0 1], '--m');
132 plot(xtest, ytest, '-k');
133 plot(xtest, ytest+(2*sig), '-b', xtest, ytest-(2*sig), '-b');
134 axis([0 1 -1.5 1.5]);
135 title('Squared exponential covariance function')
136 legend('data', 'function', 'GP', 'error bars');
137 hold off
138
139 cninv2 = inv(gpcovar(net2, x));
140 [ytest2, sigsq2] = gpfwd(net2, xtest, cninv2);
141 sig2 = sqrt(sigsq2);
142 fh2 = figure;
143 hold on
144 plot(x, t, 'ok');
145 xlabel('Input')
146 ylabel('Target')
147 fplot('sin(2*pi*x)', [0 1], '--m');
148 plot(xtest, ytest2, '-k');
149 plot(xtest, ytest2+(2*sig2), '-b', xtest, ytest2-(2*sig2), '-b');
150 axis([0 1 -1.5 1.5]);
151 title('Rational quadratic covariance function')
152 legend('data', 'function', 'GP', 'error bars');
153 hold off
154
155 disp(' ')
156 disp('Press any key to end.')
157 pause
158 close(fh1);
159 close(fh2);
160 clear all;