Mercurial > hg > camir-aes2014
comparison toolboxes/MIRtoolbox1.3.2/somtoolbox/som_kmeans.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [codes,clusters,err] = som_kmeans(method, D, k, epochs, verbose) | |
2 | |
3 % SOM_KMEANS K-means algorithm. | |
4 % | |
5 % [codes,clusters,err] = som_kmeans(method, D, k, [epochs], [verbose]) | |
6 % | |
7 % Input and output arguments ([]'s are optional): | |
8 % method (string) k-means algorithm type: 'batch' or 'seq' | |
9 % D (matrix) data matrix | |
10 % (struct) data or map struct | |
11 % k (scalar) number of centroids | |
12 % [epochs] (scalar) number of training epochs | |
13 % [verbose] (scalar) if <> 0 display additonal information | |
14 % | |
15 % codes (matrix) codebook vectors | |
16 % clusters (vector) cluster number for each sample | |
17 % err (scalar) total quantization error for the data set | |
18 % | |
19 % See also KMEANS_CLUSTERS, SOM_MAKE, SOM_BATCHTRAIN, SOM_SEQTRAIN. | |
20 | |
21 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
22 % Function has been renamed by Kimmo Raivio, because matlab65 also have | |
23 % kmeans function 1.10.02 | |
24 %% input arguments | |
25 | |
26 if isstruct(D), | |
27 switch D.type, | |
28 case 'som_map', data = D.codebook; | |
29 case 'som_data', data = D.data; | |
30 end | |
31 else | |
32 data = D; | |
33 end | |
34 [l dim] = size(data); | |
35 | |
36 if nargin < 4 | isempty(epochs) | isnan(epochs), epochs = 100; end | |
37 if nargin < 5, verbose = 0; end | |
38 | |
39 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
40 %% action | |
41 | |
42 rand('state', sum(100*clock)); % init rand generator | |
43 | |
44 lr = 0.5; % learning rate for sequential k-means | |
45 temp = randperm(l); | |
46 centroids = data(temp(1:k),:); | |
47 res = zeros(k,l); | |
48 clusters = zeros(1, l); | |
49 | |
50 if dim==1, | |
51 [codes,clusters,err] = scalar_kmeans(data,k,epochs); | |
52 return; | |
53 end | |
54 | |
55 switch method | |
56 case 'seq', | |
57 len = epochs * l; | |
58 l_rate = linspace(lr,0,len); | |
59 order = randperm(l); | |
60 for iter = 1:len | |
61 x = D(order(rem(iter,l)+1),:); | |
62 dx = x(ones(k,1),:) - centroids; | |
63 [dist nearest] = min(sum(dx.^2,2)); | |
64 centroids(nearest,:) = centroids(nearest,:) + l_rate(iter)*dx(nearest,:); | |
65 end | |
66 [dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ... | |
67 ones(l, 1) * sum((centroids.^2)',1) - ... | |
68 2.*(data*(centroids')))'); | |
69 | |
70 case 'batch', | |
71 iter = 0; | |
72 old_clusters = zeros(k, 1); | |
73 while iter<epochs | |
74 | |
75 [dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ... | |
76 ones(l, 1) * sum((centroids.^2)',1) - ... | |
77 2.*(data*(centroids')))'); | |
78 | |
79 for i = 1:k | |
80 f = find(clusters==i); | |
81 s = length(f); | |
82 if s, centroids(i,:) = sum(data(f,:)) / s; end | |
83 end | |
84 | |
85 if iter | |
86 if sum(old_clusters==clusters)==0 | |
87 if verbose, fprintf(1, 'Convergence in %d iterations\n', iter); end | |
88 break; | |
89 end | |
90 end | |
91 | |
92 old_clusters = clusters; | |
93 iter = iter + 1; | |
94 end | |
95 | |
96 [dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ... | |
97 ones(l, 1) * sum((centroids.^2)',1) - ... | |
98 2.*(data*(centroids')))'); | |
99 otherwise, | |
100 fprintf(2, 'Unknown method\n'); | |
101 end | |
102 | |
103 err = 0; | |
104 for i = 1:k | |
105 f = find(clusters==i); | |
106 s = length(f); | |
107 if s, err = err + sum(sum((data(f,:)-ones(s,1)*centroids(i,:)).^2,2)); end | |
108 end | |
109 | |
110 codes = centroids; | |
111 return; | |
112 | |
113 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
114 | |
115 function [y,bm,qe] = scalar_kmeans(x,k,maxepochs) | |
116 | |
117 nans = ~isfinite(x); | |
118 x(nans) = []; | |
119 n = length(x); | |
120 mi = min(x); ma = max(x) | |
121 y = linspace(mi,ma,k)'; | |
122 bm = ones(n,1); | |
123 bmold = zeros(n,1); | |
124 i = 0; | |
125 while ~all(bm==bmold) & i<maxepochs, | |
126 bmold = bm; | |
127 [c bm] = histc(x,[-Inf; (y(2:end)+y(1:end-1))/2; Inf]); | |
128 y = full(sum(sparse(bm,1:n,x,k,n),2)); | |
129 zh = (c(1:end-1)==0); | |
130 y(~zh) = y(~zh)./c(~zh); | |
131 inds = find(zh)'; | |
132 for j=inds, if j==1, y(j) = mi; else y(j) = y(j-1) + eps; end, end | |
133 i=i+1; | |
134 end | |
135 if i==maxepochs, [c bm] = histc(x,[-Inf; (y(2:end)+y(1:end-1))/2; Inf]); end | |
136 if nargout>2, qe = sum(abs(x-y(bm)))/n; end | |
137 if any(nans), | |
138 notnan = find(~nans); n = length(nans); | |
139 y = full(sparse(notnan,1,y ,n,1)); y(nans) = NaN; | |
140 bm = full(sparse(notnan,1,bm,n,1)); bm(nans) = NaN; | |
141 if nargout>2, qe = full(sparse(notnan,1,qe,n,1)); qe(nans) = NaN; end | |
142 end | |
143 | |
144 return; | |
145 |