wolffd@0
|
1 %DEMMDN1 Demonstrate fitting a multi-valued function using a Mixture Density Network.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % The problem consists of one input variable X and one target variable
|
wolffd@0
|
5 % T with data generated by sampling T at equal intervals and then
|
wolffd@0
|
6 % generating target data by computing T + 0.3*SIN(2*PI*T) and adding
|
wolffd@0
|
7 % Gaussian noise. A Mixture Density Network with 3 centres in the
|
wolffd@0
|
8 % mixture model is trained by minimizing a negative log likelihood
|
wolffd@0
|
9 % error function using the scaled conjugate gradient optimizer.
|
wolffd@0
|
10 %
|
wolffd@0
|
11 % The conditional means, mixing coefficients and variances are plotted
|
wolffd@0
|
12 % as a function of X, and a contour plot of the full conditional
|
wolffd@0
|
13 % density is also generated.
|
wolffd@0
|
14 %
|
wolffd@0
|
15 % See also
|
wolffd@0
|
16 % MDN, MDNERR, MDNGRAD, SCG
|
wolffd@0
|
17 %
|
wolffd@0
|
18
|
wolffd@0
|
19 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
20
|
wolffd@0
|
21
|
wolffd@0
|
22 % Generate the matrix of inputs x and targets t.
|
wolffd@0
|
23 seedn = 42;
|
wolffd@0
|
24 seed = 42;
|
wolffd@0
|
25 randn('state', seedn);
|
wolffd@0
|
26 rand('state', seed);
|
wolffd@0
|
27 ndata = 300; % Number of data points.
|
wolffd@0
|
28 noise = 0.2; % Range of noise distribution.
|
wolffd@0
|
29 t = [0:1/(ndata - 1):1]';
|
wolffd@0
|
30 x = t + 0.3*sin(2*pi*t) + noise*rand(ndata, 1) - noise/2;
|
wolffd@0
|
31 axis_limits = [-0.2 1.2 -0.2 1.2];
|
wolffd@0
|
32
|
wolffd@0
|
33 clc
|
wolffd@0
|
34 disp('This demonstration illustrates the use of a Mixture Density Network')
|
wolffd@0
|
35 disp('to model multi-valued functions. The data is generated from the')
|
wolffd@0
|
36 disp('mapping x = t + 0.3 sin(2 pi t) + e, where e is a noise term.')
|
wolffd@0
|
37 disp('We begin by plotting the data.')
|
wolffd@0
|
38 disp(' ')
|
wolffd@0
|
39 disp('Press any key to continue')
|
wolffd@0
|
40 pause
|
wolffd@0
|
41 % Plot the data
|
wolffd@0
|
42 fh1 = figure;
|
wolffd@0
|
43 p1 = plot(x, t, 'ob');
|
wolffd@0
|
44 axis(axis_limits);
|
wolffd@0
|
45 hold on
|
wolffd@0
|
46 disp('Note that for x in the range 0.35 to 0.65, there are three possible')
|
wolffd@0
|
47 disp('branches of the function.')
|
wolffd@0
|
48 disp(' ')
|
wolffd@0
|
49 disp('Press any key to continue')
|
wolffd@0
|
50 pause
|
wolffd@0
|
51
|
wolffd@0
|
52 % Set up network parameters.
|
wolffd@0
|
53 nin = 1; % Number of inputs.
|
wolffd@0
|
54 nhidden = 5; % Number of hidden units.
|
wolffd@0
|
55 ncentres = 3; % Number of mixture components.
|
wolffd@0
|
56 dim_target = 1; % Dimension of target space
|
wolffd@0
|
57 mdntype = '0'; % Currently unused: reserved for future use
|
wolffd@0
|
58 alpha = 100; % Inverse variance for weight initialisation
|
wolffd@0
|
59 % Make variance small for good starting point
|
wolffd@0
|
60
|
wolffd@0
|
61 % Create and initialize network weight vector.
|
wolffd@0
|
62 net = mdn(nin, nhidden, ncentres, dim_target, mdntype);
|
wolffd@0
|
63 init_options = zeros(1, 18);
|
wolffd@0
|
64 init_options(1) = -1; % Suppress all messages
|
wolffd@0
|
65 init_options(14) = 10; % 10 iterations of K means in gmminit
|
wolffd@0
|
66 net = mdninit(net, alpha, t, init_options);
|
wolffd@0
|
67
|
wolffd@0
|
68 % Set up vector of options for the optimiser.
|
wolffd@0
|
69 options = foptions;
|
wolffd@0
|
70 options(1) = 1; % This provides display of error values.
|
wolffd@0
|
71 options(14) = 200; % Number of training cycles.
|
wolffd@0
|
72
|
wolffd@0
|
73 clc
|
wolffd@0
|
74 disp('We initialise the neural network model, which is an MLP with a')
|
wolffd@0
|
75 disp('Gaussian mixture model with three components and spherical variance')
|
wolffd@0
|
76 disp('as the error function. This enables us to model the complete')
|
wolffd@0
|
77 disp('conditional density function.')
|
wolffd@0
|
78 disp(' ')
|
wolffd@0
|
79 disp('Next we train the model for 200 epochs using a scaled conjugate gradient')
|
wolffd@0
|
80 disp('optimizer. The error function is the negative log likelihood of the')
|
wolffd@0
|
81 disp('training data.')
|
wolffd@0
|
82 disp(' ')
|
wolffd@0
|
83 disp('Press any key to continue.')
|
wolffd@0
|
84 pause
|
wolffd@0
|
85
|
wolffd@0
|
86 % Train using scaled conjugate gradients.
|
wolffd@0
|
87 [net, options] = netopt(net, options, x, t, 'scg');
|
wolffd@0
|
88
|
wolffd@0
|
89 disp(' ')
|
wolffd@0
|
90 disp('Press any key to continue.')
|
wolffd@0
|
91 pause
|
wolffd@0
|
92
|
wolffd@0
|
93 clc
|
wolffd@0
|
94 disp('We can also train a conventional MLP with sum of squares error function.')
|
wolffd@0
|
95 disp('This will approximate the conditional mean, which is not always a')
|
wolffd@0
|
96 disp('good representation of the data. Note that the error function is the')
|
wolffd@0
|
97 disp('sum of squares error on the training data, which accounts for the')
|
wolffd@0
|
98 disp('different values from training the MDN.')
|
wolffd@0
|
99 disp(' ')
|
wolffd@0
|
100 disp('We train the network with the quasi-Newton optimizer for 80 epochs.')
|
wolffd@0
|
101 disp(' ')
|
wolffd@0
|
102 disp('Press any key to continue.')
|
wolffd@0
|
103 pause
|
wolffd@0
|
104 mlp_nhidden = 8;
|
wolffd@0
|
105 net2 = mlp(nin, mlp_nhidden, dim_target, 'linear');
|
wolffd@0
|
106 options(14) = 80;
|
wolffd@0
|
107 [net2, options] = netopt(net2, options, x, t, 'quasinew');
|
wolffd@0
|
108 disp(' ')
|
wolffd@0
|
109 disp('Press any key to continue.')
|
wolffd@0
|
110 pause
|
wolffd@0
|
111
|
wolffd@0
|
112 clc
|
wolffd@0
|
113 disp('Now we plot the underlying function, the MDN prediction,')
|
wolffd@0
|
114 disp('represented by the mode of the conditional distribution, and the')
|
wolffd@0
|
115 disp('prediction of the conventional MLP.')
|
wolffd@0
|
116 disp(' ')
|
wolffd@0
|
117 disp('Press any key to continue.')
|
wolffd@0
|
118 pause
|
wolffd@0
|
119
|
wolffd@0
|
120 % Plot the original function, and the trained network function.
|
wolffd@0
|
121 plotvals = [0:0.01:1]';
|
wolffd@0
|
122 mixes = mdn2gmm(mdnfwd(net, plotvals));
|
wolffd@0
|
123 axis(axis_limits);
|
wolffd@0
|
124 yplot = t+0.3*sin(2*pi*t);
|
wolffd@0
|
125 p2 = plot(yplot, t, '--y');
|
wolffd@0
|
126
|
wolffd@0
|
127 % Use the mode to represent the function
|
wolffd@0
|
128 y = zeros(1, length(plotvals));
|
wolffd@0
|
129 priors = zeros(length(plotvals), ncentres);
|
wolffd@0
|
130 c = zeros(length(plotvals), 3);
|
wolffd@0
|
131 widths = zeros(length(plotvals), ncentres);
|
wolffd@0
|
132 for i = 1:length(plotvals)
|
wolffd@0
|
133 [m, j] = max(mixes(i).priors);
|
wolffd@0
|
134 y(i) = mixes(i).centres(j,:);
|
wolffd@0
|
135 c(i,:) = mixes(i).centres';
|
wolffd@0
|
136 end
|
wolffd@0
|
137 p3 = plot(plotvals, y, '*r');
|
wolffd@0
|
138 p4 = plot(plotvals, mlpfwd(net2, plotvals), 'g');
|
wolffd@0
|
139 set(p4, 'LineWidth', 2);
|
wolffd@0
|
140 legend([p1 p2 p3 p4], 'data', 'function', 'MDN mode', 'MLP mean', 4);
|
wolffd@0
|
141 hold off
|
wolffd@0
|
142
|
wolffd@0
|
143 clc
|
wolffd@0
|
144 disp('We can also plot how the mixture model parameters depend on x.')
|
wolffd@0
|
145 disp('First we plot the mixture centres, then the priors and finally')
|
wolffd@0
|
146 disp('the variances.')
|
wolffd@0
|
147 disp(' ')
|
wolffd@0
|
148 disp('Press any key to continue.')
|
wolffd@0
|
149 pause
|
wolffd@0
|
150 fh2 = figure;
|
wolffd@0
|
151 subplot(3, 1, 1)
|
wolffd@0
|
152 plot(plotvals, c)
|
wolffd@0
|
153 hold on
|
wolffd@0
|
154 title('Mixture centres')
|
wolffd@0
|
155 legend('centre 1', 'centre 2', 'centre 3')
|
wolffd@0
|
156 hold off
|
wolffd@0
|
157
|
wolffd@0
|
158 priors = reshape([mixes.priors], mixes(1).ncentres, size(mixes, 2))';
|
wolffd@0
|
159 %%fh3 = figure;
|
wolffd@0
|
160 subplot(3, 1, 2)
|
wolffd@0
|
161 plot(plotvals, priors)
|
wolffd@0
|
162 hold on
|
wolffd@0
|
163 title('Mixture priors')
|
wolffd@0
|
164 legend('centre 1', 'centre 2', 'centre 3')
|
wolffd@0
|
165 hold off
|
wolffd@0
|
166
|
wolffd@0
|
167 variances = reshape([mixes.covars], mixes(1).ncentres, size(mixes, 2))';
|
wolffd@0
|
168 %%fh4 = figure;
|
wolffd@0
|
169 subplot(3, 1, 3)
|
wolffd@0
|
170 plot(plotvals, variances)
|
wolffd@0
|
171 hold on
|
wolffd@0
|
172 title('Mixture variances')
|
wolffd@0
|
173 legend('centre 1', 'centre 2', 'centre 3')
|
wolffd@0
|
174 hold off
|
wolffd@0
|
175
|
wolffd@0
|
176 disp('The last figure is a contour plot of the conditional probability')
|
wolffd@0
|
177 disp('density generated by the Mixture Density Network. Note how it')
|
wolffd@0
|
178 disp('is well matched to the regions of high data density.')
|
wolffd@0
|
179 disp(' ')
|
wolffd@0
|
180 disp('Press any key to continue.')
|
wolffd@0
|
181 pause
|
wolffd@0
|
182 % Contour plot for MDN.
|
wolffd@0
|
183 i = 0:0.01:1.0;
|
wolffd@0
|
184 j = 0:0.01:1.0;
|
wolffd@0
|
185
|
wolffd@0
|
186 [I, J] = meshgrid(i,j);
|
wolffd@0
|
187 I = I(:);
|
wolffd@0
|
188 J = J(:);
|
wolffd@0
|
189 li = length(i);
|
wolffd@0
|
190 lj = length(j);
|
wolffd@0
|
191 Z = zeros(li, lj);
|
wolffd@0
|
192 for k = 1:li;
|
wolffd@0
|
193 Z(:,k) = gmmprob(mixes(k), j');
|
wolffd@0
|
194 end
|
wolffd@0
|
195 fh5 = figure;
|
wolffd@0
|
196 % Set up levels by hand to make a good figure
|
wolffd@0
|
197 v = [2 2.5 3 3.5 5:3:18];
|
wolffd@0
|
198 contour(i, j, Z, v)
|
wolffd@0
|
199 hold on
|
wolffd@0
|
200 title('Contour plot of conditional density')
|
wolffd@0
|
201 hold off
|
wolffd@0
|
202
|
wolffd@0
|
203 disp(' ')
|
wolffd@0
|
204 disp('Press any key to exit.')
|
wolffd@0
|
205 pause
|
wolffd@0
|
206 close(fh1);
|
wolffd@0
|
207 close(fh2);
|
wolffd@0
|
208 %%close(fh3);
|
wolffd@0
|
209 %%close(fh4);
|
wolffd@0
|
210 close(fh5);
|
wolffd@0
|
211 %%clear all;
|