wolffd@0: %DEMMDN1 Demonstrate fitting a multi-valued function using a Mixture Density Network. wolffd@0: % wolffd@0: % Description wolffd@0: % The problem consists of one input variable X and one target variable wolffd@0: % T with data generated by sampling T at equal intervals and then wolffd@0: % generating target data by computing T + 0.3*SIN(2*PI*T) and adding wolffd@0: % Gaussian noise. A Mixture Density Network with 3 centres in the wolffd@0: % mixture model is trained by minimizing a negative log likelihood wolffd@0: % error function using the scaled conjugate gradient optimizer. wolffd@0: % wolffd@0: % The conditional means, mixing coefficients and variances are plotted wolffd@0: % as a function of X, and a contour plot of the full conditional wolffd@0: % density is also generated. wolffd@0: % wolffd@0: % See also wolffd@0: % MDN, MDNERR, MDNGRAD, SCG wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: wolffd@0: % Generate the matrix of inputs x and targets t. wolffd@0: seedn = 42; wolffd@0: seed = 42; wolffd@0: randn('state', seedn); wolffd@0: rand('state', seed); wolffd@0: ndata = 300; % Number of data points. wolffd@0: noise = 0.2; % Range of noise distribution. wolffd@0: t = [0:1/(ndata - 1):1]'; wolffd@0: x = t + 0.3*sin(2*pi*t) + noise*rand(ndata, 1) - noise/2; wolffd@0: axis_limits = [-0.2 1.2 -0.2 1.2]; wolffd@0: wolffd@0: clc wolffd@0: disp('This demonstration illustrates the use of a Mixture Density Network') wolffd@0: disp('to model multi-valued functions. The data is generated from the') wolffd@0: disp('mapping x = t + 0.3 sin(2 pi t) + e, where e is a noise term.') wolffd@0: disp('We begin by plotting the data.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue') wolffd@0: pause wolffd@0: % Plot the data wolffd@0: fh1 = figure; wolffd@0: p1 = plot(x, t, 'ob'); wolffd@0: axis(axis_limits); wolffd@0: hold on wolffd@0: disp('Note that for x in the range 0.35 to 0.65, there are three possible') wolffd@0: disp('branches of the function.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue') wolffd@0: pause wolffd@0: wolffd@0: % Set up network parameters. wolffd@0: nin = 1; % Number of inputs. wolffd@0: nhidden = 5; % Number of hidden units. wolffd@0: ncentres = 3; % Number of mixture components. wolffd@0: dim_target = 1; % Dimension of target space wolffd@0: mdntype = '0'; % Currently unused: reserved for future use wolffd@0: alpha = 100; % Inverse variance for weight initialisation wolffd@0: % Make variance small for good starting point wolffd@0: wolffd@0: % Create and initialize network weight vector. wolffd@0: net = mdn(nin, nhidden, ncentres, dim_target, mdntype); wolffd@0: init_options = zeros(1, 18); wolffd@0: init_options(1) = -1; % Suppress all messages wolffd@0: init_options(14) = 10; % 10 iterations of K means in gmminit wolffd@0: net = mdninit(net, alpha, t, init_options); wolffd@0: wolffd@0: % Set up vector of options for the optimiser. wolffd@0: options = foptions; wolffd@0: options(1) = 1; % This provides display of error values. wolffd@0: options(14) = 200; % Number of training cycles. wolffd@0: wolffd@0: clc wolffd@0: disp('We initialise the neural network model, which is an MLP with a') wolffd@0: disp('Gaussian mixture model with three components and spherical variance') wolffd@0: disp('as the error function. This enables us to model the complete') wolffd@0: disp('conditional density function.') wolffd@0: disp(' ') wolffd@0: disp('Next we train the model for 200 epochs using a scaled conjugate gradient') wolffd@0: disp('optimizer. The error function is the negative log likelihood of the') wolffd@0: disp('training data.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: wolffd@0: % Train using scaled conjugate gradients. wolffd@0: [net, options] = netopt(net, options, x, t, 'scg'); wolffd@0: wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: wolffd@0: clc wolffd@0: disp('We can also train a conventional MLP with sum of squares error function.') wolffd@0: disp('This will approximate the conditional mean, which is not always a') wolffd@0: disp('good representation of the data. Note that the error function is the') wolffd@0: disp('sum of squares error on the training data, which accounts for the') wolffd@0: disp('different values from training the MDN.') wolffd@0: disp(' ') wolffd@0: disp('We train the network with the quasi-Newton optimizer for 80 epochs.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: mlp_nhidden = 8; wolffd@0: net2 = mlp(nin, mlp_nhidden, dim_target, 'linear'); wolffd@0: options(14) = 80; wolffd@0: [net2, options] = netopt(net2, options, x, t, 'quasinew'); wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: wolffd@0: clc wolffd@0: disp('Now we plot the underlying function, the MDN prediction,') wolffd@0: disp('represented by the mode of the conditional distribution, and the') wolffd@0: disp('prediction of the conventional MLP.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: wolffd@0: % Plot the original function, and the trained network function. wolffd@0: plotvals = [0:0.01:1]'; wolffd@0: mixes = mdn2gmm(mdnfwd(net, plotvals)); wolffd@0: axis(axis_limits); wolffd@0: yplot = t+0.3*sin(2*pi*t); wolffd@0: p2 = plot(yplot, t, '--y'); wolffd@0: wolffd@0: % Use the mode to represent the function wolffd@0: y = zeros(1, length(plotvals)); wolffd@0: priors = zeros(length(plotvals), ncentres); wolffd@0: c = zeros(length(plotvals), 3); wolffd@0: widths = zeros(length(plotvals), ncentres); wolffd@0: for i = 1:length(plotvals) wolffd@0: [m, j] = max(mixes(i).priors); wolffd@0: y(i) = mixes(i).centres(j,:); wolffd@0: c(i,:) = mixes(i).centres'; wolffd@0: end wolffd@0: p3 = plot(plotvals, y, '*r'); wolffd@0: p4 = plot(plotvals, mlpfwd(net2, plotvals), 'g'); wolffd@0: set(p4, 'LineWidth', 2); wolffd@0: legend([p1 p2 p3 p4], 'data', 'function', 'MDN mode', 'MLP mean', 4); wolffd@0: hold off wolffd@0: wolffd@0: clc wolffd@0: disp('We can also plot how the mixture model parameters depend on x.') wolffd@0: disp('First we plot the mixture centres, then the priors and finally') wolffd@0: disp('the variances.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: fh2 = figure; wolffd@0: subplot(3, 1, 1) wolffd@0: plot(plotvals, c) wolffd@0: hold on wolffd@0: title('Mixture centres') wolffd@0: legend('centre 1', 'centre 2', 'centre 3') wolffd@0: hold off wolffd@0: wolffd@0: priors = reshape([mixes.priors], mixes(1).ncentres, size(mixes, 2))'; wolffd@0: %%fh3 = figure; wolffd@0: subplot(3, 1, 2) wolffd@0: plot(plotvals, priors) wolffd@0: hold on wolffd@0: title('Mixture priors') wolffd@0: legend('centre 1', 'centre 2', 'centre 3') wolffd@0: hold off wolffd@0: wolffd@0: variances = reshape([mixes.covars], mixes(1).ncentres, size(mixes, 2))'; wolffd@0: %%fh4 = figure; wolffd@0: subplot(3, 1, 3) wolffd@0: plot(plotvals, variances) wolffd@0: hold on wolffd@0: title('Mixture variances') wolffd@0: legend('centre 1', 'centre 2', 'centre 3') wolffd@0: hold off wolffd@0: wolffd@0: disp('The last figure is a contour plot of the conditional probability') wolffd@0: disp('density generated by the Mixture Density Network. Note how it') wolffd@0: disp('is well matched to the regions of high data density.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: % Contour plot for MDN. wolffd@0: i = 0:0.01:1.0; wolffd@0: j = 0:0.01:1.0; wolffd@0: wolffd@0: [I, J] = meshgrid(i,j); wolffd@0: I = I(:); wolffd@0: J = J(:); wolffd@0: li = length(i); wolffd@0: lj = length(j); wolffd@0: Z = zeros(li, lj); wolffd@0: for k = 1:li; wolffd@0: Z(:,k) = gmmprob(mixes(k), j'); wolffd@0: end wolffd@0: fh5 = figure; wolffd@0: % Set up levels by hand to make a good figure wolffd@0: v = [2 2.5 3 3.5 5:3:18]; wolffd@0: contour(i, j, Z, v) wolffd@0: hold on wolffd@0: title('Contour plot of conditional density') wolffd@0: hold off wolffd@0: wolffd@0: disp(' ') wolffd@0: disp('Press any key to exit.') wolffd@0: pause wolffd@0: close(fh1); wolffd@0: close(fh2); wolffd@0: %%close(fh3); wolffd@0: %%close(fh4); wolffd@0: close(fh5); wolffd@0: %%clear all;