comparison toolboxes/FullBNT-1.0.7/netlab3.3/demmdn1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMMDN1 Demonstrate fitting a multi-valued function using a Mixture Density Network.
2 %
3 % Description
4 % The problem consists of one input variable X and one target variable
5 % T with data generated by sampling T at equal intervals and then
6 % generating target data by computing T + 0.3*SIN(2*PI*T) and adding
7 % Gaussian noise. A Mixture Density Network with 3 centres in the
8 % mixture model is trained by minimizing a negative log likelihood
9 % error function using the scaled conjugate gradient optimizer.
10 %
11 % The conditional means, mixing coefficients and variances are plotted
12 % as a function of X, and a contour plot of the full conditional
13 % density is also generated.
14 %
15 % See also
17 %
19 % Copyright (c) Ian T Nabney (1996-2001)
22 % Generate the matrix of inputs x and targets t.
23 seedn = 42;
24 seed = 42;
25 randn('state', seedn);
26 rand('state', seed);
27 ndata = 300; % Number of data points.
28 noise = 0.2; % Range of noise distribution.
29 t = [0:1/(ndata - 1):1]';
30 x = t + 0.3*sin(2*pi*t) + noise*rand(ndata, 1) - noise/2;
31 axis_limits = [-0.2 1.2 -0.2 1.2];
33 clc
34 disp('This demonstration illustrates the use of a Mixture Density Network')
35 disp('to model multi-valued functions. The data is generated from the')
36 disp('mapping x = t + 0.3 sin(2 pi t) + e, where e is a noise term.')
37 disp('We begin by plotting the data.')
38 disp(' ')
39 disp('Press any key to continue')
40 pause
41 % Plot the data
42 fh1 = figure;
43 p1 = plot(x, t, 'ob');
44 axis(axis_limits);
45 hold on
46 disp('Note that for x in the range 0.35 to 0.65, there are three possible')
47 disp('branches of the function.')
48 disp(' ')
49 disp('Press any key to continue')
50 pause
52 % Set up network parameters.
53 nin = 1; % Number of inputs.
54 nhidden = 5; % Number of hidden units.
55 ncentres = 3; % Number of mixture components.
56 dim_target = 1; % Dimension of target space
57 mdntype = '0'; % Currently unused: reserved for future use
58 alpha = 100; % Inverse variance for weight initialisation
59 % Make variance small for good starting point
61 % Create and initialize network weight vector.
62 net = mdn(nin, nhidden, ncentres, dim_target, mdntype);
63 init_options = zeros(1, 18);
64 init_options(1) = -1; % Suppress all messages
65 init_options(14) = 10; % 10 iterations of K means in gmminit
66 net = mdninit(net, alpha, t, init_options);
68 % Set up vector of options for the optimiser.
69 options = foptions;
70 options(1) = 1; % This provides display of error values.
71 options(14) = 200; % Number of training cycles.
73 clc
74 disp('We initialise the neural network model, which is an MLP with a')
75 disp('Gaussian mixture model with three components and spherical variance')
76 disp('as the error function. This enables us to model the complete')
77 disp('conditional density function.')
78 disp(' ')
79 disp('Next we train the model for 200 epochs using a scaled conjugate gradient')
80 disp('optimizer. The error function is the negative log likelihood of the')
81 disp('training data.')
82 disp(' ')
83 disp('Press any key to continue.')
84 pause
86 % Train using scaled conjugate gradients.
87 [net, options] = netopt(net, options, x, t, 'scg');
89 disp(' ')
90 disp('Press any key to continue.')
91 pause
93 clc
94 disp('We can also train a conventional MLP with sum of squares error function.')
95 disp('This will approximate the conditional mean, which is not always a')
96 disp('good representation of the data. Note that the error function is the')
97 disp('sum of squares error on the training data, which accounts for the')
98 disp('different values from training the MDN.')
99 disp(' ')
100 disp('We train the network with the quasi-Newton optimizer for 80 epochs.')
101 disp(' ')
102 disp('Press any key to continue.')
103 pause
104 mlp_nhidden = 8;
105 net2 = mlp(nin, mlp_nhidden, dim_target, 'linear');
106 options(14) = 80;
107 [net2, options] = netopt(net2, options, x, t, 'quasinew');
108 disp(' ')
109 disp('Press any key to continue.')
110 pause
112 clc
113 disp('Now we plot the underlying function, the MDN prediction,')
114 disp('represented by the mode of the conditional distribution, and the')
115 disp('prediction of the conventional MLP.')
116 disp(' ')
117 disp('Press any key to continue.')
118 pause
120 % Plot the original function, and the trained network function.
121 plotvals = [0:0.01:1]';
122 mixes = mdn2gmm(mdnfwd(net, plotvals));
123 axis(axis_limits);
124 yplot = t+0.3*sin(2*pi*t);
125 p2 = plot(yplot, t, '--y');
127 % Use the mode to represent the function
128 y = zeros(1, length(plotvals));
129 priors = zeros(length(plotvals), ncentres);
130 c = zeros(length(plotvals), 3);
131 widths = zeros(length(plotvals), ncentres);
132 for i = 1:length(plotvals)
133 [m, j] = max(mixes(i).priors);
134 y(i) = mixes(i).centres(j,:);
135 c(i,:) = mixes(i).centres';
136 end
137 p3 = plot(plotvals, y, '*r');
138 p4 = plot(plotvals, mlpfwd(net2, plotvals), 'g');
139 set(p4, 'LineWidth', 2);
140 legend([p1 p2 p3 p4], 'data', 'function', 'MDN mode', 'MLP mean', 4);
141 hold off
143 clc
144 disp('We can also plot how the mixture model parameters depend on x.')
145 disp('First we plot the mixture centres, then the priors and finally')
146 disp('the variances.')
147 disp(' ')
148 disp('Press any key to continue.')
149 pause
150 fh2 = figure;
151 subplot(3, 1, 1)
152 plot(plotvals, c)
153 hold on
154 title('Mixture centres')
155 legend('centre 1', 'centre 2', 'centre 3')
156 hold off
158 priors = reshape([mixes.priors], mixes(1).ncentres, size(mixes, 2))';
159 %%fh3 = figure;
160 subplot(3, 1, 2)
161 plot(plotvals, priors)
162 hold on
163 title('Mixture priors')
164 legend('centre 1', 'centre 2', 'centre 3')
165 hold off
167 variances = reshape([mixes.covars], mixes(1).ncentres, size(mixes, 2))';
168 %%fh4 = figure;
169 subplot(3, 1, 3)
170 plot(plotvals, variances)
171 hold on
172 title('Mixture variances')
173 legend('centre 1', 'centre 2', 'centre 3')
174 hold off
176 disp('The last figure is a contour plot of the conditional probability')
177 disp('density generated by the Mixture Density Network. Note how it')
178 disp('is well matched to the regions of high data density.')
179 disp(' ')
180 disp('Press any key to continue.')
181 pause
182 % Contour plot for MDN.
183 i = 0:0.01:1.0;
184 j = 0:0.01:1.0;
186 [I, J] = meshgrid(i,j);
187 I = I(:);
188 J = J(:);
189 li = length(i);
190 lj = length(j);
191 Z = zeros(li, lj);
192 for k = 1:li;
193 Z(:,k) = gmmprob(mixes(k), j');
194 end
195 fh5 = figure;
196 % Set up levels by hand to make a good figure
197 v = [2 2.5 3 3.5 5:3:18];
198 contour(i, j, Z, v)
199 hold on
200 title('Contour plot of conditional density')
201 hold off
203 disp(' ')
204 disp('Press any key to exit.')
205 pause
206 close(fh1);
207 close(fh2);
208 %%close(fh3);
209 %%close(fh4);
210 close(fh5);
211 %%clear all;