wolffd@0: %DEMGMM1 Demonstrate EM for Gaussian mixtures. wolffd@0: % wolffd@0: % Description wolffd@0: % This script demonstrates the use of the EM algorithm to fit a mixture wolffd@0: % of Gaussians to a set of data using maximum likelihood. A colour wolffd@0: % coding scheme is used to illustrate the evaluation of the posterior wolffd@0: % probabilities in the E-step of the EM algorithm. wolffd@0: % wolffd@0: % See also wolffd@0: % DEMGMM2, DEMGMM3, DEMGMM4, GMM, GMMEM, GMMPOST wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: mov = avifile('movies/gmm1.avi','fps',1 ); wolffd@0: wolffd@0: % Generate the data wolffd@0: randn('state', 0); rand('state', 0); wolffd@0: gmix = gmm(2, 2, 'spherical'); wolffd@0: ndat1 = 20; ndat2 = 20; ndata = ndat1+ndat2; wolffd@0: gmix.centres = [0.3 0.3; 0.7 0.7]; wolffd@0: gmix.covars = [0.01 0.01]; wolffd@0: x = gmmsamp(gmix, ndata); wolffd@0: wolffd@0: h = figure; wolffd@0: hd = plot(x(:, 1), x(:, 2), '.g', 'markersize', 30); wolffd@0: hold on; axis([0 1 0 1]); axis square; set(gca, 'box', 'on'); wolffd@0: ht = text(0.5, 1.05, 'Data', 'horizontalalignment', 'center'); wolffd@0: wolffd@0: wolffd@0: % Set up mixture model wolffd@0: ncentres = 2; input_dim = 2; wolffd@0: mix = gmm(input_dim, ncentres, 'spherical'); wolffd@0: wolffd@0: % Initialise the mixture model wolffd@0: mix.centres = [0.2 0.8; 0.8, 0.2]; wolffd@0: mix.covars = [0.01 0.01]; wolffd@0: wolffd@0: % Plot the initial model wolffd@0: ncirc = 30; theta = linspace(0, 2*pi, ncirc); wolffd@0: xs = cos(theta); ys = sin(theta); wolffd@0: xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs; wolffd@0: yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys; wolffd@0: hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r'); wolffd@0: hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b'); wolffd@0: set(ht, 'string', 'Initial Configuration'); wolffd@0: figure(h); wolffd@0: mov = addframe(mov, getframe(gcf)); wolffd@0: mov = addframe(mov, getframe(gcf)); wolffd@0: wolffd@0: % Initial E-step. wolffd@0: set(ht, 'string', 'E-step'); wolffd@0: post = gmmpost(mix, x); wolffd@0: dcols = [post(:,1), zeros(ndata, 1), post(:,2)]; wolffd@0: delete(hd); wolffd@0: for i = 1 : ndata wolffd@0: hd(i) = plot(x(i, 1), x(i, 2), 'color', dcols(i,:), ... wolffd@0: 'marker', '.', 'markersize', 30); wolffd@0: end wolffd@0: wolffd@0: % M-step. wolffd@0: set(ht, 'string', 'M-step'); wolffd@0: options = foptions; wolffd@0: options(14) = 1; % A single iteration wolffd@0: options(1) = -1; % Switch off all messages, including warning wolffd@0: mix = gmmem(mix, x, options); wolffd@0: delete(hc); wolffd@0: xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs; wolffd@0: yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys; wolffd@0: hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r'); wolffd@0: hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b'); wolffd@0: figure(h); wolffd@0: mov = addframe(mov, getframe(gcf)); wolffd@0: mov = addframe(mov, getframe(gcf)); wolffd@0: wolffd@0: % Loop over EM iterations. wolffd@0: numiters = 9; wolffd@0: for n = 1 : numiters wolffd@0: wolffd@0: set(ht, 'string', 'E-step'); wolffd@0: post = gmmpost(mix, x); wolffd@0: dcols = [post(:,1), zeros(ndata, 1), post(:,2)]; wolffd@0: delete(hd); wolffd@0: for i = 1 : ndata wolffd@0: hd(i) = plot(x(i, 1), x(i, 2), 'color', dcols(i,:), ... wolffd@0: 'marker', '.', 'markersize', 30); wolffd@0: end wolffd@0: %pause(1) wolffd@0: wolffd@0: set(ht, 'string', 'M-step'); wolffd@0: [mix, options] = gmmem(mix, x, options); wolffd@0: fprintf(1, 'Cycle %4d Error %11.6f\n', n, options(8)); wolffd@0: delete(hc); wolffd@0: xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs; wolffd@0: yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys; wolffd@0: hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r'); wolffd@0: hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b'); wolffd@0: pause(1) wolffd@0: wolffd@0: mov = addframe(mov, getframe(gcf)); wolffd@0: end wolffd@0: wolffd@0: mov = close(mov);