Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/netlab3.3/demmdn1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/netlab3.3/demmdn1.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,211 @@ +%DEMMDN1 Demonstrate fitting a multi-valued function using a Mixture Density Network. +% +% Description +% The problem consists of one input variable X and one target variable +% T with data generated by sampling T at equal intervals and then +% generating target data by computing T + 0.3*SIN(2*PI*T) and adding +% Gaussian noise. A Mixture Density Network with 3 centres in the +% mixture model is trained by minimizing a negative log likelihood +% error function using the scaled conjugate gradient optimizer. +% +% The conditional means, mixing coefficients and variances are plotted +% as a function of X, and a contour plot of the full conditional +% density is also generated. +% +% See also +% MDN, MDNERR, MDNGRAD, SCG +% + +% Copyright (c) Ian T Nabney (1996-2001) + + +% Generate the matrix of inputs x and targets t. +seedn = 42; +seed = 42; +randn('state', seedn); +rand('state', seed); +ndata = 300; % Number of data points. +noise = 0.2; % Range of noise distribution. +t = [0:1/(ndata - 1):1]'; +x = t + 0.3*sin(2*pi*t) + noise*rand(ndata, 1) - noise/2; +axis_limits = [-0.2 1.2 -0.2 1.2]; + +clc +disp('This demonstration illustrates the use of a Mixture Density Network') +disp('to model multi-valued functions. The data is generated from the') +disp('mapping x = t + 0.3 sin(2 pi t) + e, where e is a noise term.') +disp('We begin by plotting the data.') +disp(' ') +disp('Press any key to continue') +pause +% Plot the data +fh1 = figure; +p1 = plot(x, t, 'ob'); +axis(axis_limits); +hold on +disp('Note that for x in the range 0.35 to 0.65, there are three possible') +disp('branches of the function.') +disp(' ') +disp('Press any key to continue') +pause + +% Set up network parameters. +nin = 1; % Number of inputs. +nhidden = 5; % Number of hidden units. +ncentres = 3; % Number of mixture components. +dim_target = 1; % Dimension of target space +mdntype = '0'; % Currently unused: reserved for future use +alpha = 100; % Inverse variance for weight initialisation + % Make variance small for good starting point + +% Create and initialize network weight vector. +net = mdn(nin, nhidden, ncentres, dim_target, mdntype); +init_options = zeros(1, 18); +init_options(1) = -1; % Suppress all messages +init_options(14) = 10; % 10 iterations of K means in gmminit +net = mdninit(net, alpha, t, init_options); + +% Set up vector of options for the optimiser. +options = foptions; +options(1) = 1; % This provides display of error values. +options(14) = 200; % Number of training cycles. + +clc +disp('We initialise the neural network model, which is an MLP with a') +disp('Gaussian mixture model with three components and spherical variance') +disp('as the error function. This enables us to model the complete') +disp('conditional density function.') +disp(' ') +disp('Next we train the model for 200 epochs using a scaled conjugate gradient') +disp('optimizer. The error function is the negative log likelihood of the') +disp('training data.') +disp(' ') +disp('Press any key to continue.') +pause + +% Train using scaled conjugate gradients. +[net, options] = netopt(net, options, x, t, 'scg'); + +disp(' ') +disp('Press any key to continue.') +pause + +clc +disp('We can also train a conventional MLP with sum of squares error function.') +disp('This will approximate the conditional mean, which is not always a') +disp('good representation of the data. Note that the error function is the') +disp('sum of squares error on the training data, which accounts for the') +disp('different values from training the MDN.') +disp(' ') +disp('We train the network with the quasi-Newton optimizer for 80 epochs.') +disp(' ') +disp('Press any key to continue.') +pause +mlp_nhidden = 8; +net2 = mlp(nin, mlp_nhidden, dim_target, 'linear'); +options(14) = 80; +[net2, options] = netopt(net2, options, x, t, 'quasinew'); +disp(' ') +disp('Press any key to continue.') +pause + +clc +disp('Now we plot the underlying function, the MDN prediction,') +disp('represented by the mode of the conditional distribution, and the') +disp('prediction of the conventional MLP.') +disp(' ') +disp('Press any key to continue.') +pause + +% Plot the original function, and the trained network function. +plotvals = [0:0.01:1]'; +mixes = mdn2gmm(mdnfwd(net, plotvals)); +axis(axis_limits); +yplot = t+0.3*sin(2*pi*t); +p2 = plot(yplot, t, '--y'); + +% Use the mode to represent the function +y = zeros(1, length(plotvals)); +priors = zeros(length(plotvals), ncentres); +c = zeros(length(plotvals), 3); +widths = zeros(length(plotvals), ncentres); +for i = 1:length(plotvals) + [m, j] = max(mixes(i).priors); + y(i) = mixes(i).centres(j,:); + c(i,:) = mixes(i).centres'; +end +p3 = plot(plotvals, y, '*r'); +p4 = plot(plotvals, mlpfwd(net2, plotvals), 'g'); +set(p4, 'LineWidth', 2); +legend([p1 p2 p3 p4], 'data', 'function', 'MDN mode', 'MLP mean', 4); +hold off + +clc +disp('We can also plot how the mixture model parameters depend on x.') +disp('First we plot the mixture centres, then the priors and finally') +disp('the variances.') +disp(' ') +disp('Press any key to continue.') +pause +fh2 = figure; +subplot(3, 1, 1) +plot(plotvals, c) +hold on +title('Mixture centres') +legend('centre 1', 'centre 2', 'centre 3') +hold off + +priors = reshape([mixes.priors], mixes(1).ncentres, size(mixes, 2))'; +%%fh3 = figure; +subplot(3, 1, 2) +plot(plotvals, priors) +hold on +title('Mixture priors') +legend('centre 1', 'centre 2', 'centre 3') +hold off + +variances = reshape([mixes.covars], mixes(1).ncentres, size(mixes, 2))'; +%%fh4 = figure; +subplot(3, 1, 3) +plot(plotvals, variances) +hold on +title('Mixture variances') +legend('centre 1', 'centre 2', 'centre 3') +hold off + +disp('The last figure is a contour plot of the conditional probability') +disp('density generated by the Mixture Density Network. Note how it') +disp('is well matched to the regions of high data density.') +disp(' ') +disp('Press any key to continue.') +pause +% Contour plot for MDN. +i = 0:0.01:1.0; +j = 0:0.01:1.0; + +[I, J] = meshgrid(i,j); +I = I(:); +J = J(:); +li = length(i); +lj = length(j); +Z = zeros(li, lj); +for k = 1:li; + Z(:,k) = gmmprob(mixes(k), j'); +end +fh5 = figure; +% Set up levels by hand to make a good figure +v = [2 2.5 3 3.5 5:3:18]; +contour(i, j, Z, v) +hold on +title('Contour plot of conditional density') +hold off + +disp(' ') +disp('Press any key to exit.') +pause +close(fh1); +close(fh2); +%%close(fh3); +%%close(fh4); +close(fh5); +%%clear all;