Mercurial > hg > camir-aes2014
comparison toolboxes/MIRtoolbox1.3.2/MIRToolbox/netlabgmminit.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function mix = netlabgmminit(mix, x, options) | |
2 %GMMINIT Initialises Gaussian mixture model from data | |
3 %(Renamed NETLABGMMINIT in MIRtoolbox to avoid conflict with statistics | |
4 %toolbox) | |
5 % Description | |
6 % MIX = GMMINIT(MIX, X, OPTIONS) uses a dataset X to initialise the | |
7 % parameters of a Gaussian mixture model defined by the data structure | |
8 % MIX. The k-means algorithm is used to determine the centres. The | |
9 % priors are computed from the proportion of examples belonging to each | |
10 % cluster. The covariance matrices are calculated as the sample | |
11 % covariance of the points associated with (i.e. closest to) the | |
12 % corresponding centres. For a mixture of PPCA model, the PPCA | |
13 % decomposition is calculated for the points closest to a given centre. | |
14 % This initialisation can be used as the starting point for training | |
15 % the model using the EM algorithm. | |
16 % | |
17 % See also | |
18 % GMM | |
19 % | |
20 | |
21 % Copyright (c) Ian T Nabney (1996-2001) | |
22 | |
23 [ndata, xdim] = size(x); | |
24 | |
25 % Check that inputs are consistent | |
26 errstring = consist(mix, 'gmm', x); | |
27 if ~isempty(errstring) | |
28 error(errstring); | |
29 end | |
30 | |
31 % Arbitrary width used if variance collapses to zero: make it 'large' so | |
32 % that centre is responsible for a reasonable number of points. | |
33 GMM_WIDTH = 1.0; | |
34 | |
35 % Use kmeans algorithm to set centres | |
36 options(5) = 1; | |
37 [mix.centres, options, post] = netlabkmeans(mix.centres, x, options); | |
38 | |
39 % Set priors depending on number of points in each cluster | |
40 cluster_sizes = max(sum(post, 1), 1); % Make sure that no prior is zero | |
41 mix.priors = cluster_sizes/sum(cluster_sizes); % Normalise priors | |
42 | |
43 switch mix.covar_type | |
44 case 'spherical' | |
45 if mix.ncentres > 1 | |
46 % Determine widths as distance to nearest centre | |
47 % (or a constant if this is zero) | |
48 cdist = dist2(mix.centres, mix.centres); | |
49 cdist = cdist + diag(ones(mix.ncentres, 1)*realmax); | |
50 mix.covars = min(cdist); | |
51 mix.covars = mix.covars + GMM_WIDTH*(mix.covars < eps); | |
52 else | |
53 % Just use variance of all data points averaged over all | |
54 % dimensions | |
55 mix.covars = mean(diag(cov(x))); | |
56 end | |
57 case 'diag' | |
58 for j = 1:mix.ncentres | |
59 % Pick out data points belonging to this centre | |
60 c = x(find(post(:, j)),:); | |
61 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :)); | |
62 mix.covars(j, :) = sum((diffs.*diffs), 1)/size(c, 1); | |
63 % Replace small entries by GMM_WIDTH value | |
64 mix.covars(j, :) = mix.covars(j, :) + GMM_WIDTH.*(mix.covars(j, :)<eps); | |
65 end | |
66 case 'full' | |
67 for j = 1:mix.ncentres | |
68 % Pick out data points belonging to this centre | |
69 c = x(find(post(:, j)),:); | |
70 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :)); | |
71 mix.covars(:,:,j) = (diffs'*diffs)/(size(c, 1)); | |
72 % Add GMM_WIDTH*Identity to rank-deficient covariance matrices | |
73 if rank(mix.covars(:,:,j)) < mix.nin | |
74 mix.covars(:,:,j) = mix.covars(:,:,j) + GMM_WIDTH.*eye(mix.nin); | |
75 end | |
76 end | |
77 case 'ppca' | |
78 for j = 1:mix.ncentres | |
79 % Pick out data points belonging to this centre | |
80 c = x(find(post(:,j)),:); | |
81 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :)); | |
82 [tempcovars, tempU, templambda] = ... | |
83 ppca((diffs'*diffs)/size(c, 1), mix.ppca_dim); | |
84 if length(templambda) ~= mix.ppca_dim | |
85 error('Unable to extract enough components'); | |
86 else | |
87 mix.covars(j) = tempcovars; | |
88 mix.U(:, :, j) = tempU; | |
89 mix.lambda(j, :) = templambda; | |
90 end | |
91 end | |
92 otherwise | |
93 error(['Unknown covariance type ', mix.covar_type]); | |
94 end | |
95 |