comparison toolboxes/FullBNT-1.0.7/netlab3.3/demgmm4.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMGMM4 Demonstrate density modelling with a Gaussian mixture model.
2 %
3 % Description
4 % The problem consists of modelling data generated by a mixture of
5 % three Gaussians in 2 dimensions with a mixture model using full
6 % covariance matrices. The priors are 0.3, 0.5 and 0.2; the centres
7 % are (2, 3.5), (0, 0) and (0,2); the variances are (0.16, 0.64) axis
8 % aligned, (0.25, 1) rotated by 30 degrees and the identity matrix. The
9 % first figure contains a scatter plot of the data.
10 %
11 % A Gaussian mixture model with three components is trained using EM.
12 % The parameter vector is printed before training and after training.
13 % The user should press any key to continue at these points. The
14 % parameter vector consists of priors (the column), and centres (given
15 % as (x, y) pairs as the next two columns). The covariance matrices
16 % are printed separately.
17 %
18 % The second figure is a 3 dimensional view of the density function,
19 % while the third shows the axes of the 1-standard deviation ellipses
20 % for the three components of the mixture model.
21 %
22 % See also
24 %
26 % Copyright (c) Ian T Nabney (1996-2001)
29 % Generate the data
31 ndata = 500;
33 % Fix the seeds for reproducible results
34 randn('state', 42);
35 rand('state', 42);
36 data = randn(ndata, 2);
37 prior = [0.3 0.5 0.2];
38 % Mixture model swaps clusters 1 and 3
39 datap = [0.2 0.5 0.3];
40 datac = [0 2; 0 0; 2 3.5];
41 datacov = repmat(eye(2), [1 1 3]);
42 data1 = data(1:prior(1)*ndata,:);
43 data2 = data(prior(1)*ndata+1:(prior(2)+prior(1))*ndata, :);
44 data3 = data((prior(1)+prior(2))*ndata +1:ndata, :);
46 % First cluster has axis aligned variance and centre (2, 3.5)
47 data1(:, 1) = data1(:, 1)*0.4 + 2.0;
48 data1(:, 2) = data1(:, 2)*0.8 + 3.5;
49 datacov(:, :, 3) = [0.4*0.4 0; 0 0.8*0.8];
51 % Second cluster has variance axes rotated by 30 degrees and centre (0, 0)
52 rotn = [cos(pi/6) -sin(pi/6); sin(pi/6) cos(pi/6)];
53 data2(:,1) = data2(:, 1)*0.5;
54 data2 = data2*rotn;
55 datacov(:, :, 2) = rotn' * [0.25 0; 0 1] * rotn;
57 % Third cluster is at (0,2)
58 data3 = data3 + repmat([0 2], prior(3)*ndata, 1);
60 % Put the dataset together again
61 data = [data1; data2; data3];
63 clc
64 disp('This demonstration illustrates the use of a Gaussian mixture model')
65 disp('with full covariance matrices to approximate the unconditional ')
66 disp('probability density of data in a two-dimensional space.')
67 disp('We begin by generating the data from a mixture of three Gaussians and')
68 disp('plotting it.')
69 disp(' ')
70 disp('The first cluster has axis aligned variance and centre (0, 2).')
71 disp('The second cluster has variance axes rotated by 30 degrees')
72 disp('and centre (0, 0). The third cluster has unit variance and centre')
73 disp('(2, 3.5).')
74 disp(' ')
75 disp('Press any key to continue.')
76 pause
78 fh1 = figure;
79 plot(data(:, 1), data(:, 2), 'o')
80 set(gca, 'Box', 'on')
82 % Set up mixture model
83 ncentres = 3;
84 input_dim = 2;
85 mix = gmm(input_dim, ncentres, 'full');
87 % Initialise the model parameters from the data
88 options = foptions;
89 options(14) = 5; % Just use 5 iterations of k-means in initialisation
90 mix = gmminit(mix, data, options);
92 % Print out model
93 clc
94 disp('The mixture model has three components and full covariance')
95 disp('matrices. The model parameters after initialisation using the')
96 disp('k-means algorithm are as follows')
97 disp(' Priors Centres')
98 disp([mix.priors' mix.centres])
99 disp('Covariance matrices are')
100 disp(mix.covars)
101 disp('Press any key to continue.')
102 pause
104 % Set up vector of options for EM trainer
105 options = zeros(1, 18);
106 options(1) = 1; % Prints out error values.
107 options(14) = 50; % Number of iterations.
109 disp('We now train the model using the EM algorithm for 50 iterations.')
110 disp(' ')
111 disp('Press any key to continue.')
112 pause
113 [mix, options, errlog] = gmmem(mix, data, options);
115 % Print out model
116 disp(' ')
117 disp('The trained model has priors and centres:')
118 disp(' Priors Centres')
119 disp([mix.priors' mix.centres])
120 disp('The data generator has priors and centres')
121 disp(' Priors Centres')
122 disp([datap' datac])
123 disp('Model covariance matrices are')
124 disp(mix.covars(:, :, 1))
125 disp(mix.covars(:, :, 2))
126 disp(mix.covars(:, :, 3))
127 disp('Data generator covariance matrices are')
128 disp(datacov(:, :, 1))
129 disp(datacov(:, :, 2))
130 disp(datacov(:, :, 3))
131 disp('Note the close correspondence between these parameters and those')
132 disp('of the distribution used to generate the data. The match for')
133 disp('covariance matrices is not that close, but would be improved with')
134 disp('more iterations of the training algorithm.')
135 disp(' ')
136 disp('Press any key to continue.')
137 pause
139 clc
140 disp('We now plot the density given by the mixture model as a surface plot.')
141 disp(' ')
142 disp('Press any key to continue.')
143 pause
145 % Plot the result
146 x = -4.0:0.2:5.0;
147 y = -4.0:0.2:5.0;
148 [X, Y] = meshgrid(x,y);
149 X = X(:);
150 Y = Y(:);
151 grid = [X Y];
152 Z = gmmprob(mix, grid);
153 Z = reshape(Z, length(x), length(y));
154 c = mesh(x, y, Z);
155 hold on
156 title('Surface plot of probability density')
157 hold off
158 drawnow
160 clc
161 disp('The final plot shows the centres and widths, given by one standard')
162 disp('deviation, of the three components of the mixture model. The axes')
163 disp('of the ellipses of constant density are shown.')
164 disp(' ')
165 disp('Press any key to continue.')
166 pause
168 % Try to calculate a sensible position for the second figure, below the first
169 fig1_pos = get(fh1, 'Position');
170 fig2_pos = fig1_pos;
171 fig2_pos(2) = fig2_pos(2) - fig1_pos(4) - 30;
172 fh2 = figure('Position', fig2_pos);
174 h3 = plot(data(:, 1), data(:, 2), 'bo');
175 axis equal;
176 hold on
177 title('Plot of data and covariances')
178 for i = 1:ncentres
179 [v,d] = eig(mix.covars(:,:,i));
180 for j = 1:2
181 % Ensure that eigenvector has unit length
182 v(:,j) = v(:,j)/norm(v(:,j));
183 start=mix.centres(i,:)-sqrt(d(j,j))*(v(:,j)');
184 endpt=mix.centres(i,:)+sqrt(d(j,j))*(v(:,j)');
185 linex = [start(1) endpt(1)];
186 liney = [start(2) endpt(2)];
187 line(linex, liney, 'Color', 'k', 'LineWidth', 3)
188 end
189 % Plot ellipses of one standard deviation
190 theta = 0:0.02:2*pi;
191 x = sqrt(d(1,1))*cos(theta);
192 y = sqrt(d(2,2))*sin(theta);
193 % Rotate ellipse axes
194 ellipse = (v*([x; y]))';
195 % Adjust centre
196 ellipse = ellipse + ones(length(theta), 1)*mix.centres(i,:);
197 plot(ellipse(:,1), ellipse(:,2), 'r-');
198 end
199 hold off
201 disp('Note how the data cluster positions and widths are captured by')
202 disp('the mixture model.')
203 disp(' ')
204 disp('Press any key to end.')
205 pause
207 close(fh1);
208 close(fh2);
209 clear all;