Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/gmminit.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function mix = gmminit(mix, x, options) | |
2 %GMMINIT Initialises Gaussian mixture model from data | |
3 % | |
4 % Description | |
5 % MIX = GMMINIT(MIX, X, OPTIONS) uses a dataset X to initialise the | |
6 % parameters of a Gaussian mixture model defined by the data structure | |
7 % MIX. The k-means algorithm is used to determine the centres. The | |
8 % priors are computed from the proportion of examples belonging to each | |
9 % cluster. The covariance matrices are calculated as the sample | |
10 % covariance of the points associated with (i.e. closest to) the | |
11 % corresponding centres. For a mixture of PPCA model, the PPCA | |
12 % decomposition is calculated for the points closest to a given centre. | |
13 % This initialisation can be used as the starting point for training | |
14 % the model using the EM algorithm. | |
15 % | |
16 % See also | |
17 % GMM | |
18 % | |
19 | |
20 % Copyright (c) Ian T Nabney (1996-2001) | |
21 | |
22 [ndata, xdim] = size(x); | |
23 | |
24 % Check that inputs are consistent | |
25 errstring = consist(mix, 'gmm', x); | |
26 if ~isempty(errstring) | |
27 error(errstring); | |
28 end | |
29 | |
30 % Arbitrary width used if variance collapses to zero: make it 'large' so | |
31 % that centre is responsible for a reasonable number of points. | |
32 GMM_WIDTH = 1.0; | |
33 | |
34 % Use kmeans algorithm to set centres | |
35 options(5) = 1; | |
36 [mix.centres, options, post] = kmeansNetlab(mix.centres, x, options); | |
37 | |
38 % Set priors depending on number of points in each cluster | |
39 cluster_sizes = max(sum(post, 1), 1); % Make sure that no prior is zero | |
40 mix.priors = cluster_sizes/sum(cluster_sizes); % Normalise priors | |
41 | |
42 switch mix.covar_type | |
43 case 'spherical' | |
44 if mix.ncentres > 1 | |
45 % Determine widths as distance to nearest centre | |
46 % (or a constant if this is zero) | |
47 cdist = dist2(mix.centres, mix.centres); | |
48 cdist = cdist + diag(ones(mix.ncentres, 1)*realmax); | |
49 mix.covars = min(cdist); | |
50 mix.covars = mix.covars + GMM_WIDTH*(mix.covars < eps); | |
51 else | |
52 % Just use variance of all data points averaged over all | |
53 % dimensions | |
54 mix.covars = mean(diag(cov(x))); | |
55 end | |
56 case 'diag' | |
57 for j = 1:mix.ncentres | |
58 % Pick out data points belonging to this centre | |
59 c = x(find(post(:, j)),:); | |
60 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :)); | |
61 mix.covars(j, :) = sum((diffs.*diffs), 1)/size(c, 1); | |
62 % Replace small entries by GMM_WIDTH value | |
63 mix.covars(j, :) = mix.covars(j, :) + GMM_WIDTH.*(mix.covars(j, :)<eps); | |
64 end | |
65 case 'full' | |
66 for j = 1:mix.ncentres | |
67 % Pick out data points belonging to this centre | |
68 c = x(find(post(:, j)),:); | |
69 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :)); | |
70 mix.covars(:,:,j) = (diffs'*diffs)/(size(c, 1)); | |
71 % Add GMM_WIDTH*Identity to rank-deficient covariance matrices | |
72 if rank(mix.covars(:,:,j)) < mix.nin | |
73 mix.covars(:,:,j) = mix.covars(:,:,j) + GMM_WIDTH.*eye(mix.nin); | |
74 end | |
75 end | |
76 case 'ppca' | |
77 for j = 1:mix.ncentres | |
78 % Pick out data points belonging to this centre | |
79 c = x(find(post(:,j)),:); | |
80 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :)); | |
81 [tempcovars, tempU, templambda] = ... | |
82 ppca((diffs'*diffs)/size(c, 1), mix.ppca_dim); | |
83 if length(templambda) ~= mix.ppca_dim | |
84 error('Unable to extract enough components'); | |
85 else | |
86 mix.covars(j) = tempcovars; | |
87 mix.U(:, :, j) = tempU; | |
88 mix.lambda(j, :) = templambda; | |
89 end | |
90 end | |
91 otherwise | |
92 error(['Unknown covariance type ', mix.covar_type]); | |
93 end | |
94 |