Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/demmdn1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 %DEMMDN1 Demonstrate fitting a multi-valued function using a Mixture Density Network. | |
2 % | |
3 % Description | |
4 % The problem consists of one input variable X and one target variable | |
5 % T with data generated by sampling T at equal intervals and then | |
6 % generating target data by computing T + 0.3*SIN(2*PI*T) and adding | |
7 % Gaussian noise. A Mixture Density Network with 3 centres in the | |
8 % mixture model is trained by minimizing a negative log likelihood | |
9 % error function using the scaled conjugate gradient optimizer. | |
10 % | |
11 % The conditional means, mixing coefficients and variances are plotted | |
12 % as a function of X, and a contour plot of the full conditional | |
13 % density is also generated. | |
14 % | |
15 % See also | |
16 % MDN, MDNERR, MDNGRAD, SCG | |
17 % | |
18 | |
19 % Copyright (c) Ian T Nabney (1996-2001) | |
20 | |
21 | |
22 % Generate the matrix of inputs x and targets t. | |
23 seedn = 42; | |
24 seed = 42; | |
25 randn('state', seedn); | |
26 rand('state', seed); | |
27 ndata = 300; % Number of data points. | |
28 noise = 0.2; % Range of noise distribution. | |
29 t = [0:1/(ndata - 1):1]'; | |
30 x = t + 0.3*sin(2*pi*t) + noise*rand(ndata, 1) - noise/2; | |
31 axis_limits = [-0.2 1.2 -0.2 1.2]; | |
32 | |
33 clc | |
34 disp('This demonstration illustrates the use of a Mixture Density Network') | |
35 disp('to model multi-valued functions. The data is generated from the') | |
36 disp('mapping x = t + 0.3 sin(2 pi t) + e, where e is a noise term.') | |
37 disp('We begin by plotting the data.') | |
38 disp(' ') | |
39 disp('Press any key to continue') | |
40 pause | |
41 % Plot the data | |
42 fh1 = figure; | |
43 p1 = plot(x, t, 'ob'); | |
44 axis(axis_limits); | |
45 hold on | |
46 disp('Note that for x in the range 0.35 to 0.65, there are three possible') | |
47 disp('branches of the function.') | |
48 disp(' ') | |
49 disp('Press any key to continue') | |
50 pause | |
51 | |
52 % Set up network parameters. | |
53 nin = 1; % Number of inputs. | |
54 nhidden = 5; % Number of hidden units. | |
55 ncentres = 3; % Number of mixture components. | |
56 dim_target = 1; % Dimension of target space | |
57 mdntype = '0'; % Currently unused: reserved for future use | |
58 alpha = 100; % Inverse variance for weight initialisation | |
59 % Make variance small for good starting point | |
60 | |
61 % Create and initialize network weight vector. | |
62 net = mdn(nin, nhidden, ncentres, dim_target, mdntype); | |
63 init_options = zeros(1, 18); | |
64 init_options(1) = -1; % Suppress all messages | |
65 init_options(14) = 10; % 10 iterations of K means in gmminit | |
66 net = mdninit(net, alpha, t, init_options); | |
67 | |
68 % Set up vector of options for the optimiser. | |
69 options = foptions; | |
70 options(1) = 1; % This provides display of error values. | |
71 options(14) = 200; % Number of training cycles. | |
72 | |
73 clc | |
74 disp('We initialise the neural network model, which is an MLP with a') | |
75 disp('Gaussian mixture model with three components and spherical variance') | |
76 disp('as the error function. This enables us to model the complete') | |
77 disp('conditional density function.') | |
78 disp(' ') | |
79 disp('Next we train the model for 200 epochs using a scaled conjugate gradient') | |
80 disp('optimizer. The error function is the negative log likelihood of the') | |
81 disp('training data.') | |
82 disp(' ') | |
83 disp('Press any key to continue.') | |
84 pause | |
85 | |
86 % Train using scaled conjugate gradients. | |
87 [net, options] = netopt(net, options, x, t, 'scg'); | |
88 | |
89 disp(' ') | |
90 disp('Press any key to continue.') | |
91 pause | |
92 | |
93 clc | |
94 disp('We can also train a conventional MLP with sum of squares error function.') | |
95 disp('This will approximate the conditional mean, which is not always a') | |
96 disp('good representation of the data. Note that the error function is the') | |
97 disp('sum of squares error on the training data, which accounts for the') | |
98 disp('different values from training the MDN.') | |
99 disp(' ') | |
100 disp('We train the network with the quasi-Newton optimizer for 80 epochs.') | |
101 disp(' ') | |
102 disp('Press any key to continue.') | |
103 pause | |
104 mlp_nhidden = 8; | |
105 net2 = mlp(nin, mlp_nhidden, dim_target, 'linear'); | |
106 options(14) = 80; | |
107 [net2, options] = netopt(net2, options, x, t, 'quasinew'); | |
108 disp(' ') | |
109 disp('Press any key to continue.') | |
110 pause | |
111 | |
112 clc | |
113 disp('Now we plot the underlying function, the MDN prediction,') | |
114 disp('represented by the mode of the conditional distribution, and the') | |
115 disp('prediction of the conventional MLP.') | |
116 disp(' ') | |
117 disp('Press any key to continue.') | |
118 pause | |
119 | |
120 % Plot the original function, and the trained network function. | |
121 plotvals = [0:0.01:1]'; | |
122 mixes = mdn2gmm(mdnfwd(net, plotvals)); | |
123 axis(axis_limits); | |
124 yplot = t+0.3*sin(2*pi*t); | |
125 p2 = plot(yplot, t, '--y'); | |
126 | |
127 % Use the mode to represent the function | |
128 y = zeros(1, length(plotvals)); | |
129 priors = zeros(length(plotvals), ncentres); | |
130 c = zeros(length(plotvals), 3); | |
131 widths = zeros(length(plotvals), ncentres); | |
132 for i = 1:length(plotvals) | |
133 [m, j] = max(mixes(i).priors); | |
134 y(i) = mixes(i).centres(j,:); | |
135 c(i,:) = mixes(i).centres'; | |
136 end | |
137 p3 = plot(plotvals, y, '*r'); | |
138 p4 = plot(plotvals, mlpfwd(net2, plotvals), 'g'); | |
139 set(p4, 'LineWidth', 2); | |
140 legend([p1 p2 p3 p4], 'data', 'function', 'MDN mode', 'MLP mean', 4); | |
141 hold off | |
142 | |
143 clc | |
144 disp('We can also plot how the mixture model parameters depend on x.') | |
145 disp('First we plot the mixture centres, then the priors and finally') | |
146 disp('the variances.') | |
147 disp(' ') | |
148 disp('Press any key to continue.') | |
149 pause | |
150 fh2 = figure; | |
151 subplot(3, 1, 1) | |
152 plot(plotvals, c) | |
153 hold on | |
154 title('Mixture centres') | |
155 legend('centre 1', 'centre 2', 'centre 3') | |
156 hold off | |
157 | |
158 priors = reshape([mixes.priors], mixes(1).ncentres, size(mixes, 2))'; | |
159 %%fh3 = figure; | |
160 subplot(3, 1, 2) | |
161 plot(plotvals, priors) | |
162 hold on | |
163 title('Mixture priors') | |
164 legend('centre 1', 'centre 2', 'centre 3') | |
165 hold off | |
166 | |
167 variances = reshape([mixes.covars], mixes(1).ncentres, size(mixes, 2))'; | |
168 %%fh4 = figure; | |
169 subplot(3, 1, 3) | |
170 plot(plotvals, variances) | |
171 hold on | |
172 title('Mixture variances') | |
173 legend('centre 1', 'centre 2', 'centre 3') | |
174 hold off | |
175 | |
176 disp('The last figure is a contour plot of the conditional probability') | |
177 disp('density generated by the Mixture Density Network. Note how it') | |
178 disp('is well matched to the regions of high data density.') | |
179 disp(' ') | |
180 disp('Press any key to continue.') | |
181 pause | |
182 % Contour plot for MDN. | |
183 i = 0:0.01:1.0; | |
184 j = 0:0.01:1.0; | |
185 | |
186 [I, J] = meshgrid(i,j); | |
187 I = I(:); | |
188 J = J(:); | |
189 li = length(i); | |
190 lj = length(j); | |
191 Z = zeros(li, lj); | |
192 for k = 1:li; | |
193 Z(:,k) = gmmprob(mixes(k), j'); | |
194 end | |
195 fh5 = figure; | |
196 % Set up levels by hand to make a good figure | |
197 v = [2 2.5 3 3.5 5:3:18]; | |
198 contour(i, j, Z, v) | |
199 hold on | |
200 title('Contour plot of conditional density') | |
201 hold off | |
202 | |
203 disp(' ') | |
204 disp('Press any key to exit.') | |
205 pause | |
206 close(fh1); | |
207 close(fh2); | |
208 %%close(fh3); | |
209 %%close(fh4); | |
210 close(fh5); | |
211 %%clear all; |