comparison toolboxes/FullBNT-1.0.7/netlab3.3/demgmm1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMGMM1 Demonstrate EM for Gaussian mixtures.
2 %
3 % Description
4 % This script demonstrates the use of the EM algorithm to fit a mixture
5 % of Gaussians to a set of data using maximum likelihood. A colour
6 % coding scheme is used to illustrate the evaluation of the posterior
7 % probabilities in the E-step of the EM algorithm.
8 %
9 % See also
10 % DEMGMM2, DEMGMM3, DEMGMM4, GMM, GMMEM, GMMPOST
11 %
12
13 % Copyright (c) Ian T Nabney (1996-2001)
14
15 clc;
16 disp('This demonstration illustrates the use of the EM (expectation-')
17 disp('maximization) algorithm for fitting of a mixture of Gaussians to a')
18 disp('data set by maximum likelihood.')
19 disp(' ')
20 disp('The data set consists of 40 data points in a 2-dimensional')
21 disp('space, generated by sampling from a mixture of 2 Gaussian')
22 disp('distributions.')
23 disp(' ')
24 disp('Press any key to see a plot of the data.')
25 pause;
26
27 % Generate the data
28 randn('state', 0); rand('state', 0);
29 gmix = gmm(2, 2, 'spherical');
30 ndat1 = 20; ndat2 = 20; ndata = ndat1+ndat2;
31 gmix.centres = [0.3 0.3; 0.7 0.7];
32 gmix.covars = [0.01 0.01];
33 x = gmmsamp(gmix, ndata);
34
35 h = figure;
36 hd = plot(x(:, 1), x(:, 2), '.g', 'markersize', 30);
37 hold on; axis([0 1 0 1]); axis square; set(gca, 'box', 'on');
38 ht = text(0.5, 1.05, 'Data', 'horizontalalignment', 'center');
39 disp(' ');
40 disp('Press any key to continue.')
41 pause; clc;
42
43 disp('We next create and initialize a mixture model consisting of a mixture')
44 disp('of 2 Gaussians having ''spherical'' covariance matrices, using the')
45 disp('function GMM. The Gaussian components can be displayed on the same')
46 disp('plot as the data by drawing a contour of constant probability density')
47 disp('for each component having radius equal to the corresponding standard')
48 disp('deviation. Component 1 is coloured red and component 2 is coloured')
49 disp('blue.')
50 disp(' ')
51 disp('Note that a particulary poor choice of initial parameters has been')
52 disp('made in order to illustrate more effectively the operation of the')
53 disp('EM algorithm.')
54 disp(' ')
55 disp('Press any key to see the initial configuration of the mixture model.')
56 pause;
57
58 % Set up mixture model
59 ncentres = 2; input_dim = 2;
60 mix = gmm(input_dim, ncentres, 'spherical');
61
62 % Initialise the mixture model
63 mix.centres = [0.2 0.8; 0.8, 0.2];
64 mix.covars = [0.01 0.01];
65
66 % Plot the initial model
67 ncirc = 30; theta = linspace(0, 2*pi, ncirc);
68 xs = cos(theta); ys = sin(theta);
69 xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs;
70 yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys;
71 hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r');
72 hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b');
73 set(ht, 'string', 'Initial Configuration');
74 figure(h);
75 disp(' ')
76 disp('Press any key to continue');
77 pause; clc;
78
79 disp('Now we adapt the parameters of the mixture model iteratively using the')
80 disp('EM algorithm. Each cycle of the EM algorithm consists of an E-step')
81 disp('followed by an M-step. We start with the E-step, which involves the')
82 disp('evaluation of the posterior probabilities (responsibilities) which the')
83 disp('two components have for each of the data points.')
84 disp(' ')
85 disp('Since we have labelled the two components using the colours red and')
86 disp('blue, a convenient way to indicate the value of a posterior')
87 disp('probability for a given data point is to colour the point using a')
88 disp('scale ranging from pure red (corresponding to a posterior probability')
89 disp('of 1.0 for the red component and 0.0 for the blue component) through')
90 disp('to pure blue.')
91 disp(' ')
92 disp('Press any key to see the result of applying the first E-step.')
93 pause;
94
95 % Initial E-step.
96 set(ht, 'string', 'E-step');
97 post = gmmpost(mix, x);
98 dcols = [post(:,1), zeros(ndata, 1), post(:,2)];
99 delete(hd);
100 for i = 1 : ndata
101 hd(i) = plot(x(i, 1), x(i, 2), 'color', dcols(i,:), ...
102 'marker', '.', 'markersize', 30);
103 end
104 figure(h);
105
106 disp(' ');
107 disp('Press any key to continue')
108 pause; clc;
109
110 disp('Next we perform the corresponding M-step. This involves replacing the')
111 disp('centres of the component Gaussians by the corresponding weighted means')
112 disp('of the data. Thus the centre of the red component is replaced by the')
113 disp('mean of the data set, in which each data point is weighted according to')
114 disp('the amount of red ink (corresponding to the responsibility of')
115 disp('component 1 for explaining that data point). The variances and mixing')
116 disp('proportions of the two components are similarly re-estimated.')
117 disp(' ')
118 disp('Press any key to see the result of applying the first M-step.')
119 pause;
120
121 % M-step.
122 set(ht, 'string', 'M-step');
123 options = foptions;
124 options(14) = 1; % A single iteration
125 options(1) = -1; % Switch off all messages, including warning
126 mix = gmmem(mix, x, options);
127 delete(hc);
128 xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs;
129 yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys;
130 hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r');
131 hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b');
132 figure(h);
133 disp(' ')
134 disp('Press any key to continue')
135 pause; clc;
136
137 disp('We can continue making alternate E and M steps until the changes in')
138 disp('the log likelihood at each cycle become sufficiently small.')
139 disp(' ')
140 disp('Press any key to see an animation of a further 9 EM cycles.')
141 pause;
142 figure(h);
143
144 % Loop over EM iterations.
145 numiters = 9;
146 for n = 1 : numiters
147
148 set(ht, 'string', 'E-step');
149 post = gmmpost(mix, x);
150 dcols = [post(:,1), zeros(ndata, 1), post(:,2)];
151 delete(hd);
152 for i = 1 : ndata
153 hd(i) = plot(x(i, 1), x(i, 2), 'color', dcols(i,:), ...
154 'marker', '.', 'markersize', 30);
155 end
156 pause(1)
157
158 set(ht, 'string', 'M-step');
159 [mix, options] = gmmem(mix, x, options);
160 fprintf(1, 'Cycle %4d Error %11.6f\n', n, options(8));
161 delete(hc);
162 xvals = mix.centres(:, 1)*ones(1,ncirc) + sqrt(mix.covars')*xs;
163 yvals = mix.centres(:, 2)*ones(1,ncirc) + sqrt(mix.covars')*ys;
164 hc(1)=line(xvals(1,:), yvals(1,:), 'color', 'r');
165 hc(2)=line(xvals(2,:), yvals(2,:), 'color', 'b');
166 pause(1)
167
168 end
169
170 disp(' ');
171 disp('Press any key to end.')
172 pause; clc; close(h); clear all
173