annotate toolboxes/FullBNT-1.0.7/netlab3.3/gmminit.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function mix = gmminit(mix, x, options)
wolffd@0 2 %GMMINIT Initialises Gaussian mixture model from data
wolffd@0 3 %
wolffd@0 4 % Description
wolffd@0 5 % MIX = GMMINIT(MIX, X, OPTIONS) uses a dataset X to initialise the
wolffd@0 6 % parameters of a Gaussian mixture model defined by the data structure
wolffd@0 7 % MIX. The k-means algorithm is used to determine the centres. The
wolffd@0 8 % priors are computed from the proportion of examples belonging to each
wolffd@0 9 % cluster. The covariance matrices are calculated as the sample
wolffd@0 10 % covariance of the points associated with (i.e. closest to) the
wolffd@0 11 % corresponding centres. For a mixture of PPCA model, the PPCA
wolffd@0 12 % decomposition is calculated for the points closest to a given centre.
wolffd@0 13 % This initialisation can be used as the starting point for training
wolffd@0 14 % the model using the EM algorithm.
wolffd@0 15 %
wolffd@0 16 % See also
wolffd@0 17 % GMM
wolffd@0 18 %
wolffd@0 19
wolffd@0 20 % Copyright (c) Ian T Nabney (1996-2001)
wolffd@0 21
wolffd@0 22 [ndata, xdim] = size(x);
wolffd@0 23
wolffd@0 24 % Check that inputs are consistent
wolffd@0 25 errstring = consist(mix, 'gmm', x);
wolffd@0 26 if ~isempty(errstring)
wolffd@0 27 error(errstring);
wolffd@0 28 end
wolffd@0 29
wolffd@0 30 % Arbitrary width used if variance collapses to zero: make it 'large' so
wolffd@0 31 % that centre is responsible for a reasonable number of points.
wolffd@0 32 GMM_WIDTH = 1.0;
wolffd@0 33
wolffd@0 34 % Use kmeans algorithm to set centres
wolffd@0 35 options(5) = 1;
wolffd@0 36 [mix.centres, options, post] = kmeansNetlab(mix.centres, x, options);
wolffd@0 37
wolffd@0 38 % Set priors depending on number of points in each cluster
wolffd@0 39 cluster_sizes = max(sum(post, 1), 1); % Make sure that no prior is zero
wolffd@0 40 mix.priors = cluster_sizes/sum(cluster_sizes); % Normalise priors
wolffd@0 41
wolffd@0 42 switch mix.covar_type
wolffd@0 43 case 'spherical'
wolffd@0 44 if mix.ncentres > 1
wolffd@0 45 % Determine widths as distance to nearest centre
wolffd@0 46 % (or a constant if this is zero)
wolffd@0 47 cdist = dist2(mix.centres, mix.centres);
wolffd@0 48 cdist = cdist + diag(ones(mix.ncentres, 1)*realmax);
wolffd@0 49 mix.covars = min(cdist);
wolffd@0 50 mix.covars = mix.covars + GMM_WIDTH*(mix.covars < eps);
wolffd@0 51 else
wolffd@0 52 % Just use variance of all data points averaged over all
wolffd@0 53 % dimensions
wolffd@0 54 mix.covars = mean(diag(cov(x)));
wolffd@0 55 end
wolffd@0 56 case 'diag'
wolffd@0 57 for j = 1:mix.ncentres
wolffd@0 58 % Pick out data points belonging to this centre
wolffd@0 59 c = x(find(post(:, j)),:);
wolffd@0 60 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
wolffd@0 61 mix.covars(j, :) = sum((diffs.*diffs), 1)/size(c, 1);
wolffd@0 62 % Replace small entries by GMM_WIDTH value
wolffd@0 63 mix.covars(j, :) = mix.covars(j, :) + GMM_WIDTH.*(mix.covars(j, :)<eps);
wolffd@0 64 end
wolffd@0 65 case 'full'
wolffd@0 66 for j = 1:mix.ncentres
wolffd@0 67 % Pick out data points belonging to this centre
wolffd@0 68 c = x(find(post(:, j)),:);
wolffd@0 69 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
wolffd@0 70 mix.covars(:,:,j) = (diffs'*diffs)/(size(c, 1));
wolffd@0 71 % Add GMM_WIDTH*Identity to rank-deficient covariance matrices
wolffd@0 72 if rank(mix.covars(:,:,j)) < mix.nin
wolffd@0 73 mix.covars(:,:,j) = mix.covars(:,:,j) + GMM_WIDTH.*eye(mix.nin);
wolffd@0 74 end
wolffd@0 75 end
wolffd@0 76 case 'ppca'
wolffd@0 77 for j = 1:mix.ncentres
wolffd@0 78 % Pick out data points belonging to this centre
wolffd@0 79 c = x(find(post(:,j)),:);
wolffd@0 80 diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
wolffd@0 81 [tempcovars, tempU, templambda] = ...
wolffd@0 82 ppca((diffs'*diffs)/size(c, 1), mix.ppca_dim);
wolffd@0 83 if length(templambda) ~= mix.ppca_dim
wolffd@0 84 error('Unable to extract enough components');
wolffd@0 85 else
wolffd@0 86 mix.covars(j) = tempcovars;
wolffd@0 87 mix.U(:, :, j) = tempU;
wolffd@0 88 mix.lambda(j, :) = templambda;
wolffd@0 89 end
wolffd@0 90 end
wolffd@0 91 otherwise
wolffd@0 92 error(['Unknown covariance type ', mix.covar_type]);
wolffd@0 93 end
wolffd@0 94