wolffd@0
|
1 function mix = netlabgmminit(mix, x, options)
|
wolffd@0
|
2 %GMMINIT Initialises Gaussian mixture model from data
|
wolffd@0
|
3 %(Renamed NETLABGMMINIT in MIRtoolbox to avoid conflict with statistics
|
wolffd@0
|
4 %toolbox)
|
wolffd@0
|
5 % Description
|
wolffd@0
|
6 % MIX = GMMINIT(MIX, X, OPTIONS) uses a dataset X to initialise the
|
wolffd@0
|
7 % parameters of a Gaussian mixture model defined by the data structure
|
wolffd@0
|
8 % MIX. The k-means algorithm is used to determine the centres. The
|
wolffd@0
|
9 % priors are computed from the proportion of examples belonging to each
|
wolffd@0
|
10 % cluster. The covariance matrices are calculated as the sample
|
wolffd@0
|
11 % covariance of the points associated with (i.e. closest to) the
|
wolffd@0
|
12 % corresponding centres. For a mixture of PPCA model, the PPCA
|
wolffd@0
|
13 % decomposition is calculated for the points closest to a given centre.
|
wolffd@0
|
14 % This initialisation can be used as the starting point for training
|
wolffd@0
|
15 % the model using the EM algorithm.
|
wolffd@0
|
16 %
|
wolffd@0
|
17 % See also
|
wolffd@0
|
18 % GMM
|
wolffd@0
|
19 %
|
wolffd@0
|
20
|
wolffd@0
|
21 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
22
|
wolffd@0
|
23 [ndata, xdim] = size(x);
|
wolffd@0
|
24
|
wolffd@0
|
25 % Check that inputs are consistent
|
wolffd@0
|
26 errstring = consist(mix, 'gmm', x);
|
wolffd@0
|
27 if ~isempty(errstring)
|
wolffd@0
|
28 error(errstring);
|
wolffd@0
|
29 end
|
wolffd@0
|
30
|
wolffd@0
|
31 % Arbitrary width used if variance collapses to zero: make it 'large' so
|
wolffd@0
|
32 % that centre is responsible for a reasonable number of points.
|
wolffd@0
|
33 GMM_WIDTH = 1.0;
|
wolffd@0
|
34
|
wolffd@0
|
35 % Use kmeans algorithm to set centres
|
wolffd@0
|
36 options(5) = 1;
|
wolffd@0
|
37 [mix.centres, options, post] = netlabkmeans(mix.centres, x, options);
|
wolffd@0
|
38
|
wolffd@0
|
39 % Set priors depending on number of points in each cluster
|
wolffd@0
|
40 cluster_sizes = max(sum(post, 1), 1); % Make sure that no prior is zero
|
wolffd@0
|
41 mix.priors = cluster_sizes/sum(cluster_sizes); % Normalise priors
|
wolffd@0
|
42
|
wolffd@0
|
43 switch mix.covar_type
|
wolffd@0
|
44 case 'spherical'
|
wolffd@0
|
45 if mix.ncentres > 1
|
wolffd@0
|
46 % Determine widths as distance to nearest centre
|
wolffd@0
|
47 % (or a constant if this is zero)
|
wolffd@0
|
48 cdist = dist2(mix.centres, mix.centres);
|
wolffd@0
|
49 cdist = cdist + diag(ones(mix.ncentres, 1)*realmax);
|
wolffd@0
|
50 mix.covars = min(cdist);
|
wolffd@0
|
51 mix.covars = mix.covars + GMM_WIDTH*(mix.covars < eps);
|
wolffd@0
|
52 else
|
wolffd@0
|
53 % Just use variance of all data points averaged over all
|
wolffd@0
|
54 % dimensions
|
wolffd@0
|
55 mix.covars = mean(diag(cov(x)));
|
wolffd@0
|
56 end
|
wolffd@0
|
57 case 'diag'
|
wolffd@0
|
58 for j = 1:mix.ncentres
|
wolffd@0
|
59 % Pick out data points belonging to this centre
|
wolffd@0
|
60 c = x(find(post(:, j)),:);
|
wolffd@0
|
61 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
|
wolffd@0
|
62 mix.covars(j, :) = sum((diffs.*diffs), 1)/size(c, 1);
|
wolffd@0
|
63 % Replace small entries by GMM_WIDTH value
|
wolffd@0
|
64 mix.covars(j, :) = mix.covars(j, :) + GMM_WIDTH.*(mix.covars(j, :)<eps);
|
wolffd@0
|
65 end
|
wolffd@0
|
66 case 'full'
|
wolffd@0
|
67 for j = 1:mix.ncentres
|
wolffd@0
|
68 % Pick out data points belonging to this centre
|
wolffd@0
|
69 c = x(find(post(:, j)),:);
|
wolffd@0
|
70 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
|
wolffd@0
|
71 mix.covars(:,:,j) = (diffs'*diffs)/(size(c, 1));
|
wolffd@0
|
72 % Add GMM_WIDTH*Identity to rank-deficient covariance matrices
|
wolffd@0
|
73 if rank(mix.covars(:,:,j)) < mix.nin
|
wolffd@0
|
74 mix.covars(:,:,j) = mix.covars(:,:,j) + GMM_WIDTH.*eye(mix.nin);
|
wolffd@0
|
75 end
|
wolffd@0
|
76 end
|
wolffd@0
|
77 case 'ppca'
|
wolffd@0
|
78 for j = 1:mix.ncentres
|
wolffd@0
|
79 % Pick out data points belonging to this centre
|
wolffd@0
|
80 c = x(find(post(:,j)),:);
|
wolffd@0
|
81 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
|
wolffd@0
|
82 [tempcovars, tempU, templambda] = ...
|
wolffd@0
|
83 ppca((diffs'*diffs)/size(c, 1), mix.ppca_dim);
|
wolffd@0
|
84 if length(templambda) ~= mix.ppca_dim
|
wolffd@0
|
85 error('Unable to extract enough components');
|
wolffd@0
|
86 else
|
wolffd@0
|
87 mix.covars(j) = tempcovars;
|
wolffd@0
|
88 mix.U(:, :, j) = tempU;
|
wolffd@0
|
89 mix.lambda(j, :) = templambda;
|
wolffd@0
|
90 end
|
wolffd@0
|
91 end
|
wolffd@0
|
92 otherwise
|
wolffd@0
|
93 error(['Unknown covariance type ', mix.covar_type]);
|
wolffd@0
|
94 end
|
wolffd@0
|
95
|