Mercurial > hg > camir-aes2014
diff toolboxes/MIRtoolbox1.3.2/MIRToolbox/netlabgmminit.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/MIRtoolbox1.3.2/MIRToolbox/netlabgmminit.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,95 @@ +function mix = netlabgmminit(mix, x, options) +%GMMINIT Initialises Gaussian mixture model from data +%(Renamed NETLABGMMINIT in MIRtoolbox to avoid conflict with statistics +%toolbox) +% Description +% MIX = GMMINIT(MIX, X, OPTIONS) uses a dataset X to initialise the +% parameters of a Gaussian mixture model defined by the data structure +% MIX. The k-means algorithm is used to determine the centres. The +% priors are computed from the proportion of examples belonging to each +% cluster. The covariance matrices are calculated as the sample +% covariance of the points associated with (i.e. closest to) the +% corresponding centres. For a mixture of PPCA model, the PPCA +% decomposition is calculated for the points closest to a given centre. +% This initialisation can be used as the starting point for training +% the model using the EM algorithm. +% +% See also +% GMM +% + +% Copyright (c) Ian T Nabney (1996-2001) + +[ndata, xdim] = size(x); + +% Check that inputs are consistent +errstring = consist(mix, 'gmm', x); +if ~isempty(errstring) + error(errstring); +end + +% Arbitrary width used if variance collapses to zero: make it 'large' so +% that centre is responsible for a reasonable number of points. +GMM_WIDTH = 1.0; + +% Use kmeans algorithm to set centres +options(5) = 1; +[mix.centres, options, post] = netlabkmeans(mix.centres, x, options); + +% Set priors depending on number of points in each cluster +cluster_sizes = max(sum(post, 1), 1); % Make sure that no prior is zero +mix.priors = cluster_sizes/sum(cluster_sizes); % Normalise priors + +switch mix.covar_type +case 'spherical' + if mix.ncentres > 1 + % Determine widths as distance to nearest centre + % (or a constant if this is zero) + cdist = dist2(mix.centres, mix.centres); + cdist = cdist + diag(ones(mix.ncentres, 1)*realmax); + mix.covars = min(cdist); + mix.covars = mix.covars + GMM_WIDTH*(mix.covars < eps); + else + % Just use variance of all data points averaged over all + % dimensions + mix.covars = mean(diag(cov(x))); + end + case 'diag' + for j = 1:mix.ncentres + % Pick out data points belonging to this centre + c = x(find(post(:, j)),:); + diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :)); + mix.covars(j, :) = sum((diffs.*diffs), 1)/size(c, 1); + % Replace small entries by GMM_WIDTH value + mix.covars(j, :) = mix.covars(j, :) + GMM_WIDTH.*(mix.covars(j, :)<eps); + end + case 'full' + for j = 1:mix.ncentres + % Pick out data points belonging to this centre + c = x(find(post(:, j)),:); + diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :)); + mix.covars(:,:,j) = (diffs'*diffs)/(size(c, 1)); + % Add GMM_WIDTH*Identity to rank-deficient covariance matrices + if rank(mix.covars(:,:,j)) < mix.nin + mix.covars(:,:,j) = mix.covars(:,:,j) + GMM_WIDTH.*eye(mix.nin); + end + end + case 'ppca' + for j = 1:mix.ncentres + % Pick out data points belonging to this centre + c = x(find(post(:,j)),:); + diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :)); + [tempcovars, tempU, templambda] = ... + ppca((diffs'*diffs)/size(c, 1), mix.ppca_dim); + if length(templambda) ~= mix.ppca_dim + error('Unable to extract enough components'); + else + mix.covars(j) = tempcovars; + mix.U(:, :, j) = tempU; + mix.lambda(j, :) = templambda; + end + end + otherwise + error(['Unknown covariance type ', mix.covar_type]); +end +