wolffd@0
|
1 function [centres, cweights, post, errlog, options] = weighted_kmeans(centres, data, weights, options)
|
wolffd@0
|
2 %[centres, cweights, post, errlog, options] = weighted_kmeans(centres,data, weights, options)
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % weighted_kmeans Trains a k means cluster model on weighted input vectors
|
wolffd@0
|
5 %
|
wolffd@0
|
6 % Adapted from the Netlab Toolbox by Daniel Wolff,
|
wolffd@0
|
7 % This function takes a WEIGHTS vector, containing weights for the
|
wolffd@0
|
8 % different data points. This can be used for training with varying
|
wolffd@0
|
9 % discretisation intervals.
|
wolffd@0
|
10 %
|
wolffd@0
|
11 % Description
|
wolffd@0
|
12 % CENTRES = weighted_kmeans(NCENTRES, DATA, WEIGHTS, OPTIONS) or
|
wolffd@0
|
13 % CENTRES = weighted_kmeans(CENTRES, DATA, WEIGHTS, OPTIONS) uses the batch K-means
|
wolffd@0
|
14 % algorithm to set the centres of a cluster model. The matrix DATA
|
wolffd@0
|
15 % represents the data which is being clustered, with each row
|
wolffd@0
|
16 % corresponding to a vector. The sum of squares error function is used.
|
wolffd@0
|
17 % The point at which a local minimum is achieved is returned as
|
wolffd@0
|
18 % CENTRES. The error value at that point is returned in OPTIONS(8).
|
wolffd@0
|
19 %
|
wolffd@0
|
20 %
|
wolffd@0
|
21 % POST and ERRLOG
|
wolffd@0
|
22 % also return the cluster number (in a one-of-N encoding) for each
|
wolffd@0
|
23 % data point in POST and a log of the error values after each cycle in
|
wolffd@0
|
24 % ERRLOG. The optional parameters have the following
|
wolffd@0
|
25 % interpretations.
|
wolffd@0
|
26 %
|
wolffd@0
|
27 % OPTIONS(1) is set to 1 to display error values; also logs error
|
wolffd@0
|
28 % values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then
|
wolffd@0
|
29 % only warning messages are displayed. If OPTIONS(1) is -1, then
|
wolffd@0
|
30 % nothing is displayed.
|
wolffd@0
|
31 %
|
wolffd@0
|
32 % OPTIONS(2) is a measure of the absolute precision required for the
|
wolffd@0
|
33 % value of CENTRES at the solution. If the absolute difference between
|
wolffd@0
|
34 % the values of CENTRES between two successive steps is less than
|
wolffd@0
|
35 % OPTIONS(2), then this condition is satisfied.
|
wolffd@0
|
36 %
|
wolffd@0
|
37 % OPTIONS(3) is a measure of the precision required of the error
|
wolffd@0
|
38 % function at the solution. If the absolute difference between the
|
wolffd@0
|
39 % error functions between two successive steps is less than OPTIONS(3),
|
wolffd@0
|
40 % then this condition is satisfied. Both this and the previous
|
wolffd@0
|
41 % condition must be satisfied for termination.
|
wolffd@0
|
42 %
|
wolffd@0
|
43 % OPTIONS(14) is the maximum number of iterations; default 100.
|
wolffd@0
|
44 %
|
wolffd@0
|
45 % See also
|
wolffd@0
|
46 % GMMINIT, GMMEM
|
wolffd@0
|
47 %
|
wolffd@0
|
48
|
wolffd@0
|
49 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
50
|
wolffd@0
|
51 [ndata, data_dim] = size(data);
|
wolffd@0
|
52 [ncentres, dim] = size(centres);
|
wolffd@0
|
53
|
wolffd@0
|
54 if dim ~= data_dim
|
wolffd@0
|
55 if dim == 1 && ncentres == 1 && centres > 1
|
wolffd@0
|
56
|
wolffd@0
|
57 if ndata == numel(weights)
|
wolffd@0
|
58
|
wolffd@0
|
59 % ---
|
wolffd@0
|
60 % allow for number of centres specification
|
wolffd@0
|
61 % ---
|
wolffd@0
|
62 dim = data_dim;
|
wolffd@0
|
63 ncentres = centres;
|
wolffd@0
|
64
|
wolffd@0
|
65 options(5) = 1;
|
wolffd@0
|
66 else
|
wolffd@0
|
67 error('Data dimension does not match number of weights')
|
wolffd@0
|
68 end
|
wolffd@0
|
69
|
wolffd@0
|
70 else
|
wolffd@0
|
71 error('Data dimension does not match dimension of centres')
|
wolffd@0
|
72 end
|
wolffd@0
|
73 end
|
wolffd@0
|
74
|
wolffd@0
|
75 if (ncentres > ndata)
|
wolffd@0
|
76 error('More centres than data')
|
wolffd@0
|
77 end
|
wolffd@0
|
78
|
wolffd@0
|
79 % Sort out the options
|
wolffd@0
|
80 if (options(14))
|
wolffd@0
|
81 niters = options(14);
|
wolffd@0
|
82 else
|
wolffd@0
|
83 niters = 100;
|
wolffd@0
|
84 end
|
wolffd@0
|
85
|
wolffd@0
|
86 store = 0;
|
wolffd@0
|
87 if (nargout > 3)
|
wolffd@0
|
88 store = 1;
|
wolffd@0
|
89 errlog = zeros(1, niters);
|
wolffd@0
|
90 end
|
wolffd@0
|
91
|
wolffd@0
|
92 % Check if centres and posteriors need to be initialised from data
|
wolffd@0
|
93 if (options(5) == 1)
|
wolffd@0
|
94 % Do the initialisation
|
wolffd@0
|
95 perm = randperm(ndata);
|
wolffd@0
|
96 perm = perm(1:ncentres);
|
wolffd@0
|
97
|
wolffd@0
|
98 % Assign first ncentres (permuted) data points as centres
|
wolffd@0
|
99 centres = data(perm, :);
|
wolffd@0
|
100 end
|
wolffd@0
|
101 % Matrix to make unit vectors easy to construct
|
wolffd@0
|
102 id = eye(ncentres);
|
wolffd@0
|
103
|
wolffd@0
|
104 % save accumulated weight for a center
|
wolffd@0
|
105 cweights = zeros(ncentres, 1);
|
wolffd@0
|
106
|
wolffd@0
|
107 % Main loop of algorithm
|
wolffd@0
|
108 for n = 1:niters
|
wolffd@0
|
109
|
wolffd@0
|
110 % Save old centres to check for termination
|
wolffd@0
|
111 old_centres = centres;
|
wolffd@0
|
112
|
wolffd@0
|
113 % Calculate posteriors based on existing centres
|
wolffd@0
|
114 d2 = dist2(data, centres);
|
wolffd@0
|
115 % Assign each point to nearest centre
|
wolffd@0
|
116 [minvals, index] = min(d2', [], 1);
|
wolffd@0
|
117 post = logical(id(index,:));
|
wolffd@0
|
118
|
wolffd@0
|
119 % num_points = sum(post, 1);
|
wolffd@0
|
120 % Adjust the centres based on new posteriors
|
wolffd@0
|
121 for j = 1:ncentres
|
wolffd@0
|
122 if (sum(weights(post(:,j))) > 0)
|
wolffd@0
|
123 % ---
|
wolffd@0
|
124 % NOTE: this is edited to include the weights.
|
wolffd@0
|
125 % Instead of summing the vectors directly, the vectors are weighted
|
wolffd@0
|
126 % and then the result is divided by the sum of the weights instead
|
wolffd@0
|
127 % of the number of vectors for this class
|
wolffd@0
|
128 % ---
|
wolffd@0
|
129 cweights(j) = sum(weights(post(:,j)));
|
wolffd@0
|
130
|
wolffd@0
|
131 centres(j,:) = sum(diag(weights(post(:,j))) * data(post(:,j),:), 1)...
|
wolffd@0
|
132 /cweights(j);
|
wolffd@0
|
133 end
|
wolffd@0
|
134 end
|
wolffd@0
|
135
|
wolffd@0
|
136 % Error value is total squared distance from cluster centres
|
wolffd@0
|
137 % edit: weighted by the vectors weight
|
wolffd@0
|
138 e = sum(minvals .* weights);
|
wolffd@0
|
139 if store
|
wolffd@0
|
140 errlog(n) = e;
|
wolffd@0
|
141 end
|
wolffd@0
|
142 if options(1) > 0
|
wolffd@0
|
143 fprintf(1, 'Cycle %4d Error %11.6f\n', n, e);
|
wolffd@0
|
144 end
|
wolffd@0
|
145
|
wolffd@0
|
146 if n > 1
|
wolffd@0
|
147 % Test for termination
|
wolffd@0
|
148 if max(max(abs(centres - old_centres))) < options(2) & ...
|
wolffd@0
|
149 abs(old_e - e) < options(3)
|
wolffd@0
|
150 options(8) = e;
|
wolffd@0
|
151 return;
|
wolffd@0
|
152 end
|
wolffd@0
|
153 end
|
wolffd@0
|
154 old_e = e;
|
wolffd@0
|
155 end
|
wolffd@0
|
156
|
wolffd@0
|
157 % If we get here, then we haven't terminated in the given number of
|
wolffd@0
|
158 % iterations.
|
wolffd@0
|
159 options(8) = e;
|
wolffd@0
|
160 if (options(1) >= 0)
|
wolffd@0
|
161 disp(maxitmess);
|
wolffd@0
|
162 end
|
wolffd@0
|
163
|