comparison toolboxes/FullBNT-1.0.7/netlab3.3/demgmm3.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMGMM3 Demonstrate density modelling with a Gaussian mixture model.
2 %
3 % Description
4 % The problem consists of modelling data generated by a mixture of
5 % three Gaussians in 2 dimensions with a mixture model using diagonal
6 % covariance matrices. The priors are 0.3, 0.5 and 0.2; the centres
7 % are (2, 3.5), (0, 0) and (0,2); the covariances are all axis aligned
8 % (0.16, 0.64), (0.25, 1) and the identity matrix. The first figure
9 % contains a scatter plot of the data.
10 %
11 % A Gaussian mixture model with three components is trained using EM.
12 % The parameter vector is printed before training and after training.
13 % The user should press any key to continue at these points. The
14 % parameter vector consists of priors (the column), and centres (given
15 % as (x, y) pairs as the next two columns). The diagonal entries of
16 % the covariance matrices are printed separately.
17 %
18 % The second figure is a 3 dimensional view of the density function,
19 % while the third shows the axes of the 1-standard deviation circles
20 % for the three components of the mixture model.
21 %
22 % See also
23 % GMM, GMMINIT, GMMEM, GMMPROB, GMMUNPAK
24 %
25
26 % Copyright (c) Ian T Nabney (1996-2001)
27
28 % Generate the data
29 ndata = 500;
30
31 % Fix the seeds for reproducible results
32 randn('state', 42);
33 rand('state', 42);
34 data = randn(ndata, 2);
35 prior = [0.3 0.5 0.2];
36 % Mixture model swaps clusters 1 and 3
37 datap = [0.2 0.5 0.3];
38 datac = [0 2; 0 0; 2 3.5];
39 datacov = [1 1;1 0.25; 0.4*0.4 0.8*0.8];
40 data1 = data(1:prior(1)*ndata,:);
41 data2 = data(prior(1)*ndata+1:(prior(2)+prior(1))*ndata, :);
42 data3 = data((prior(1)+prior(2))*ndata +1:ndata, :);
43
44 % First cluster has axis aligned variance and centre (2, 3.5)
45 data1(:, 1) = data1(:, 1)*0.4 + 2.0;
46 data1(:, 2) = data1(:, 2)*0.8 + 3.5;
47
48 % Second cluster has axis aligned variance and centre (0, 0)
49 data2(:,2) = data2(:, 2)*0.5;
50
51 % Third cluster is at (0,2) with identity matrix for covariance
52 data3 = data3 + repmat([0 2], prior(3)*ndata, 1);
53
54 % Put the dataset together again
55 data = [data1; data2; data3];
56
57 clc
58 disp('This demonstration illustrates the use of a Gaussian mixture model')
59 disp('with diagonal covariance matrices to approximate the unconditional')
60 disp('probability density of data in a two-dimensional space.')
61 disp('We begin by generating the data from a mixture of three Gaussians')
62 disp('with axis aligned covariance structure and plotting it.')
63 disp(' ')
64 disp('The first cluster has centre (0, 2).')
65 disp('The second cluster has centre (0, 0).')
66 disp('The third cluster has centre (2, 3.5).')
67 disp(' ')
68 disp('Press any key to continue')
69 pause
70
71 fh1 = figure;
72 plot(data(:, 1), data(:, 2), 'o')
73 set(gca, 'Box', 'on')
74
75 % Set up mixture model
76 ncentres = 3;
77 input_dim = 2;
78 mix = gmm(input_dim, ncentres, 'diag');
79
80 options = foptions;
81 options(14) = 5; % Just use 5 iterations of k-means in initialisation
82 % Initialise the model parameters from the data
83 mix = gmminit(mix, data, options);
84
85 % Print out model
86 disp('The mixture model has three components and diagonal covariance')
87 disp('matrices. The model parameters after initialisation using the')
88 disp('k-means algorithm are as follows')
89 disp(' Priors Centres')
90 disp([mix.priors' mix.centres])
91 disp('Covariance diagonals are')
92 disp(mix.covars)
93 disp('Press any key to continue.')
94 pause
95
96 % Set up vector of options for EM trainer
97 options = zeros(1, 18);
98 options(1) = 1; % Prints out error values.
99 options(14) = 20; % Number of iterations.
100
101 disp('We now train the model using the EM algorithm for 20 iterations.')
102 disp(' ')
103 disp('Press any key to continue.')
104 pause
105
106 [mix, options, errlog] = gmmem(mix, data, options);
107
108 % Print out model
109 disp(' ')
110 disp('The trained model has priors and centres:')
111 disp(' Priors Centres')
112 disp([mix.priors' mix.centres])
113 disp('The data generator has priors and centres')
114 disp(' Priors Centres')
115 disp([datap' datac])
116 disp('Model covariance diagonals are')
117 disp(mix.covars)
118 disp('Data generator covariance diagonals are')
119 disp(datacov)
120 disp('Note the close correspondence between these parameters and those')
121 disp('of the distribution used to generate the data.')
122 disp(' ')
123 disp('Press any key to continue.')
124 pause
125
126 clc
127 disp('We now plot the density given by the mixture model as a surface plot.')
128 disp(' ')
129 disp('Press any key to continue.')
130 pause
131
132 % Plot the result
133 x = -4.0:0.2:5.0;
134 y = -4.0:0.2:5.0;
135 [X, Y] = meshgrid(x,y);
136 X = X(:);
137 Y = Y(:);
138 grid = [X Y];
139 Z = gmmprob(mix, grid);
140 Z = reshape(Z, length(x), length(y));
141 c = mesh(x, y, Z);
142 hold on
143 title('Surface plot of probability density')
144 hold off
145 drawnow
146
147 clc
148 disp('The final plot shows the centres and widths, given by one standard')
149 disp('deviation, of the three components of the mixture model. The axes')
150 disp('of the ellipses of constant density are shown.')
151 disp(' ')
152 disp('Press any key to continue.')
153 pause
154
155 % Try to calculate a sensible position for the second figure, below the first
156 fig1_pos = get(fh1, 'Position');
157 fig2_pos = fig1_pos;
158 fig2_pos(2) = fig2_pos(2) - fig1_pos(4);
159 fh2 = figure('Position', fig2_pos);
160
161 h = plot(data(:, 1), data(:, 2), 'bo');
162 hold on
163 axis('equal');
164 title('Plot of data and covariances')
165 for i = 1:ncentres
166 v = [1 0];
167 for j = 1:2
168 start=mix.centres(i,:)-sqrt(mix.covars(i,:).*v);
169 endpt=mix.centres(i,:)+sqrt(mix.covars(i,:).*v);
170 linex = [start(1) endpt(1)];
171 liney = [start(2) endpt(2)];
172 line(linex, liney, 'Color', 'k', 'LineWidth', 3)
173 v = [0 1];
174 end
175 % Plot ellipses of one standard deviation
176 theta = 0:0.02:2*pi;
177 x = sqrt(mix.covars(i,1))*cos(theta) + mix.centres(i,1);
178 y = sqrt(mix.covars(i,2))*sin(theta) + mix.centres(i,2);
179 plot(x, y, 'r-');
180 end
181 hold off
182
183 disp('Note how the data cluster positions and widths are captured by')
184 disp('the mixture model.')
185 disp(' ')
186 disp('Press any key to end.')
187 pause
188
189 close(fh1);
190 close(fh2);
191 clear all;
192