wolffd@0
|
1 %DEMGMM1 Demonstrate density modelling with a Gaussian mixture model.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % The problem consists of modelling data generated by a mixture of
|
wolffd@0
|
5 % three Gaussians in 2 dimensions. The priors are 0.3, 0.5 and 0.2;
|
wolffd@0
|
6 % the centres are (2, 3.5), (0, 0) and (0,2); the variances are 0.2,
|
wolffd@0
|
7 % 0.5 and 1.0. The first figure contains a scatter plot of the data.
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % A Gaussian mixture model with three components is trained using EM.
|
wolffd@0
|
10 % The parameter vector is printed before training and after training.
|
wolffd@0
|
11 % The user should press any key to continue at these points. The
|
wolffd@0
|
12 % parameter vector consists of priors (the column), centres (given as
|
wolffd@0
|
13 % (x, y) pairs as the next two columns), and variances (the last
|
wolffd@0
|
14 % column).
|
wolffd@0
|
15 %
|
wolffd@0
|
16 % The second figure is a 3 dimensional view of the density function,
|
wolffd@0
|
17 % while the third shows the 1-standard deviation circles for the three
|
wolffd@0
|
18 % components of the mixture model.
|
wolffd@0
|
19 %
|
wolffd@0
|
20 % See also
|
wolffd@0
|
21 % GMM, GMMINIT, GMMEM, GMMPROB, GMMUNPAK
|
wolffd@0
|
22 %
|
wolffd@0
|
23
|
wolffd@0
|
24 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
25
|
wolffd@0
|
26 % Generate the data
|
wolffd@0
|
27 % Fix seeds for reproducible results
|
wolffd@0
|
28 randn('state', 42);
|
wolffd@0
|
29 rand('state', 42);
|
wolffd@0
|
30
|
wolffd@0
|
31 ndata = 500;
|
wolffd@0
|
32 [data, datac, datap, datasd] = dem2ddat(ndata);
|
wolffd@0
|
33
|
wolffd@0
|
34 clc
|
wolffd@0
|
35 disp('This demonstration illustrates the use of a Gaussian mixture model')
|
wolffd@0
|
36 disp('to approximate the unconditional probability density of data in')
|
wolffd@0
|
37 disp('a two-dimensional space. We begin by generating the data from')
|
wolffd@0
|
38 disp('a mixture of three Gaussians and plotting it.')
|
wolffd@0
|
39 disp(' ')
|
wolffd@0
|
40 disp('Press any key to continue')
|
wolffd@0
|
41 pause
|
wolffd@0
|
42
|
wolffd@0
|
43 fh1 = figure;
|
wolffd@0
|
44 plot(data(:, 1), data(:, 2), 'o')
|
wolffd@0
|
45 set(gca, 'Box', 'on')
|
wolffd@0
|
46 % Set up mixture model
|
wolffd@0
|
47 ncentres = 3;
|
wolffd@0
|
48 input_dim = 2;
|
wolffd@0
|
49 mix = gmm(input_dim, ncentres, 'spherical');
|
wolffd@0
|
50
|
wolffd@0
|
51 options = foptions;
|
wolffd@0
|
52 options(14) = 5; % Just use 5 iterations of k-means in initialisation
|
wolffd@0
|
53 % Initialise the model parameters from the data
|
wolffd@0
|
54 mix = gmminit(mix, data, options);
|
wolffd@0
|
55
|
wolffd@0
|
56 clc
|
wolffd@0
|
57 disp('The data is drawn from a mixture with parameters')
|
wolffd@0
|
58 disp(' Priors Centres Variances')
|
wolffd@0
|
59 disp([datap' datac (datasd.^2)'])
|
wolffd@0
|
60 disp(' ')
|
wolffd@0
|
61 disp('The mixture model has three components and spherical covariance')
|
wolffd@0
|
62 disp('matrices. The model parameters after initialisation using the')
|
wolffd@0
|
63 disp('k-means algorithm are as follows')
|
wolffd@0
|
64 % Print out model
|
wolffd@0
|
65 disp(' Priors Centres Variances')
|
wolffd@0
|
66 disp([mix.priors' mix.centres mix.covars'])
|
wolffd@0
|
67 disp('Press any key to continue')
|
wolffd@0
|
68 pause
|
wolffd@0
|
69
|
wolffd@0
|
70 % Set up vector of options for EM trainer
|
wolffd@0
|
71 options = zeros(1, 18);
|
wolffd@0
|
72 options(1) = 1; % Prints out error values.
|
wolffd@0
|
73 options(14) = 10; % Max. Number of iterations.
|
wolffd@0
|
74
|
wolffd@0
|
75 disp('We now train the model using the EM algorithm for 10 iterations')
|
wolffd@0
|
76 disp(' ')
|
wolffd@0
|
77 disp('Press any key to continue')
|
wolffd@0
|
78 pause
|
wolffd@0
|
79 [mix, options, errlog] = gmmem(mix, data, options);
|
wolffd@0
|
80
|
wolffd@0
|
81 % Print out model
|
wolffd@0
|
82 disp(' ')
|
wolffd@0
|
83 disp('The trained model has parameters ')
|
wolffd@0
|
84 disp(' Priors Centres Variances')
|
wolffd@0
|
85 disp([mix.priors' mix.centres mix.covars'])
|
wolffd@0
|
86 disp('Note the close correspondence between these parameters and those')
|
wolffd@0
|
87 disp('of the distribution used to generate the data, which are repeated here.')
|
wolffd@0
|
88 disp(' Priors Centres Variances')
|
wolffd@0
|
89 disp([datap' datac (datasd.^2)'])
|
wolffd@0
|
90 disp(' ')
|
wolffd@0
|
91 disp('Press any key to continue')
|
wolffd@0
|
92 pause
|
wolffd@0
|
93
|
wolffd@0
|
94 clc
|
wolffd@0
|
95 disp('We now plot the density given by the mixture model as a surface plot')
|
wolffd@0
|
96 disp(' ')
|
wolffd@0
|
97 disp('Press any key to continue')
|
wolffd@0
|
98 pause
|
wolffd@0
|
99 % Plot the result
|
wolffd@0
|
100 x = -4.0:0.2:5.0;
|
wolffd@0
|
101 y = -4.0:0.2:5.0;
|
wolffd@0
|
102 [X, Y] = meshgrid(x,y);
|
wolffd@0
|
103 X = X(:);
|
wolffd@0
|
104 Y = Y(:);
|
wolffd@0
|
105 grid = [X Y];
|
wolffd@0
|
106 Z = gmmprob(mix, grid);
|
wolffd@0
|
107 Z = reshape(Z, length(x), length(y));
|
wolffd@0
|
108 c = mesh(x, y, Z);
|
wolffd@0
|
109 hold on
|
wolffd@0
|
110 title('Surface plot of probability density')
|
wolffd@0
|
111 hold off
|
wolffd@0
|
112
|
wolffd@0
|
113 clc
|
wolffd@0
|
114 disp('The final plot shows the centres and widths, given by one standard')
|
wolffd@0
|
115 disp('deviation, of the three components of the mixture model.')
|
wolffd@0
|
116 disp(' ')
|
wolffd@0
|
117 disp('Press any key to continue.')
|
wolffd@0
|
118 pause
|
wolffd@0
|
119 % Try to calculate a sensible position for the second figure, below the first
|
wolffd@0
|
120 fig1_pos = get(fh1, 'Position');
|
wolffd@0
|
121 fig2_pos = fig1_pos;
|
wolffd@0
|
122 fig2_pos(2) = fig2_pos(2) - fig1_pos(4);
|
wolffd@0
|
123 fh2 = figure;
|
wolffd@0
|
124 set(fh2, 'Position', fig2_pos)
|
wolffd@0
|
125
|
wolffd@0
|
126 hp1 = plot(data(:, 1), data(:, 2), 'bo');
|
wolffd@0
|
127 axis('equal');
|
wolffd@0
|
128 hold on
|
wolffd@0
|
129 hp2 = plot(mix.centres(:, 1), mix.centres(:,2), 'g+');
|
wolffd@0
|
130 set(hp2, 'MarkerSize', 10);
|
wolffd@0
|
131 set(hp2, 'LineWidth', 3);
|
wolffd@0
|
132
|
wolffd@0
|
133 title('Plot of data and mixture centres')
|
wolffd@0
|
134 angles = 0:pi/30:2*pi;
|
wolffd@0
|
135 for i = 1 : mix.ncentres
|
wolffd@0
|
136 x_circle = mix.centres(i,1)*ones(1, length(angles)) + ...
|
wolffd@0
|
137 sqrt(mix.covars(i))*cos(angles);
|
wolffd@0
|
138 y_circle = mix.centres(i,2)*ones(1, length(angles)) + ...
|
wolffd@0
|
139 sqrt(mix.covars(i))*sin(angles);
|
wolffd@0
|
140 plot(x_circle, y_circle, 'r')
|
wolffd@0
|
141 end
|
wolffd@0
|
142 hold off
|
wolffd@0
|
143 disp('Note how the data cluster positions and widths are captured by')
|
wolffd@0
|
144 disp('the mixture model.')
|
wolffd@0
|
145 disp(' ')
|
wolffd@0
|
146 disp('Press any key to end.')
|
wolffd@0
|
147 pause
|
wolffd@0
|
148
|
wolffd@0
|
149 close(fh1);
|
wolffd@0
|
150 close(fh2);
|
wolffd@0
|
151 clear all;
|
wolffd@0
|
152
|