annotate toolboxes/FullBNT-1.0.7/netlab3.3/demmdn1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 %DEMMDN1 Demonstrate fitting a multi-valued function using a Mixture Density Network.
wolffd@0 2 %
wolffd@0 3 % Description
wolffd@0 4 % The problem consists of one input variable X and one target variable
wolffd@0 5 % T with data generated by sampling T at equal intervals and then
wolffd@0 6 % generating target data by computing T + 0.3*SIN(2*PI*T) and adding
wolffd@0 7 % Gaussian noise. A Mixture Density Network with 3 centres in the
wolffd@0 8 % mixture model is trained by minimizing a negative log likelihood
wolffd@0 9 % error function using the scaled conjugate gradient optimizer.
wolffd@0 10 %
wolffd@0 11 % The conditional means, mixing coefficients and variances are plotted
wolffd@0 12 % as a function of X, and a contour plot of the full conditional
wolffd@0 13 % density is also generated.
wolffd@0 14 %
wolffd@0 15 % See also
wolffd@0 16 % MDN, MDNERR, MDNGRAD, SCG
wolffd@0 17 %
wolffd@0 18
wolffd@0 19 % Copyright (c) Ian T Nabney (1996-2001)
wolffd@0 20
wolffd@0 21
wolffd@0 22 % Generate the matrix of inputs x and targets t.
wolffd@0 23 seedn = 42;
wolffd@0 24 seed = 42;
wolffd@0 25 randn('state', seedn);
wolffd@0 26 rand('state', seed);
wolffd@0 27 ndata = 300; % Number of data points.
wolffd@0 28 noise = 0.2; % Range of noise distribution.
wolffd@0 29 t = [0:1/(ndata - 1):1]';
wolffd@0 30 x = t + 0.3*sin(2*pi*t) + noise*rand(ndata, 1) - noise/2;
wolffd@0 31 axis_limits = [-0.2 1.2 -0.2 1.2];
wolffd@0 32
wolffd@0 33 clc
wolffd@0 34 disp('This demonstration illustrates the use of a Mixture Density Network')
wolffd@0 35 disp('to model multi-valued functions. The data is generated from the')
wolffd@0 36 disp('mapping x = t + 0.3 sin(2 pi t) + e, where e is a noise term.')
wolffd@0 37 disp('We begin by plotting the data.')
wolffd@0 38 disp(' ')
wolffd@0 39 disp('Press any key to continue')
wolffd@0 40 pause
wolffd@0 41 % Plot the data
wolffd@0 42 fh1 = figure;
wolffd@0 43 p1 = plot(x, t, 'ob');
wolffd@0 44 axis(axis_limits);
wolffd@0 45 hold on
wolffd@0 46 disp('Note that for x in the range 0.35 to 0.65, there are three possible')
wolffd@0 47 disp('branches of the function.')
wolffd@0 48 disp(' ')
wolffd@0 49 disp('Press any key to continue')
wolffd@0 50 pause
wolffd@0 51
wolffd@0 52 % Set up network parameters.
wolffd@0 53 nin = 1; % Number of inputs.
wolffd@0 54 nhidden = 5; % Number of hidden units.
wolffd@0 55 ncentres = 3; % Number of mixture components.
wolffd@0 56 dim_target = 1; % Dimension of target space
wolffd@0 57 mdntype = '0'; % Currently unused: reserved for future use
wolffd@0 58 alpha = 100; % Inverse variance for weight initialisation
wolffd@0 59 % Make variance small for good starting point
wolffd@0 60
wolffd@0 61 % Create and initialize network weight vector.
wolffd@0 62 net = mdn(nin, nhidden, ncentres, dim_target, mdntype);
wolffd@0 63 init_options = zeros(1, 18);
wolffd@0 64 init_options(1) = -1; % Suppress all messages
wolffd@0 65 init_options(14) = 10; % 10 iterations of K means in gmminit
wolffd@0 66 net = mdninit(net, alpha, t, init_options);
wolffd@0 67
wolffd@0 68 % Set up vector of options for the optimiser.
wolffd@0 69 options = foptions;
wolffd@0 70 options(1) = 1; % This provides display of error values.
wolffd@0 71 options(14) = 200; % Number of training cycles.
wolffd@0 72
wolffd@0 73 clc
wolffd@0 74 disp('We initialise the neural network model, which is an MLP with a')
wolffd@0 75 disp('Gaussian mixture model with three components and spherical variance')
wolffd@0 76 disp('as the error function. This enables us to model the complete')
wolffd@0 77 disp('conditional density function.')
wolffd@0 78 disp(' ')
wolffd@0 79 disp('Next we train the model for 200 epochs using a scaled conjugate gradient')
wolffd@0 80 disp('optimizer. The error function is the negative log likelihood of the')
wolffd@0 81 disp('training data.')
wolffd@0 82 disp(' ')
wolffd@0 83 disp('Press any key to continue.')
wolffd@0 84 pause
wolffd@0 85
wolffd@0 86 % Train using scaled conjugate gradients.
wolffd@0 87 [net, options] = netopt(net, options, x, t, 'scg');
wolffd@0 88
wolffd@0 89 disp(' ')
wolffd@0 90 disp('Press any key to continue.')
wolffd@0 91 pause
wolffd@0 92
wolffd@0 93 clc
wolffd@0 94 disp('We can also train a conventional MLP with sum of squares error function.')
wolffd@0 95 disp('This will approximate the conditional mean, which is not always a')
wolffd@0 96 disp('good representation of the data. Note that the error function is the')
wolffd@0 97 disp('sum of squares error on the training data, which accounts for the')
wolffd@0 98 disp('different values from training the MDN.')
wolffd@0 99 disp(' ')
wolffd@0 100 disp('We train the network with the quasi-Newton optimizer for 80 epochs.')
wolffd@0 101 disp(' ')
wolffd@0 102 disp('Press any key to continue.')
wolffd@0 103 pause
wolffd@0 104 mlp_nhidden = 8;
wolffd@0 105 net2 = mlp(nin, mlp_nhidden, dim_target, 'linear');
wolffd@0 106 options(14) = 80;
wolffd@0 107 [net2, options] = netopt(net2, options, x, t, 'quasinew');
wolffd@0 108 disp(' ')
wolffd@0 109 disp('Press any key to continue.')
wolffd@0 110 pause
wolffd@0 111
wolffd@0 112 clc
wolffd@0 113 disp('Now we plot the underlying function, the MDN prediction,')
wolffd@0 114 disp('represented by the mode of the conditional distribution, and the')
wolffd@0 115 disp('prediction of the conventional MLP.')
wolffd@0 116 disp(' ')
wolffd@0 117 disp('Press any key to continue.')
wolffd@0 118 pause
wolffd@0 119
wolffd@0 120 % Plot the original function, and the trained network function.
wolffd@0 121 plotvals = [0:0.01:1]';
wolffd@0 122 mixes = mdn2gmm(mdnfwd(net, plotvals));
wolffd@0 123 axis(axis_limits);
wolffd@0 124 yplot = t+0.3*sin(2*pi*t);
wolffd@0 125 p2 = plot(yplot, t, '--y');
wolffd@0 126
wolffd@0 127 % Use the mode to represent the function
wolffd@0 128 y = zeros(1, length(plotvals));
wolffd@0 129 priors = zeros(length(plotvals), ncentres);
wolffd@0 130 c = zeros(length(plotvals), 3);
wolffd@0 131 widths = zeros(length(plotvals), ncentres);
wolffd@0 132 for i = 1:length(plotvals)
wolffd@0 133 [m, j] = max(mixes(i).priors);
wolffd@0 134 y(i) = mixes(i).centres(j,:);
wolffd@0 135 c(i,:) = mixes(i).centres';
wolffd@0 136 end
wolffd@0 137 p3 = plot(plotvals, y, '*r');
wolffd@0 138 p4 = plot(plotvals, mlpfwd(net2, plotvals), 'g');
wolffd@0 139 set(p4, 'LineWidth', 2);
wolffd@0 140 legend([p1 p2 p3 p4], 'data', 'function', 'MDN mode', 'MLP mean', 4);
wolffd@0 141 hold off
wolffd@0 142
wolffd@0 143 clc
wolffd@0 144 disp('We can also plot how the mixture model parameters depend on x.')
wolffd@0 145 disp('First we plot the mixture centres, then the priors and finally')
wolffd@0 146 disp('the variances.')
wolffd@0 147 disp(' ')
wolffd@0 148 disp('Press any key to continue.')
wolffd@0 149 pause
wolffd@0 150 fh2 = figure;
wolffd@0 151 subplot(3, 1, 1)
wolffd@0 152 plot(plotvals, c)
wolffd@0 153 hold on
wolffd@0 154 title('Mixture centres')
wolffd@0 155 legend('centre 1', 'centre 2', 'centre 3')
wolffd@0 156 hold off
wolffd@0 157
wolffd@0 158 priors = reshape([mixes.priors], mixes(1).ncentres, size(mixes, 2))';
wolffd@0 159 %%fh3 = figure;
wolffd@0 160 subplot(3, 1, 2)
wolffd@0 161 plot(plotvals, priors)
wolffd@0 162 hold on
wolffd@0 163 title('Mixture priors')
wolffd@0 164 legend('centre 1', 'centre 2', 'centre 3')
wolffd@0 165 hold off
wolffd@0 166
wolffd@0 167 variances = reshape([mixes.covars], mixes(1).ncentres, size(mixes, 2))';
wolffd@0 168 %%fh4 = figure;
wolffd@0 169 subplot(3, 1, 3)
wolffd@0 170 plot(plotvals, variances)
wolffd@0 171 hold on
wolffd@0 172 title('Mixture variances')
wolffd@0 173 legend('centre 1', 'centre 2', 'centre 3')
wolffd@0 174 hold off
wolffd@0 175
wolffd@0 176 disp('The last figure is a contour plot of the conditional probability')
wolffd@0 177 disp('density generated by the Mixture Density Network. Note how it')
wolffd@0 178 disp('is well matched to the regions of high data density.')
wolffd@0 179 disp(' ')
wolffd@0 180 disp('Press any key to continue.')
wolffd@0 181 pause
wolffd@0 182 % Contour plot for MDN.
wolffd@0 183 i = 0:0.01:1.0;
wolffd@0 184 j = 0:0.01:1.0;
wolffd@0 185
wolffd@0 186 [I, J] = meshgrid(i,j);
wolffd@0 187 I = I(:);
wolffd@0 188 J = J(:);
wolffd@0 189 li = length(i);
wolffd@0 190 lj = length(j);
wolffd@0 191 Z = zeros(li, lj);
wolffd@0 192 for k = 1:li;
wolffd@0 193 Z(:,k) = gmmprob(mixes(k), j');
wolffd@0 194 end
wolffd@0 195 fh5 = figure;
wolffd@0 196 % Set up levels by hand to make a good figure
wolffd@0 197 v = [2 2.5 3 3.5 5:3:18];
wolffd@0 198 contour(i, j, Z, v)
wolffd@0 199 hold on
wolffd@0 200 title('Contour plot of conditional density')
wolffd@0 201 hold off
wolffd@0 202
wolffd@0 203 disp(' ')
wolffd@0 204 disp('Press any key to exit.')
wolffd@0 205 pause
wolffd@0 206 close(fh1);
wolffd@0 207 close(fh2);
wolffd@0 208 %%close(fh3);
wolffd@0 209 %%close(fh4);
wolffd@0 210 close(fh5);
wolffd@0 211 %%clear all;