wolffd@0
|
1 %DEMGMM1 Demonstrate EM for Gaussian mixtures.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % This script demonstrates the use of the EM algorithm to fit a mixture
|
wolffd@0
|
5 % of Gaussians to a set of data using maximum likelihood. A colour
|
wolffd@0
|
6 % coding scheme is used to illustrate the evaluation of the posterior
|
wolffd@0
|
7 % probabilities in the E-step of the EM algorithm.
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % See also
|
wolffd@0
|
10 % DEMGMM2, DEMGMM3, DEMGMM4, GMM, GMMEM, GMMPOST
|
wolffd@0
|
11 %
|
wolffd@0
|
12
|
wolffd@0
|
13 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
14
|
wolffd@0
|
15 mov = avifile('movies/gmm1.avi','fps',1 );
|
wolffd@0
|
16
|
wolffd@0
|
17 % Generate the data
|
wolffd@0
|
18 randn('state', 0); rand('state', 0);
|
wolffd@0
|
19 gmix = gmm(2, 2, 'spherical');
|
wolffd@0
|
20 ndat1 = 20; ndat2 = 20; ndata = ndat1+ndat2;
|
wolffd@0
|
21 gmix.centres = [0.3 0.3; 0.7 0.7];
|
wolffd@0
|
22 gmix.covars = [0.01 0.01];
|
wolffd@0
|
23 x = gmmsamp(gmix, ndata);
|
wolffd@0
|
24
|
wolffd@0
|
25 h = figure;
|
wolffd@0
|
26 hd = plot(x(:, 1), x(:, 2), '.g', 'markersize', 30);
|
wolffd@0
|
27 hold on; axis([0 1 0 1]); axis square; set(gca, 'box', 'on');
|
wolffd@0
|
28 ht = text(0.5, 1.05, 'Data', 'horizontalalignment', 'center');
|
wolffd@0
|
29
|
wolffd@0
|
30
|
wolffd@0
|
31 % Set up mixture model
|
wolffd@0
|
32 ncentres = 2; input_dim = 2;
|
wolffd@0
|
33 mix = gmm(input_dim, ncentres, 'spherical');
|
wolffd@0
|
34
|
wolffd@0
|
35 % Initialise the mixture model
|
wolffd@0
|
36 mix.centres = [0.2 0.8; 0.8, 0.2];
|
wolffd@0
|
37 mix.covars = [0.01 0.01];
|
wolffd@0
|
38
|
wolffd@0
|
39 % Plot the initial model
|
wolffd@0
|
40 ncirc = 30; theta = linspace(0, 2*pi, ncirc);
|
wolffd@0
|
41 xs = cos(theta); ys = sin(theta);
|
wolffd@0
|
42 xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs;
|
wolffd@0
|
43 yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys;
|
wolffd@0
|
44 hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r');
|
wolffd@0
|
45 hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b');
|
wolffd@0
|
46 set(ht, 'string', 'Initial Configuration');
|
wolffd@0
|
47 figure(h);
|
wolffd@0
|
48 mov = addframe(mov, getframe(gcf));
|
wolffd@0
|
49 mov = addframe(mov, getframe(gcf));
|
wolffd@0
|
50
|
wolffd@0
|
51 % Initial E-step.
|
wolffd@0
|
52 set(ht, 'string', 'E-step');
|
wolffd@0
|
53 post = gmmpost(mix, x);
|
wolffd@0
|
54 dcols = [post(:,1), zeros(ndata, 1), post(:,2)];
|
wolffd@0
|
55 delete(hd);
|
wolffd@0
|
56 for i = 1 : ndata
|
wolffd@0
|
57 hd(i) = plot(x(i, 1), x(i, 2), 'color', dcols(i,:), ...
|
wolffd@0
|
58 'marker', '.', 'markersize', 30);
|
wolffd@0
|
59 end
|
wolffd@0
|
60
|
wolffd@0
|
61 % M-step.
|
wolffd@0
|
62 set(ht, 'string', 'M-step');
|
wolffd@0
|
63 options = foptions;
|
wolffd@0
|
64 options(14) = 1; % A single iteration
|
wolffd@0
|
65 options(1) = -1; % Switch off all messages, including warning
|
wolffd@0
|
66 mix = gmmem(mix, x, options);
|
wolffd@0
|
67 delete(hc);
|
wolffd@0
|
68 xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs;
|
wolffd@0
|
69 yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys;
|
wolffd@0
|
70 hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r');
|
wolffd@0
|
71 hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b');
|
wolffd@0
|
72 figure(h);
|
wolffd@0
|
73 mov = addframe(mov, getframe(gcf));
|
wolffd@0
|
74 mov = addframe(mov, getframe(gcf));
|
wolffd@0
|
75
|
wolffd@0
|
76 % Loop over EM iterations.
|
wolffd@0
|
77 numiters = 9;
|
wolffd@0
|
78 for n = 1 : numiters
|
wolffd@0
|
79
|
wolffd@0
|
80 set(ht, 'string', 'E-step');
|
wolffd@0
|
81 post = gmmpost(mix, x);
|
wolffd@0
|
82 dcols = [post(:,1), zeros(ndata, 1), post(:,2)];
|
wolffd@0
|
83 delete(hd);
|
wolffd@0
|
84 for i = 1 : ndata
|
wolffd@0
|
85 hd(i) = plot(x(i, 1), x(i, 2), 'color', dcols(i,:), ...
|
wolffd@0
|
86 'marker', '.', 'markersize', 30);
|
wolffd@0
|
87 end
|
wolffd@0
|
88 %pause(1)
|
wolffd@0
|
89
|
wolffd@0
|
90 set(ht, 'string', 'M-step');
|
wolffd@0
|
91 [mix, options] = gmmem(mix, x, options);
|
wolffd@0
|
92 fprintf(1, 'Cycle %4d Error %11.6f\n', n, options(8));
|
wolffd@0
|
93 delete(hc);
|
wolffd@0
|
94 xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs;
|
wolffd@0
|
95 yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys;
|
wolffd@0
|
96 hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r');
|
wolffd@0
|
97 hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b');
|
wolffd@0
|
98 pause(1)
|
wolffd@0
|
99
|
wolffd@0
|
100 mov = addframe(mov, getframe(gcf));
|
wolffd@0
|
101 end
|
wolffd@0
|
102
|
wolffd@0
|
103 mov = close(mov);
|