Mercurial > hg > camir-aes2014
diff toolboxes/MIRtoolbox1.3.2/somtoolbox/som_kmeans.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/MIRtoolbox1.3.2/somtoolbox/som_kmeans.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,145 @@ +function [codes,clusters,err] = som_kmeans(method, D, k, epochs, verbose) + +% SOM_KMEANS K-means algorithm. +% +% [codes,clusters,err] = som_kmeans(method, D, k, [epochs], [verbose]) +% +% Input and output arguments ([]'s are optional): +% method (string) k-means algorithm type: 'batch' or 'seq' +% D (matrix) data matrix +% (struct) data or map struct +% k (scalar) number of centroids +% [epochs] (scalar) number of training epochs +% [verbose] (scalar) if <> 0 display additonal information +% +% codes (matrix) codebook vectors +% clusters (vector) cluster number for each sample +% err (scalar) total quantization error for the data set +% +% See also KMEANS_CLUSTERS, SOM_MAKE, SOM_BATCHTRAIN, SOM_SEQTRAIN. + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +% Function has been renamed by Kimmo Raivio, because matlab65 also have +% kmeans function 1.10.02 +%% input arguments + +if isstruct(D), + switch D.type, + case 'som_map', data = D.codebook; + case 'som_data', data = D.data; + end +else + data = D; +end +[l dim] = size(data); + +if nargin < 4 | isempty(epochs) | isnan(epochs), epochs = 100; end +if nargin < 5, verbose = 0; end + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% action + +rand('state', sum(100*clock)); % init rand generator + +lr = 0.5; % learning rate for sequential k-means +temp = randperm(l); +centroids = data(temp(1:k),:); +res = zeros(k,l); +clusters = zeros(1, l); + +if dim==1, + [codes,clusters,err] = scalar_kmeans(data,k,epochs); + return; +end + +switch method + case 'seq', + len = epochs * l; + l_rate = linspace(lr,0,len); + order = randperm(l); + for iter = 1:len + x = D(order(rem(iter,l)+1),:); + dx = x(ones(k,1),:) - centroids; + [dist nearest] = min(sum(dx.^2,2)); + centroids(nearest,:) = centroids(nearest,:) + l_rate(iter)*dx(nearest,:); + end + [dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ... + ones(l, 1) * sum((centroids.^2)',1) - ... + 2.*(data*(centroids')))'); + + case 'batch', + iter = 0; + old_clusters = zeros(k, 1); + while iter<epochs + + [dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ... + ones(l, 1) * sum((centroids.^2)',1) - ... + 2.*(data*(centroids')))'); + + for i = 1:k + f = find(clusters==i); + s = length(f); + if s, centroids(i,:) = sum(data(f,:)) / s; end + end + + if iter + if sum(old_clusters==clusters)==0 + if verbose, fprintf(1, 'Convergence in %d iterations\n', iter); end + break; + end + end + + old_clusters = clusters; + iter = iter + 1; + end + + [dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ... + ones(l, 1) * sum((centroids.^2)',1) - ... + 2.*(data*(centroids')))'); + otherwise, + fprintf(2, 'Unknown method\n'); +end + +err = 0; +for i = 1:k + f = find(clusters==i); + s = length(f); + if s, err = err + sum(sum((data(f,:)-ones(s,1)*centroids(i,:)).^2,2)); end +end + +codes = centroids; +return; + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +function [y,bm,qe] = scalar_kmeans(x,k,maxepochs) + + nans = ~isfinite(x); + x(nans) = []; + n = length(x); + mi = min(x); ma = max(x) + y = linspace(mi,ma,k)'; + bm = ones(n,1); + bmold = zeros(n,1); + i = 0; + while ~all(bm==bmold) & i<maxepochs, + bmold = bm; + [c bm] = histc(x,[-Inf; (y(2:end)+y(1:end-1))/2; Inf]); + y = full(sum(sparse(bm,1:n,x,k,n),2)); + zh = (c(1:end-1)==0); + y(~zh) = y(~zh)./c(~zh); + inds = find(zh)'; + for j=inds, if j==1, y(j) = mi; else y(j) = y(j-1) + eps; end, end + i=i+1; + end + if i==maxepochs, [c bm] = histc(x,[-Inf; (y(2:end)+y(1:end-1))/2; Inf]); end + if nargout>2, qe = sum(abs(x-y(bm)))/n; end + if any(nans), + notnan = find(~nans); n = length(nans); + y = full(sparse(notnan,1,y ,n,1)); y(nans) = NaN; + bm = full(sparse(notnan,1,bm,n,1)); bm(nans) = NaN; + if nargout>2, qe = full(sparse(notnan,1,qe,n,1)); qe(nans) = NaN; end + end + + return; +