wolffd@0
|
1 %DEMGMM1 Demonstrate EM for Gaussian mixtures.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % This script demonstrates the use of the EM algorithm to fit a mixture
|
wolffd@0
|
5 % of Gaussians to a set of data using maximum likelihood. A colour
|
wolffd@0
|
6 % coding scheme is used to illustrate the evaluation of the posterior
|
wolffd@0
|
7 % probabilities in the E-step of the EM algorithm.
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % See also
|
wolffd@0
|
10 % DEMGMM2, DEMGMM3, DEMGMM4, GMM, GMMEM, GMMPOST
|
wolffd@0
|
11 %
|
wolffd@0
|
12
|
wolffd@0
|
13 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
14
|
wolffd@0
|
15 clc;
|
wolffd@0
|
16 disp('This demonstration illustrates the use of the EM (expectation-')
|
wolffd@0
|
17 disp('maximization) algorithm for fitting of a mixture of Gaussians to a')
|
wolffd@0
|
18 disp('data set by maximum likelihood.')
|
wolffd@0
|
19 disp(' ')
|
wolffd@0
|
20 disp('The data set consists of 40 data points in a 2-dimensional')
|
wolffd@0
|
21 disp('space, generated by sampling from a mixture of 2 Gaussian')
|
wolffd@0
|
22 disp('distributions.')
|
wolffd@0
|
23 disp(' ')
|
wolffd@0
|
24 disp('Press any key to see a plot of the data.')
|
wolffd@0
|
25 pause;
|
wolffd@0
|
26
|
wolffd@0
|
27 % Generate the data
|
wolffd@0
|
28 randn('state', 0); rand('state', 0);
|
wolffd@0
|
29 gmix = gmm(2, 2, 'spherical');
|
wolffd@0
|
30 ndat1 = 20; ndat2 = 20; ndata = ndat1+ndat2;
|
wolffd@0
|
31 gmix.centres = [0.3 0.3; 0.7 0.7];
|
wolffd@0
|
32 gmix.covars = [0.01 0.01];
|
wolffd@0
|
33 x = gmmsamp(gmix, ndata);
|
wolffd@0
|
34
|
wolffd@0
|
35 h = figure;
|
wolffd@0
|
36 hd = plot(x(:, 1), x(:, 2), '.g', 'markersize', 30);
|
wolffd@0
|
37 hold on; axis([0 1 0 1]); axis square; set(gca, 'box', 'on');
|
wolffd@0
|
38 ht = text(0.5, 1.05, 'Data', 'horizontalalignment', 'center');
|
wolffd@0
|
39 disp(' ');
|
wolffd@0
|
40 disp('Press any key to continue.')
|
wolffd@0
|
41 pause; clc;
|
wolffd@0
|
42
|
wolffd@0
|
43 disp('We next create and initialize a mixture model consisting of a mixture')
|
wolffd@0
|
44 disp('of 2 Gaussians having ''spherical'' covariance matrices, using the')
|
wolffd@0
|
45 disp('function GMM. The Gaussian components can be displayed on the same')
|
wolffd@0
|
46 disp('plot as the data by drawing a contour of constant probability density')
|
wolffd@0
|
47 disp('for each component having radius equal to the corresponding standard')
|
wolffd@0
|
48 disp('deviation. Component 1 is coloured red and component 2 is coloured')
|
wolffd@0
|
49 disp('blue.')
|
wolffd@0
|
50 disp(' ')
|
wolffd@0
|
51 disp('Note that a particulary poor choice of initial parameters has been')
|
wolffd@0
|
52 disp('made in order to illustrate more effectively the operation of the')
|
wolffd@0
|
53 disp('EM algorithm.')
|
wolffd@0
|
54 disp(' ')
|
wolffd@0
|
55 disp('Press any key to see the initial configuration of the mixture model.')
|
wolffd@0
|
56 pause;
|
wolffd@0
|
57
|
wolffd@0
|
58 % Set up mixture model
|
wolffd@0
|
59 ncentres = 2; input_dim = 2;
|
wolffd@0
|
60 mix = gmm(input_dim, ncentres, 'spherical');
|
wolffd@0
|
61
|
wolffd@0
|
62 % Initialise the mixture model
|
wolffd@0
|
63 mix.centres = [0.2 0.8; 0.8, 0.2];
|
wolffd@0
|
64 mix.covars = [0.01 0.01];
|
wolffd@0
|
65
|
wolffd@0
|
66 % Plot the initial model
|
wolffd@0
|
67 ncirc = 30; theta = linspace(0, 2*pi, ncirc);
|
wolffd@0
|
68 xs = cos(theta); ys = sin(theta);
|
wolffd@0
|
69 xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs;
|
wolffd@0
|
70 yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys;
|
wolffd@0
|
71 hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r');
|
wolffd@0
|
72 hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b');
|
wolffd@0
|
73 set(ht, 'string', 'Initial Configuration');
|
wolffd@0
|
74 figure(h);
|
wolffd@0
|
75 disp(' ')
|
wolffd@0
|
76 disp('Press any key to continue');
|
wolffd@0
|
77 pause; clc;
|
wolffd@0
|
78
|
wolffd@0
|
79 disp('Now we adapt the parameters of the mixture model iteratively using the')
|
wolffd@0
|
80 disp('EM algorithm. Each cycle of the EM algorithm consists of an E-step')
|
wolffd@0
|
81 disp('followed by an M-step. We start with the E-step, which involves the')
|
wolffd@0
|
82 disp('evaluation of the posterior probabilities (responsibilities) which the')
|
wolffd@0
|
83 disp('two components have for each of the data points.')
|
wolffd@0
|
84 disp(' ')
|
wolffd@0
|
85 disp('Since we have labelled the two components using the colours red and')
|
wolffd@0
|
86 disp('blue, a convenient way to indicate the value of a posterior')
|
wolffd@0
|
87 disp('probability for a given data point is to colour the point using a')
|
wolffd@0
|
88 disp('scale ranging from pure red (corresponding to a posterior probability')
|
wolffd@0
|
89 disp('of 1.0 for the red component and 0.0 for the blue component) through')
|
wolffd@0
|
90 disp('to pure blue.')
|
wolffd@0
|
91 disp(' ')
|
wolffd@0
|
92 disp('Press any key to see the result of applying the first E-step.')
|
wolffd@0
|
93 pause;
|
wolffd@0
|
94
|
wolffd@0
|
95 % Initial E-step.
|
wolffd@0
|
96 set(ht, 'string', 'E-step');
|
wolffd@0
|
97 post = gmmpost(mix, x);
|
wolffd@0
|
98 dcols = [post(:,1), zeros(ndata, 1), post(:,2)];
|
wolffd@0
|
99 delete(hd);
|
wolffd@0
|
100 for i = 1 : ndata
|
wolffd@0
|
101 hd(i) = plot(x(i, 1), x(i, 2), 'color', dcols(i,:), ...
|
wolffd@0
|
102 'marker', '.', 'markersize', 30);
|
wolffd@0
|
103 end
|
wolffd@0
|
104 figure(h);
|
wolffd@0
|
105
|
wolffd@0
|
106 disp(' ');
|
wolffd@0
|
107 disp('Press any key to continue')
|
wolffd@0
|
108 pause; clc;
|
wolffd@0
|
109
|
wolffd@0
|
110 disp('Next we perform the corresponding M-step. This involves replacing the')
|
wolffd@0
|
111 disp('centres of the component Gaussians by the corresponding weighted means')
|
wolffd@0
|
112 disp('of the data. Thus the centre of the red component is replaced by the')
|
wolffd@0
|
113 disp('mean of the data set, in which each data point is weighted according to')
|
wolffd@0
|
114 disp('the amount of red ink (corresponding to the responsibility of')
|
wolffd@0
|
115 disp('component 1 for explaining that data point). The variances and mixing')
|
wolffd@0
|
116 disp('proportions of the two components are similarly re-estimated.')
|
wolffd@0
|
117 disp(' ')
|
wolffd@0
|
118 disp('Press any key to see the result of applying the first M-step.')
|
wolffd@0
|
119 pause;
|
wolffd@0
|
120
|
wolffd@0
|
121 % M-step.
|
wolffd@0
|
122 set(ht, 'string', 'M-step');
|
wolffd@0
|
123 options = foptions;
|
wolffd@0
|
124 options(14) = 1; % A single iteration
|
wolffd@0
|
125 options(1) = -1; % Switch off all messages, including warning
|
wolffd@0
|
126 mix = gmmem(mix, x, options);
|
wolffd@0
|
127 delete(hc);
|
wolffd@0
|
128 xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs;
|
wolffd@0
|
129 yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys;
|
wolffd@0
|
130 hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r');
|
wolffd@0
|
131 hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b');
|
wolffd@0
|
132 figure(h);
|
wolffd@0
|
133 disp(' ')
|
wolffd@0
|
134 disp('Press any key to continue')
|
wolffd@0
|
135 pause; clc;
|
wolffd@0
|
136
|
wolffd@0
|
137 disp('We can continue making alternate E and M steps until the changes in')
|
wolffd@0
|
138 disp('the log likelihood at each cycle become sufficiently small.')
|
wolffd@0
|
139 disp(' ')
|
wolffd@0
|
140 disp('Press any key to see an animation of a further 9 EM cycles.')
|
wolffd@0
|
141 pause;
|
wolffd@0
|
142 figure(h);
|
wolffd@0
|
143
|
wolffd@0
|
144 % Loop over EM iterations.
|
wolffd@0
|
145 numiters = 9;
|
wolffd@0
|
146 for n = 1 : numiters
|
wolffd@0
|
147
|
wolffd@0
|
148 set(ht, 'string', 'E-step');
|
wolffd@0
|
149 post = gmmpost(mix, x);
|
wolffd@0
|
150 dcols = [post(:,1), zeros(ndata, 1), post(:,2)];
|
wolffd@0
|
151 delete(hd);
|
wolffd@0
|
152 for i = 1 : ndata
|
wolffd@0
|
153 hd(i) = plot(x(i, 1), x(i, 2), 'color', dcols(i,:), ...
|
wolffd@0
|
154 'marker', '.', 'markersize', 30);
|
wolffd@0
|
155 end
|
wolffd@0
|
156 pause(1)
|
wolffd@0
|
157
|
wolffd@0
|
158 set(ht, 'string', 'M-step');
|
wolffd@0
|
159 [mix, options] = gmmem(mix, x, options);
|
wolffd@0
|
160 fprintf(1, 'Cycle %4d Error %11.6f\n', n, options(8));
|
wolffd@0
|
161 delete(hc);
|
wolffd@0
|
162 xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs;
|
wolffd@0
|
163 yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys;
|
wolffd@0
|
164 hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r');
|
wolffd@0
|
165 hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b');
|
wolffd@0
|
166 pause(1)
|
wolffd@0
|
167
|
wolffd@0
|
168 end
|
wolffd@0
|
169
|
wolffd@0
|
170 disp(' ');
|
wolffd@0
|
171 disp('Press any key to end.')
|
wolffd@0
|
172 pause; clc; close(h); clear all
|
wolffd@0
|
173
|