wolffd@0: %DEMGTM1 Demonstrate EM for GTM. wolffd@0: % wolffd@0: % Description wolffd@0: % This script demonstrates the use of the EM algorithm to fit a one- wolffd@0: % dimensional GTM to a two-dimensional set of data using maximum wolffd@0: % likelihood. The location and spread of the Gaussian kernels in the wolffd@0: % data space is shown during training. wolffd@0: % wolffd@0: % See also wolffd@0: % DEMGTM2, GTM, GTMEM, GTMPOST wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: % Demonstrates the GTM with a 2D target space and a 1D latent space. wolffd@0: % wolffd@0: % This script generates a simple data set in 2 dimensions, wolffd@0: % with an intrinsic dimensionality of 1, and trains a GTM wolffd@0: % with a 1-dimensional latent variable to model this data wolffd@0: % set, visually illustrating the training process wolffd@0: % wolffd@0: % Synopsis: gtm_demo wolffd@0: wolffd@0: % Generate and plot a 2D data set wolffd@0: wolffd@0: data_min = 0.15; wolffd@0: data_max = 3.05; wolffd@0: T = [data_min:0.05:data_max]'; wolffd@0: T = [T (T + 1.25*sin(2*T))]; wolffd@0: fh1 = figure; wolffd@0: plot(T(:,1), T(:,2), 'ro'); wolffd@0: axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]); wolffd@0: clc; wolffd@0: disp('This demonstration shows in detail how the EM algorithm works') wolffd@0: disp('for training a GTM with a one dimensional latent space.') wolffd@0: disp(' ') wolffd@0: fprintf([... wolffd@0: 'The figure shows data generated by feeding a 1D uniform distribution\n', ... wolffd@0: '(on the X-axis) through a non-linear function (y = x + 1.25*sin(2*x))\n', ... wolffd@0: '\nPress any key to continue ...\n\n']); wolffd@0: pause; wolffd@0: wolffd@0: % Generate a unit circle figure, to be used for plotting wolffd@0: src = [0:(2*pi)/(20-1):2*pi]'; wolffd@0: unitC = [sin(src) cos(src)]; wolffd@0: wolffd@0: % Generate and plot (along with the data) an initial GTM model wolffd@0: wolffd@0: clc; wolffd@0: num_latent_points = 20; wolffd@0: num_rbf_centres = 5; wolffd@0: wolffd@0: net = gtm(1, num_latent_points, 2, num_rbf_centres, 'gaussian'); wolffd@0: wolffd@0: options = zeros(1, 18); wolffd@0: options(7) = 1; wolffd@0: net = gtminit(net, options, T, 'regular', num_latent_points, ... wolffd@0: num_rbf_centres); wolffd@0: wolffd@0: mix = gtmfwd(net); wolffd@0: % Replot the figure wolffd@0: hold off; wolffd@0: plot(mix.centres(:,1), mix.centres(:,2), 'g'); wolffd@0: hold on; wolffd@0: for i=1:num_latent_points wolffd@0: c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ... wolffd@0: ones(num_latent_points,1)*mix.centres(i,2)]; wolffd@0: fill(c(:,1), c(:,2), [0.8 1 0.8]); wolffd@0: end wolffd@0: plot(T(:,1), T(:,2), 'ro'); wolffd@0: plot(mix.centres(:,1), mix.centres(:,2), 'g+'); wolffd@0: plot(mix.centres(:,1), mix.centres(:,2), 'g'); wolffd@0: axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]); wolffd@0: drawnow; wolffd@0: title('Initial configuration'); wolffd@0: disp(' ') wolffd@0: fprintf([... wolffd@0: 'The figure shows the starting point for the GTM, before the training.\n', ... wolffd@0: 'A discrete latent variable distribution of %d points in 1 dimension \n', ... wolffd@0: 'is mapped to the 1st principal component of the target data by an RBF.\n', ... wolffd@0: 'with %d basis functions. Each of the %d points defines the centre of\n', ... wolffd@0: 'a Gaussian in a Gaussian mixture, marked by the green ''+''-signs. The\n', ... wolffd@0: 'mixture components all have equal variance, illustrated by the filled\n', ... wolffd@0: 'circle around each ''+''-sign, the radii corresponding to 2 standard\n', ... wolffd@0: 'deviations. The ''+''-signs are connected with a line according to their\n', ... wolffd@0: 'corresponding ordering in latent space.\n\n', ... wolffd@0: 'Press any key to begin training ...\n\n'], num_latent_points, ... wolffd@0: num_rbf_centres, num_latent_points); wolffd@0: pause; wolffd@0: wolffd@0: figure(fh1); wolffd@0: %%%% Train the GTM and plot it (along with the data) as training proceeds %%%% wolffd@0: options = foptions; wolffd@0: options(1) = -1; % Turn off all warning messages wolffd@0: options(14) = 1; wolffd@0: for j = 1:15 wolffd@0: [net, options] = gtmem(net, T, options); wolffd@0: hold off; wolffd@0: mix = gtmfwd(net); wolffd@0: plot(mix.centres(:,1), mix.centres(:,2), 'g'); wolffd@0: hold on; wolffd@0: for i=1:20 wolffd@0: c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ... wolffd@0: ones(20,1)*mix.centres(i,2)]; wolffd@0: fill(c(:,1), c(:,2), [0.8 1.0 0.8]); wolffd@0: end wolffd@0: plot(T(:,1), T(:,2), 'ro'); wolffd@0: plot(mix.centres(:,1), mix.centres(:,2), 'g+'); wolffd@0: plot(mix.centres(:,1), mix.centres(:,2), 'g'); wolffd@0: axis([0 3.5 0 3.5]); wolffd@0: title(['After ', int2str(j),' iterations of training.']); wolffd@0: drawnow; wolffd@0: if (j == 4) wolffd@0: fprintf([... wolffd@0: 'The GTM initially adapts relatively quickly - already after \n', ... wolffd@0: '4 iterations of training, a rough fit is attained.\n\n', ... wolffd@0: 'Press any key to continue training ...\n\n']); wolffd@0: pause; wolffd@0: figure(fh1); wolffd@0: elseif (j == 8) wolffd@0: fprintf([... wolffd@0: 'After another 4 iterations of training: from now on further \n', ... wolffd@0: 'training only makes small changes to the mapping, which combined with \n', ... wolffd@0: 'decrements of the Gaussian mixture variance, optimize the fit in \n', ... wolffd@0: 'terms of likelihood.\n\n', ... wolffd@0: 'Press any key to continue training ...\n\n']); wolffd@0: pause; wolffd@0: figure(fh1); wolffd@0: else wolffd@0: pause(1); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: clc; wolffd@0: fprintf([... wolffd@0: 'After 15 iterations of training the GTM can be regarded as converged. \n', ... wolffd@0: 'Is has been adapted to fit the target data distribution as well \n', ... wolffd@0: 'as possible, given prior smoothness constraints on the mapping. It \n', ... wolffd@0: 'captures the fact that the probabilty density is higher at the two \n', ... wolffd@0: 'bends of the curve, and lower towards its end points.\n\n']); wolffd@0: disp(' '); wolffd@0: disp('Press any key to exit.'); wolffd@0: pause; wolffd@0: wolffd@0: close(fh1); wolffd@0: clear all; wolffd@0: