Daniel@0: %DEMMDN1 Demonstrate fitting a multi-valued function using a Mixture Density Network. Daniel@0: % Daniel@0: % Description Daniel@0: % The problem consists of one input variable X and one target variable Daniel@0: % T with data generated by sampling T at equal intervals and then Daniel@0: % generating target data by computing T + 0.3*SIN(2*PI*T) and adding Daniel@0: % Gaussian noise. A Mixture Density Network with 3 centres in the Daniel@0: % mixture model is trained by minimizing a negative log likelihood Daniel@0: % error function using the scaled conjugate gradient optimizer. Daniel@0: % Daniel@0: % The conditional means, mixing coefficients and variances are plotted Daniel@0: % as a function of X, and a contour plot of the full conditional Daniel@0: % density is also generated. Daniel@0: % Daniel@0: % See also Daniel@0: % MDN, MDNERR, MDNGRAD, SCG Daniel@0: % Daniel@0: Daniel@0: % Copyright (c) Ian T Nabney (1996-2001) Daniel@0: Daniel@0: Daniel@0: % Generate the matrix of inputs x and targets t. Daniel@0: seedn = 42; Daniel@0: seed = 42; Daniel@0: randn('state', seedn); Daniel@0: rand('state', seed); Daniel@0: ndata = 300; % Number of data points. Daniel@0: noise = 0.2; % Range of noise distribution. Daniel@0: t = [0:1/(ndata - 1):1]'; Daniel@0: x = t + 0.3*sin(2*pi*t) + noise*rand(ndata, 1) - noise/2; Daniel@0: axis_limits = [-0.2 1.2 -0.2 1.2]; Daniel@0: Daniel@0: clc Daniel@0: disp('This demonstration illustrates the use of a Mixture Density Network') Daniel@0: disp('to model multi-valued functions. The data is generated from the') Daniel@0: disp('mapping x = t + 0.3 sin(2 pi t) + e, where e is a noise term.') Daniel@0: disp('We begin by plotting the data.') Daniel@0: disp(' ') Daniel@0: disp('Press any key to continue') Daniel@0: pause Daniel@0: % Plot the data Daniel@0: fh1 = figure; Daniel@0: p1 = plot(x, t, 'ob'); Daniel@0: axis(axis_limits); Daniel@0: hold on Daniel@0: disp('Note that for x in the range 0.35 to 0.65, there are three possible') Daniel@0: disp('branches of the function.') Daniel@0: disp(' ') Daniel@0: disp('Press any key to continue') Daniel@0: pause Daniel@0: Daniel@0: % Set up network parameters. Daniel@0: nin = 1; % Number of inputs. Daniel@0: nhidden = 5; % Number of hidden units. Daniel@0: ncentres = 3; % Number of mixture components. Daniel@0: dim_target = 1; % Dimension of target space Daniel@0: mdntype = '0'; % Currently unused: reserved for future use Daniel@0: alpha = 100; % Inverse variance for weight initialisation Daniel@0: % Make variance small for good starting point Daniel@0: Daniel@0: % Create and initialize network weight vector. Daniel@0: net = mdn(nin, nhidden, ncentres, dim_target, mdntype); Daniel@0: init_options = zeros(1, 18); Daniel@0: init_options(1) = -1; % Suppress all messages Daniel@0: init_options(14) = 10; % 10 iterations of K means in gmminit Daniel@0: net = mdninit(net, alpha, t, init_options); Daniel@0: Daniel@0: % Set up vector of options for the optimiser. Daniel@0: options = foptions; Daniel@0: options(1) = 1; % This provides display of error values. Daniel@0: options(14) = 200; % Number of training cycles. Daniel@0: Daniel@0: clc Daniel@0: disp('We initialise the neural network model, which is an MLP with a') Daniel@0: disp('Gaussian mixture model with three components and spherical variance') Daniel@0: disp('as the error function. This enables us to model the complete') Daniel@0: disp('conditional density function.') Daniel@0: disp(' ') Daniel@0: disp('Next we train the model for 200 epochs using a scaled conjugate gradient') Daniel@0: disp('optimizer. The error function is the negative log likelihood of the') Daniel@0: disp('training data.') Daniel@0: disp(' ') Daniel@0: disp('Press any key to continue.') Daniel@0: pause Daniel@0: Daniel@0: % Train using scaled conjugate gradients. Daniel@0: [net, options] = netopt(net, options, x, t, 'scg'); Daniel@0: Daniel@0: disp(' ') Daniel@0: disp('Press any key to continue.') Daniel@0: pause Daniel@0: Daniel@0: clc Daniel@0: disp('We can also train a conventional MLP with sum of squares error function.') Daniel@0: disp('This will approximate the conditional mean, which is not always a') Daniel@0: disp('good representation of the data. Note that the error function is the') Daniel@0: disp('sum of squares error on the training data, which accounts for the') Daniel@0: disp('different values from training the MDN.') Daniel@0: disp(' ') Daniel@0: disp('We train the network with the quasi-Newton optimizer for 80 epochs.') Daniel@0: disp(' ') Daniel@0: disp('Press any key to continue.') Daniel@0: pause Daniel@0: mlp_nhidden = 8; Daniel@0: net2 = mlp(nin, mlp_nhidden, dim_target, 'linear'); Daniel@0: options(14) = 80; Daniel@0: [net2, options] = netopt(net2, options, x, t, 'quasinew'); Daniel@0: disp(' ') Daniel@0: disp('Press any key to continue.') Daniel@0: pause Daniel@0: Daniel@0: clc Daniel@0: disp('Now we plot the underlying function, the MDN prediction,') Daniel@0: disp('represented by the mode of the conditional distribution, and the') Daniel@0: disp('prediction of the conventional MLP.') Daniel@0: disp(' ') Daniel@0: disp('Press any key to continue.') Daniel@0: pause Daniel@0: Daniel@0: % Plot the original function, and the trained network function. Daniel@0: plotvals = [0:0.01:1]'; Daniel@0: mixes = mdn2gmm(mdnfwd(net, plotvals)); Daniel@0: axis(axis_limits); Daniel@0: yplot = t+0.3*sin(2*pi*t); Daniel@0: p2 = plot(yplot, t, '--y'); Daniel@0: Daniel@0: % Use the mode to represent the function Daniel@0: y = zeros(1, length(plotvals)); Daniel@0: priors = zeros(length(plotvals), ncentres); Daniel@0: c = zeros(length(plotvals), 3); Daniel@0: widths = zeros(length(plotvals), ncentres); Daniel@0: for i = 1:length(plotvals) Daniel@0: [m, j] = max(mixes(i).priors); Daniel@0: y(i) = mixes(i).centres(j,:); Daniel@0: c(i,:) = mixes(i).centres'; Daniel@0: end Daniel@0: p3 = plot(plotvals, y, '*r'); Daniel@0: p4 = plot(plotvals, mlpfwd(net2, plotvals), 'g'); Daniel@0: set(p4, 'LineWidth', 2); Daniel@0: legend([p1 p2 p3 p4], 'data', 'function', 'MDN mode', 'MLP mean', 4); Daniel@0: hold off Daniel@0: Daniel@0: clc Daniel@0: disp('We can also plot how the mixture model parameters depend on x.') Daniel@0: disp('First we plot the mixture centres, then the priors and finally') Daniel@0: disp('the variances.') Daniel@0: disp(' ') Daniel@0: disp('Press any key to continue.') Daniel@0: pause Daniel@0: fh2 = figure; Daniel@0: subplot(3, 1, 1) Daniel@0: plot(plotvals, c) Daniel@0: hold on Daniel@0: title('Mixture centres') Daniel@0: legend('centre 1', 'centre 2', 'centre 3') Daniel@0: hold off Daniel@0: Daniel@0: priors = reshape([mixes.priors], mixes(1).ncentres, size(mixes, 2))'; Daniel@0: %%fh3 = figure; Daniel@0: subplot(3, 1, 2) Daniel@0: plot(plotvals, priors) Daniel@0: hold on Daniel@0: title('Mixture priors') Daniel@0: legend('centre 1', 'centre 2', 'centre 3') Daniel@0: hold off Daniel@0: Daniel@0: variances = reshape([mixes.covars], mixes(1).ncentres, size(mixes, 2))'; Daniel@0: %%fh4 = figure; Daniel@0: subplot(3, 1, 3) Daniel@0: plot(plotvals, variances) Daniel@0: hold on Daniel@0: title('Mixture variances') Daniel@0: legend('centre 1', 'centre 2', 'centre 3') Daniel@0: hold off Daniel@0: Daniel@0: disp('The last figure is a contour plot of the conditional probability') Daniel@0: disp('density generated by the Mixture Density Network. Note how it') Daniel@0: disp('is well matched to the regions of high data density.') Daniel@0: disp(' ') Daniel@0: disp('Press any key to continue.') Daniel@0: pause Daniel@0: % Contour plot for MDN. Daniel@0: i = 0:0.01:1.0; Daniel@0: j = 0:0.01:1.0; Daniel@0: Daniel@0: [I, J] = meshgrid(i,j); Daniel@0: I = I(:); Daniel@0: J = J(:); Daniel@0: li = length(i); Daniel@0: lj = length(j); Daniel@0: Z = zeros(li, lj); Daniel@0: for k = 1:li; Daniel@0: Z(:,k) = gmmprob(mixes(k), j'); Daniel@0: end Daniel@0: fh5 = figure; Daniel@0: % Set up levels by hand to make a good figure Daniel@0: v = [2 2.5 3 3.5 5:3:18]; Daniel@0: contour(i, j, Z, v) Daniel@0: hold on Daniel@0: title('Contour plot of conditional density') Daniel@0: hold off Daniel@0: Daniel@0: disp(' ') Daniel@0: disp('Press any key to exit.') Daniel@0: pause Daniel@0: close(fh1); Daniel@0: close(fh2); Daniel@0: %%close(fh3); Daniel@0: %%close(fh4); Daniel@0: close(fh5); Daniel@0: %%clear all;