wolffd@0: function [centres, cweights, post, errlog, options] = weighted_kmeans(centres, data, weights, options) wolffd@0: %[centres, cweights, post, errlog, options] = weighted_kmeans(centres,data, weights, options) wolffd@0: % wolffd@0: % weighted_kmeans Trains a k means cluster model on weighted input vectors wolffd@0: % wolffd@0: % Adapted from the Netlab Toolbox by Daniel Wolff, wolffd@0: % This function takes a WEIGHTS vector, containing weights for the wolffd@0: % different data points. This can be used for training with varying wolffd@0: % discretisation intervals. wolffd@0: % wolffd@0: % Description wolffd@0: % CENTRES = weighted_kmeans(NCENTRES, DATA, WEIGHTS, OPTIONS) or wolffd@0: % CENTRES = weighted_kmeans(CENTRES, DATA, WEIGHTS, OPTIONS) uses the batch K-means wolffd@0: % algorithm to set the centres of a cluster model. The matrix DATA wolffd@0: % represents the data which is being clustered, with each row wolffd@0: % corresponding to a vector. The sum of squares error function is used. wolffd@0: % The point at which a local minimum is achieved is returned as wolffd@0: % CENTRES. The error value at that point is returned in OPTIONS(8). wolffd@0: % wolffd@0: % wolffd@0: % POST and ERRLOG wolffd@0: % also return the cluster number (in a one-of-N encoding) for each wolffd@0: % data point in POST and a log of the error values after each cycle in wolffd@0: % ERRLOG. The optional parameters have the following wolffd@0: % interpretations. wolffd@0: % wolffd@0: % OPTIONS(1) is set to 1 to display error values; also logs error wolffd@0: % values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then wolffd@0: % only warning messages are displayed. If OPTIONS(1) is -1, then wolffd@0: % nothing is displayed. wolffd@0: % wolffd@0: % OPTIONS(2) is a measure of the absolute precision required for the wolffd@0: % value of CENTRES at the solution. If the absolute difference between wolffd@0: % the values of CENTRES between two successive steps is less than wolffd@0: % OPTIONS(2), then this condition is satisfied. wolffd@0: % wolffd@0: % OPTIONS(3) is a measure of the precision required of the error wolffd@0: % function at the solution. If the absolute difference between the wolffd@0: % error functions between two successive steps is less than OPTIONS(3), wolffd@0: % then this condition is satisfied. Both this and the previous wolffd@0: % condition must be satisfied for termination. wolffd@0: % wolffd@0: % OPTIONS(14) is the maximum number of iterations; default 100. wolffd@0: % wolffd@0: % See also wolffd@0: % GMMINIT, GMMEM wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: [ndata, data_dim] = size(data); wolffd@0: [ncentres, dim] = size(centres); wolffd@0: wolffd@0: if dim ~= data_dim wolffd@0: if dim == 1 && ncentres == 1 && centres > 1 wolffd@0: wolffd@0: if ndata == numel(weights) wolffd@0: wolffd@0: % --- wolffd@0: % allow for number of centres specification wolffd@0: % --- wolffd@0: dim = data_dim; wolffd@0: ncentres = centres; wolffd@0: wolffd@0: options(5) = 1; wolffd@0: else wolffd@0: error('Data dimension does not match number of weights') wolffd@0: end wolffd@0: wolffd@0: else wolffd@0: error('Data dimension does not match dimension of centres') wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if (ncentres > ndata) wolffd@0: error('More centres than data') wolffd@0: end wolffd@0: wolffd@0: % Sort out the options wolffd@0: if (options(14)) wolffd@0: niters = options(14); wolffd@0: else wolffd@0: niters = 100; wolffd@0: end wolffd@0: wolffd@0: store = 0; wolffd@0: if (nargout > 3) wolffd@0: store = 1; wolffd@0: errlog = zeros(1, niters); wolffd@0: end wolffd@0: wolffd@0: % Check if centres and posteriors need to be initialised from data wolffd@0: if (options(5) == 1) wolffd@0: % Do the initialisation wolffd@0: perm = randperm(ndata); wolffd@0: perm = perm(1:ncentres); wolffd@0: wolffd@0: % Assign first ncentres (permuted) data points as centres wolffd@0: centres = data(perm, :); wolffd@0: end wolffd@0: % Matrix to make unit vectors easy to construct wolffd@0: id = eye(ncentres); wolffd@0: wolffd@0: % save accumulated weight for a center wolffd@0: cweights = zeros(ncentres, 1); wolffd@0: wolffd@0: % Main loop of algorithm wolffd@0: for n = 1:niters wolffd@0: wolffd@0: % Save old centres to check for termination wolffd@0: old_centres = centres; wolffd@0: wolffd@0: % Calculate posteriors based on existing centres wolffd@0: d2 = dist2(data, centres); wolffd@0: % Assign each point to nearest centre wolffd@0: [minvals, index] = min(d2', [], 1); wolffd@0: post = logical(id(index,:)); wolffd@0: wolffd@0: % num_points = sum(post, 1); wolffd@0: % Adjust the centres based on new posteriors wolffd@0: for j = 1:ncentres wolffd@0: if (sum(weights(post(:,j))) > 0) wolffd@0: % --- wolffd@0: % NOTE: this is edited to include the weights. wolffd@0: % Instead of summing the vectors directly, the vectors are weighted wolffd@0: % and then the result is divided by the sum of the weights instead wolffd@0: % of the number of vectors for this class wolffd@0: % --- wolffd@0: cweights(j) = sum(weights(post(:,j))); wolffd@0: wolffd@0: centres(j,:) = sum(diag(weights(post(:,j))) * data(post(:,j),:), 1)... wolffd@0: /cweights(j); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % Error value is total squared distance from cluster centres wolffd@0: % edit: weighted by the vectors weight wolffd@0: e = sum(minvals .* weights); wolffd@0: if store wolffd@0: errlog(n) = e; wolffd@0: end wolffd@0: if options(1) > 0 wolffd@0: fprintf(1, 'Cycle %4d Error %11.6f\n', n, e); wolffd@0: end wolffd@0: wolffd@0: if n > 1 wolffd@0: % Test for termination wolffd@0: if max(max(abs(centres - old_centres))) < options(2) & ... wolffd@0: abs(old_e - e) < options(3) wolffd@0: options(8) = e; wolffd@0: return; wolffd@0: end wolffd@0: end wolffd@0: old_e = e; wolffd@0: end wolffd@0: wolffd@0: % If we get here, then we haven't terminated in the given number of wolffd@0: % iterations. wolffd@0: options(8) = e; wolffd@0: if (options(1) >= 0) wolffd@0: disp(maxitmess); wolffd@0: end wolffd@0: