diff core/tools/machine_learning/weighted_kmeans.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/core/tools/machine_learning/weighted_kmeans.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,163 @@
+function [centres, cweights, post, errlog, options] = weighted_kmeans(centres, data, weights, options)
+%[centres, cweights, post, errlog, options] = weighted_kmeans(centres,data, weights, options)
+%
+% weighted_kmeans	Trains a k means cluster model on weighted input vectors
+%
+% Adapted from the Netlab Toolbox by Daniel Wolff,
+% This function takes a WEIGHTS vector, containing weights for the 
+% different data points. This can be used for training with varying 
+% discretisation intervals.
+%
+%	Description
+%	 CENTRES = weighted_kmeans(NCENTRES, DATA, WEIGHTS, OPTIONS) or
+%    CENTRES = weighted_kmeans(CENTRES, DATA, WEIGHTS, OPTIONS) uses the batch K-means
+%	algorithm to set the centres of a cluster model. The matrix DATA
+%	represents the data which is being clustered, with each row
+%	corresponding to a vector. The sum of squares error function is used.
+%	The point at which a local minimum is achieved is returned as
+%	CENTRES.  The error value at that point is returned in OPTIONS(8).
+% 
+%
+%	POST and ERRLOG
+%	also return the cluster number (in a one-of-N encoding) for each
+%	data point in POST and a log of the error values after each cycle in
+%	ERRLOG.    The optional parameters have the following
+%	interpretations.
+%
+%	OPTIONS(1) is set to 1 to display error values; also logs error
+%	values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then
+%	only warning messages are displayed.  If OPTIONS(1) is -1, then
+%	nothing is displayed.
+%
+%	OPTIONS(2) is a measure of the absolute precision required for the
+%	value of CENTRES at the solution.  If the absolute difference between
+%	the values of CENTRES between two successive steps is less than
+%	OPTIONS(2), then this condition is satisfied.
+%
+%	OPTIONS(3) is a measure of the precision required of the error
+%	function at the solution.  If the absolute difference between the
+%	error functions between two successive steps is less than OPTIONS(3),
+%	then this condition is satisfied. Both this and the previous
+%	condition must be satisfied for termination.
+%
+%	OPTIONS(14) is the maximum number of iterations; default 100.
+%
+%	See also
+%	GMMINIT, GMMEM
+%
+
+%	Copyright (c) Ian T Nabney (1996-2001)
+
+[ndata, data_dim] = size(data);
+[ncentres, dim] = size(centres);
+
+if dim ~= data_dim
+    if dim == 1 && ncentres == 1 && centres > 1
+        
+        if ndata == numel(weights)
+        
+            % ---
+            % allow for number of centres specification
+            % ---
+            dim = data_dim;
+            ncentres = centres;
+
+            options(5) = 1;
+        else
+            error('Data dimension does not match number of weights')
+        end
+        
+    else
+        error('Data dimension does not match dimension of centres')
+    end
+end
+
+if (ncentres > ndata)
+  error('More centres than data')
+end
+
+% Sort out the options
+if (options(14))
+  niters = options(14);
+else
+  niters = 100;
+end
+
+store = 0;
+if (nargout > 3)
+  store = 1;
+  errlog = zeros(1, niters);
+end
+
+% Check if centres and posteriors need to be initialised from data
+if (options(5) == 1)
+  % Do the initialisation
+  perm = randperm(ndata);
+  perm = perm(1:ncentres);
+
+  % Assign first ncentres (permuted) data points as centres
+  centres = data(perm, :);
+end
+% Matrix to make unit vectors easy to construct
+id = eye(ncentres);
+
+% save accumulated weight for a center
+cweights = zeros(ncentres, 1);
+
+% Main loop of algorithm
+for n = 1:niters
+
+  % Save old centres to check for termination
+  old_centres = centres;
+  
+  % Calculate posteriors based on existing centres
+  d2 = dist2(data, centres);
+  % Assign each point to nearest centre
+  [minvals, index] = min(d2', [], 1);
+  post = logical(id(index,:));
+
+  % num_points = sum(post, 1);
+  % Adjust the centres based on new posteriors
+  for j = 1:ncentres
+    if (sum(weights(post(:,j))) > 0)
+        % ---
+        % NOTE: this is edited to include the weights.
+        % Instead of summing the vectors directly, the vectors are weighted
+        % and then the result is divided by the sum of the weights instead
+        % of the number of vectors for this class
+        % ---
+      cweights(j) = sum(weights(post(:,j)));
+      
+      centres(j,:) = sum(diag(weights(post(:,j))) * data(post(:,j),:), 1)...
+          /cweights(j);
+    end
+  end
+  
+  % Error value is total squared distance from cluster centres
+  % edit: weighted by the vectors weight
+  e = sum(minvals .* weights);
+  if store
+    errlog(n) = e;
+  end
+  if options(1) > 0
+    fprintf(1, 'Cycle %4d  Error %11.6f\n', n, e);
+  end
+
+  if n > 1
+    % Test for termination
+    if max(max(abs(centres - old_centres))) < options(2) & ...
+        abs(old_e - e) < options(3)
+      options(8) = e;
+      return;
+    end
+  end
+  old_e = e;
+end
+
+% If we get here, then we haven't terminated in the given number of 
+% iterations.
+options(8) = e;
+if (options(1) >= 0)
+  disp(maxitmess);
+end
+