diff toolboxes/FullBNT-1.0.7/netlab3.3/gmminit.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/netlab3.3/gmminit.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,94 @@
+function mix = gmminit(mix, x, options)
+%GMMINIT Initialises Gaussian mixture model from data
+%
+%	Description
+%	MIX = GMMINIT(MIX, X, OPTIONS) uses a dataset X to initialise the
+%	parameters of a Gaussian mixture model defined by the data structure
+%	MIX.  The k-means algorithm is used to determine the centres. The
+%	priors are computed from the proportion of examples belonging to each
+%	cluster. The covariance matrices are calculated as the sample
+%	covariance of the points associated with (i.e. closest to) the
+%	corresponding centres. For a mixture of PPCA model, the PPCA
+%	decomposition is calculated for the points closest to a given centre.
+%	This initialisation can be used as the starting point for training
+%	the model using the EM algorithm.
+%
+%	See also
+%	GMM
+%
+
+%	Copyright (c) Ian T Nabney (1996-2001)
+
+[ndata, xdim] = size(x);
+
+% Check that inputs are consistent
+errstring = consist(mix, 'gmm', x);
+if ~isempty(errstring)
+  error(errstring);
+end
+
+% Arbitrary width used if variance collapses to zero: make it 'large' so
+% that centre is responsible for a reasonable number of points.
+GMM_WIDTH = 1.0;
+
+% Use kmeans algorithm to set centres
+options(5) = 1;	
+[mix.centres, options, post] = kmeansNetlab(mix.centres, x, options);
+
+% Set priors depending on number of points in each cluster
+cluster_sizes = max(sum(post, 1), 1);  % Make sure that no prior is zero
+mix.priors = cluster_sizes/sum(cluster_sizes); % Normalise priors
+
+switch mix.covar_type
+case 'spherical'
+   if mix.ncentres > 1
+      % Determine widths as distance to nearest centre 
+      % (or a constant if this is zero)
+      cdist = dist2(mix.centres, mix.centres);
+      cdist = cdist + diag(ones(mix.ncentres, 1)*realmax);
+      mix.covars = min(cdist);
+      mix.covars = mix.covars + GMM_WIDTH*(mix.covars < eps);
+   else
+      % Just use variance of all data points averaged over all
+      % dimensions
+      mix.covars = mean(diag(cov(x)));
+   end
+  case 'diag'
+    for j = 1:mix.ncentres
+      % Pick out data points belonging to this centre
+      c = x(find(post(:, j)),:);
+      diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
+      mix.covars(j, :) = sum((diffs.*diffs), 1)/size(c, 1);
+      % Replace small entries by GMM_WIDTH value
+      mix.covars(j, :) = mix.covars(j, :) + GMM_WIDTH.*(mix.covars(j, :)<eps);
+    end
+  case 'full'
+    for j = 1:mix.ncentres
+      % Pick out data points belonging to this centre
+      c = x(find(post(:, j)),:);
+      diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
+      mix.covars(:,:,j) = (diffs'*diffs)/(size(c, 1));
+      % Add GMM_WIDTH*Identity to rank-deficient covariance matrices
+      if rank(mix.covars(:,:,j)) < mix.nin
+	mix.covars(:,:,j) = mix.covars(:,:,j) + GMM_WIDTH.*eye(mix.nin);
+      end
+    end
+  case 'ppca'
+    for j = 1:mix.ncentres
+      % Pick out data points belonging to this centre
+      c = x(find(post(:,j)),:);
+      diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
+      [tempcovars, tempU, templambda] = ...
+	ppca((diffs'*diffs)/size(c, 1), mix.ppca_dim);
+      if length(templambda) ~= mix.ppca_dim
+	error('Unable to extract enough components');
+      else 
+        mix.covars(j) = tempcovars;
+        mix.U(:, :, j) = tempU;
+        mix.lambda(j, :) = templambda;
+      end
+    end
+  otherwise
+    error(['Unknown covariance type ', mix.covar_type]);
+end
+