Mercurial > hg > camir-aes2014
diff core/tools/machine_learning/weighted_kmeans.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/core/tools/machine_learning/weighted_kmeans.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,163 @@ +function [centres, cweights, post, errlog, options] = weighted_kmeans(centres, data, weights, options) +%[centres, cweights, post, errlog, options] = weighted_kmeans(centres,data, weights, options) +% +% weighted_kmeans Trains a k means cluster model on weighted input vectors +% +% Adapted from the Netlab Toolbox by Daniel Wolff, +% This function takes a WEIGHTS vector, containing weights for the +% different data points. This can be used for training with varying +% discretisation intervals. +% +% Description +% CENTRES = weighted_kmeans(NCENTRES, DATA, WEIGHTS, OPTIONS) or +% CENTRES = weighted_kmeans(CENTRES, DATA, WEIGHTS, OPTIONS) uses the batch K-means +% algorithm to set the centres of a cluster model. The matrix DATA +% represents the data which is being clustered, with each row +% corresponding to a vector. The sum of squares error function is used. +% The point at which a local minimum is achieved is returned as +% CENTRES. The error value at that point is returned in OPTIONS(8). +% +% +% POST and ERRLOG +% also return the cluster number (in a one-of-N encoding) for each +% data point in POST and a log of the error values after each cycle in +% ERRLOG. The optional parameters have the following +% interpretations. +% +% OPTIONS(1) is set to 1 to display error values; also logs error +% values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then +% only warning messages are displayed. If OPTIONS(1) is -1, then +% nothing is displayed. +% +% OPTIONS(2) is a measure of the absolute precision required for the +% value of CENTRES at the solution. If the absolute difference between +% the values of CENTRES between two successive steps is less than +% OPTIONS(2), then this condition is satisfied. +% +% OPTIONS(3) is a measure of the precision required of the error +% function at the solution. If the absolute difference between the +% error functions between two successive steps is less than OPTIONS(3), +% then this condition is satisfied. Both this and the previous +% condition must be satisfied for termination. +% +% OPTIONS(14) is the maximum number of iterations; default 100. +% +% See also +% GMMINIT, GMMEM +% + +% Copyright (c) Ian T Nabney (1996-2001) + +[ndata, data_dim] = size(data); +[ncentres, dim] = size(centres); + +if dim ~= data_dim + if dim == 1 && ncentres == 1 && centres > 1 + + if ndata == numel(weights) + + % --- + % allow for number of centres specification + % --- + dim = data_dim; + ncentres = centres; + + options(5) = 1; + else + error('Data dimension does not match number of weights') + end + + else + error('Data dimension does not match dimension of centres') + end +end + +if (ncentres > ndata) + error('More centres than data') +end + +% Sort out the options +if (options(14)) + niters = options(14); +else + niters = 100; +end + +store = 0; +if (nargout > 3) + store = 1; + errlog = zeros(1, niters); +end + +% Check if centres and posteriors need to be initialised from data +if (options(5) == 1) + % Do the initialisation + perm = randperm(ndata); + perm = perm(1:ncentres); + + % Assign first ncentres (permuted) data points as centres + centres = data(perm, :); +end +% Matrix to make unit vectors easy to construct +id = eye(ncentres); + +% save accumulated weight for a center +cweights = zeros(ncentres, 1); + +% Main loop of algorithm +for n = 1:niters + + % Save old centres to check for termination + old_centres = centres; + + % Calculate posteriors based on existing centres + d2 = dist2(data, centres); + % Assign each point to nearest centre + [minvals, index] = min(d2', [], 1); + post = logical(id(index,:)); + + % num_points = sum(post, 1); + % Adjust the centres based on new posteriors + for j = 1:ncentres + if (sum(weights(post(:,j))) > 0) + % --- + % NOTE: this is edited to include the weights. + % Instead of summing the vectors directly, the vectors are weighted + % and then the result is divided by the sum of the weights instead + % of the number of vectors for this class + % --- + cweights(j) = sum(weights(post(:,j))); + + centres(j,:) = sum(diag(weights(post(:,j))) * data(post(:,j),:), 1)... + /cweights(j); + end + end + + % Error value is total squared distance from cluster centres + % edit: weighted by the vectors weight + e = sum(minvals .* weights); + if store + errlog(n) = e; + end + if options(1) > 0 + fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); + end + + if n > 1 + % Test for termination + if max(max(abs(centres - old_centres))) < options(2) & ... + abs(old_e - e) < options(3) + options(8) = e; + return; + end + end + old_e = e; +end + +% If we get here, then we haven't terminated in the given number of +% iterations. +options(8) = e; +if (options(1) >= 0) + disp(maxitmess); +end +