diff toolboxes/FullBNT-1.0.7/netlab3.3/demmdn1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/netlab3.3/demmdn1.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,211 @@
+%DEMMDN1 Demonstrate fitting a multi-valued function using a Mixture Density Network.
+%
+%	Description
+%	The problem consists of one input variable X and one target variable
+%	T with data generated by sampling T at equal intervals and then
+%	generating target data by computing T + 0.3*SIN(2*PI*T) and adding
+%	Gaussian noise. A Mixture Density Network with 3 centres in the
+%	mixture model is trained by minimizing a negative log likelihood
+%	error function using the scaled conjugate gradient optimizer.
+%
+%	The conditional means, mixing coefficients and variances are plotted
+%	as a function of X, and a contour plot of the full conditional
+%	density is also generated.
+%
+%	See also
+%	MDN, MDNERR, MDNGRAD, SCG
+%
+
+%	Copyright (c) Ian T Nabney (1996-2001)
+
+
+% Generate the matrix of inputs x and targets t.
+seedn = 42;
+seed = 42;
+randn('state', seedn);
+rand('state', seed);
+ndata = 300;			% Number of data points.
+noise = 0.2;			% Range of noise distribution.
+t = [0:1/(ndata - 1):1]';
+x = t + 0.3*sin(2*pi*t) + noise*rand(ndata, 1) - noise/2;
+axis_limits = [-0.2 1.2 -0.2 1.2];
+
+clc
+disp('This demonstration illustrates the use of a Mixture Density Network')
+disp('to model multi-valued functions.  The data is generated from the')
+disp('mapping x = t + 0.3 sin(2 pi t) + e, where e is a noise term.')
+disp('We begin by plotting the data.')
+disp(' ')
+disp('Press any key to continue')
+pause
+% Plot the data
+fh1 = figure;
+p1 = plot(x, t, 'ob');
+axis(axis_limits);
+hold on
+disp('Note that for x in the range 0.35 to 0.65, there are three possible')
+disp('branches of the function.')
+disp(' ')
+disp('Press any key to continue')
+pause
+
+% Set up network parameters.
+nin = 1;			% Number of inputs.
+nhidden = 5;			% Number of hidden units.
+ncentres = 3;			% Number of mixture components.
+dim_target = 1;			% Dimension of target space
+mdntype = '0';			% Currently unused: reserved for future use
+alpha = 100;			% Inverse variance for weight initialisation
+				% Make variance small for good starting point
+
+% Create and initialize network weight vector.
+net = mdn(nin, nhidden, ncentres, dim_target, mdntype);
+init_options = zeros(1, 18);
+init_options(1) = -1;	% Suppress all messages
+init_options(14) = 10;  % 10 iterations of K means in gmminit
+net = mdninit(net, alpha, t, init_options);
+
+% Set up vector of options for the optimiser.
+options = foptions;
+options(1) = 1;			% This provides display of error values.
+options(14) = 200;		% Number of training cycles. 
+
+clc
+disp('We initialise the neural network model, which is an MLP with a')
+disp('Gaussian mixture model with three components and spherical variance')
+disp('as the error function.  This enables us to model the complete')
+disp('conditional density function.')
+disp(' ')
+disp('Next we train the model for 200 epochs using a scaled conjugate gradient')
+disp('optimizer.  The error function is the negative log likelihood of the')
+disp('training data.')
+disp(' ')
+disp('Press any key to continue.')
+pause
+
+% Train using scaled conjugate gradients.
+[net, options] = netopt(net, options, x, t, 'scg');
+
+disp(' ')
+disp('Press any key to continue.')
+pause
+
+clc
+disp('We can also train a conventional MLP with sum of squares error function.')
+disp('This will approximate the conditional mean, which is not always a')
+disp('good representation of the data.  Note that the error function is the')
+disp('sum of squares error on the training data, which accounts for the')
+disp('different values from training the MDN.')
+disp(' ')
+disp('We train the network with the quasi-Newton optimizer for 80 epochs.')
+disp(' ')
+disp('Press any key to continue.')
+pause
+mlp_nhidden = 8;
+net2 = mlp(nin, mlp_nhidden, dim_target, 'linear');
+options(14) = 80; 
+[net2, options] = netopt(net2, options, x, t, 'quasinew');
+disp(' ')
+disp('Press any key to continue.')
+pause
+
+clc
+disp('Now we plot the underlying function, the MDN prediction,')
+disp('represented by the mode of the conditional distribution, and the')
+disp('prediction of the conventional MLP.')
+disp(' ')
+disp('Press any key to continue.')
+pause
+
+% Plot the original function, and the trained network function.
+plotvals = [0:0.01:1]';
+mixes = mdn2gmm(mdnfwd(net, plotvals));
+axis(axis_limits);
+yplot = t+0.3*sin(2*pi*t);
+p2 = plot(yplot, t, '--y');
+
+% Use the mode to represent the function
+y = zeros(1, length(plotvals));
+priors = zeros(length(plotvals), ncentres);
+c = zeros(length(plotvals), 3);
+widths = zeros(length(plotvals), ncentres);
+for i = 1:length(plotvals)
+  [m, j] = max(mixes(i).priors);
+  y(i) = mixes(i).centres(j,:);
+  c(i,:) = mixes(i).centres';
+end
+p3 = plot(plotvals, y, '*r');
+p4 = plot(plotvals, mlpfwd(net2, plotvals), 'g');
+set(p4, 'LineWidth', 2);
+legend([p1 p2 p3 p4], 'data', 'function', 'MDN mode', 'MLP mean', 4);
+hold off
+
+clc
+disp('We can also plot how the mixture model parameters depend on x.')
+disp('First we plot the mixture centres, then the priors and finally')
+disp('the variances.')
+disp(' ')
+disp('Press any key to continue.')
+pause
+fh2 = figure;
+subplot(3, 1, 1)
+plot(plotvals, c)
+hold on
+title('Mixture centres')
+legend('centre 1', 'centre 2', 'centre 3')
+hold off
+
+priors = reshape([mixes.priors], mixes(1).ncentres, size(mixes, 2))';
+%%fh3 = figure;
+subplot(3, 1, 2)
+plot(plotvals, priors)
+hold on
+title('Mixture priors')
+legend('centre 1', 'centre 2', 'centre 3')
+hold off
+
+variances = reshape([mixes.covars], mixes(1).ncentres, size(mixes, 2))';
+%%fh4 = figure;
+subplot(3, 1, 3)
+plot(plotvals, variances)
+hold on
+title('Mixture variances')
+legend('centre 1', 'centre 2', 'centre 3')
+hold off
+
+disp('The last figure is a contour plot of the conditional probability')
+disp('density generated by the Mixture Density Network.  Note how it')
+disp('is well matched to the regions of high data density.')
+disp(' ')
+disp('Press any key to continue.')
+pause
+% Contour plot for MDN.
+i = 0:0.01:1.0;
+j = 0:0.01:1.0;
+
+[I, J] = meshgrid(i,j);
+I = I(:);
+J = J(:);
+li = length(i);
+lj = length(j);
+Z = zeros(li, lj);
+for k = 1:li;
+  Z(:,k) = gmmprob(mixes(k), j');
+end
+fh5 = figure;
+% Set up levels by hand to make a good figure
+v = [2 2.5 3 3.5 5:3:18];
+contour(i, j, Z, v)
+hold on
+title('Contour plot of conditional density')
+hold off
+
+disp(' ')
+disp('Press any key to exit.')
+pause
+close(fh1);
+close(fh2);
+%%close(fh3);
+%%close(fh4);
+close(fh5);
+%%clear all;