Daniel@0: %DEMGTM1 Demonstrate EM for GTM. Daniel@0: % Daniel@0: % Description Daniel@0: % This script demonstrates the use of the EM algorithm to fit a one- Daniel@0: % dimensional GTM to a two-dimensional set of data using maximum Daniel@0: % likelihood. The location and spread of the Gaussian kernels in the Daniel@0: % data space is shown during training. Daniel@0: % Daniel@0: % See also Daniel@0: % DEMGTM2, GTM, GTMEM, GTMPOST Daniel@0: % Daniel@0: Daniel@0: % Copyright (c) Ian T Nabney (1996-2001) Daniel@0: Daniel@0: % Demonstrates the GTM with a 2D target space and a 1D latent space. Daniel@0: % Daniel@0: % This script generates a simple data set in 2 dimensions, Daniel@0: % with an intrinsic dimensionality of 1, and trains a GTM Daniel@0: % with a 1-dimensional latent variable to model this data Daniel@0: % set, visually illustrating the training process Daniel@0: % Daniel@0: % Synopsis: gtm_demo Daniel@0: Daniel@0: % Generate and plot a 2D data set Daniel@0: Daniel@0: data_min = 0.15; Daniel@0: data_max = 3.05; Daniel@0: T = [data_min:0.05:data_max]'; Daniel@0: T = [T (T + 1.25*sin(2*T))]; Daniel@0: fh1 = figure; Daniel@0: plot(T(:,1), T(:,2), 'ro'); Daniel@0: axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]); Daniel@0: clc; Daniel@0: disp('This demonstration shows in detail how the EM algorithm works') Daniel@0: disp('for training a GTM with a one dimensional latent space.') Daniel@0: disp(' ') Daniel@0: fprintf([... Daniel@0: 'The figure shows data generated by feeding a 1D uniform distribution\n', ... Daniel@0: '(on the X-axis) through a non-linear function (y = x + 1.25*sin(2*x))\n', ... Daniel@0: '\nPress any key to continue ...\n\n']); Daniel@0: pause; Daniel@0: Daniel@0: % Generate a unit circle figure, to be used for plotting Daniel@0: src = [0:(2*pi)/(20-1):2*pi]'; Daniel@0: unitC = [sin(src) cos(src)]; Daniel@0: Daniel@0: % Generate and plot (along with the data) an initial GTM model Daniel@0: Daniel@0: clc; Daniel@0: num_latent_points = 20; Daniel@0: num_rbf_centres = 5; Daniel@0: Daniel@0: net = gtm(1, num_latent_points, 2, num_rbf_centres, 'gaussian'); Daniel@0: Daniel@0: options = zeros(1, 18); Daniel@0: options(7) = 1; Daniel@0: net = gtminit(net, options, T, 'regular', num_latent_points, ... Daniel@0: num_rbf_centres); Daniel@0: Daniel@0: mix = gtmfwd(net); Daniel@0: % Replot the figure Daniel@0: hold off; Daniel@0: plot(mix.centres(:,1), mix.centres(:,2), 'g'); Daniel@0: hold on; Daniel@0: for i=1:num_latent_points Daniel@0: c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ... Daniel@0: ones(num_latent_points,1)*mix.centres(i,2)]; Daniel@0: fill(c(:,1), c(:,2), [0.8 1 0.8]); Daniel@0: end Daniel@0: plot(T(:,1), T(:,2), 'ro'); Daniel@0: plot(mix.centres(:,1), mix.centres(:,2), 'g+'); Daniel@0: plot(mix.centres(:,1), mix.centres(:,2), 'g'); Daniel@0: axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]); Daniel@0: drawnow; Daniel@0: title('Initial configuration'); Daniel@0: disp(' ') Daniel@0: fprintf([... Daniel@0: 'The figure shows the starting point for the GTM, before the training.\n', ... Daniel@0: 'A discrete latent variable distribution of %d points in 1 dimension \n', ... Daniel@0: 'is mapped to the 1st principal component of the target data by an RBF.\n', ... Daniel@0: 'with %d basis functions. Each of the %d points defines the centre of\n', ... Daniel@0: 'a Gaussian in a Gaussian mixture, marked by the green ''+''-signs. The\n', ... Daniel@0: 'mixture components all have equal variance, illustrated by the filled\n', ... Daniel@0: 'circle around each ''+''-sign, the radii corresponding to 2 standard\n', ... Daniel@0: 'deviations. The ''+''-signs are connected with a line according to their\n', ... Daniel@0: 'corresponding ordering in latent space.\n\n', ... Daniel@0: 'Press any key to begin training ...\n\n'], num_latent_points, ... Daniel@0: num_rbf_centres, num_latent_points); Daniel@0: pause; Daniel@0: Daniel@0: figure(fh1); Daniel@0: %%%% Train the GTM and plot it (along with the data) as training proceeds %%%% Daniel@0: options = foptions; Daniel@0: options(1) = -1; % Turn off all warning messages Daniel@0: options(14) = 1; Daniel@0: for j = 1:15 Daniel@0: [net, options] = gtmem(net, T, options); Daniel@0: hold off; Daniel@0: mix = gtmfwd(net); Daniel@0: plot(mix.centres(:,1), mix.centres(:,2), 'g'); Daniel@0: hold on; Daniel@0: for i=1:20 Daniel@0: c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ... Daniel@0: ones(20,1)*mix.centres(i,2)]; Daniel@0: fill(c(:,1), c(:,2), [0.8 1.0 0.8]); Daniel@0: end Daniel@0: plot(T(:,1), T(:,2), 'ro'); Daniel@0: plot(mix.centres(:,1), mix.centres(:,2), 'g+'); Daniel@0: plot(mix.centres(:,1), mix.centres(:,2), 'g'); Daniel@0: axis([0 3.5 0 3.5]); Daniel@0: title(['After ', int2str(j),' iterations of training.']); Daniel@0: drawnow; Daniel@0: if (j == 4) Daniel@0: fprintf([... Daniel@0: 'The GTM initially adapts relatively quickly - already after \n', ... Daniel@0: '4 iterations of training, a rough fit is attained.\n\n', ... Daniel@0: 'Press any key to continue training ...\n\n']); Daniel@0: pause; Daniel@0: figure(fh1); Daniel@0: elseif (j == 8) Daniel@0: fprintf([... Daniel@0: 'After another 4 iterations of training: from now on further \n', ... Daniel@0: 'training only makes small changes to the mapping, which combined with \n', ... Daniel@0: 'decrements of the Gaussian mixture variance, optimize the fit in \n', ... Daniel@0: 'terms of likelihood.\n\n', ... Daniel@0: 'Press any key to continue training ...\n\n']); Daniel@0: pause; Daniel@0: figure(fh1); Daniel@0: else Daniel@0: pause(1); Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: clc; Daniel@0: fprintf([... Daniel@0: 'After 15 iterations of training the GTM can be regarded as converged. \n', ... Daniel@0: 'Is has been adapted to fit the target data distribution as well \n', ... Daniel@0: 'as possible, given prior smoothness constraints on the mapping. It \n', ... Daniel@0: 'captures the fact that the probabilty density is higher at the two \n', ... Daniel@0: 'bends of the curve, and lower towards its end points.\n\n']); Daniel@0: disp(' '); Daniel@0: disp('Press any key to exit.'); Daniel@0: pause; Daniel@0: Daniel@0: close(fh1); Daniel@0: clear all; Daniel@0: