comparison core/tools/machine_learning/weighted_kmeans.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [centres, cweights, post, errlog, options] = weighted_kmeans(centres, data, weights, options)
2 %[centres, cweights, post, errlog, options] = weighted_kmeans(centres,data, weights, options)
3 %
4 % weighted_kmeans Trains a k means cluster model on weighted input vectors
5 %
6 % Adapted from the Netlab Toolbox by Daniel Wolff,
7 % This function takes a WEIGHTS vector, containing weights for the
8 % different data points. This can be used for training with varying
9 % discretisation intervals.
10 %
11 % Description
12 % CENTRES = weighted_kmeans(NCENTRES, DATA, WEIGHTS, OPTIONS) or
13 % CENTRES = weighted_kmeans(CENTRES, DATA, WEIGHTS, OPTIONS) uses the batch K-means
14 % algorithm to set the centres of a cluster model. The matrix DATA
15 % represents the data which is being clustered, with each row
16 % corresponding to a vector. The sum of squares error function is used.
17 % The point at which a local minimum is achieved is returned as
18 % CENTRES. The error value at that point is returned in OPTIONS(8).
19 %
20 %
21 % POST and ERRLOG
22 % also return the cluster number (in a one-of-N encoding) for each
23 % data point in POST and a log of the error values after each cycle in
24 % ERRLOG. The optional parameters have the following
25 % interpretations.
26 %
27 % OPTIONS(1) is set to 1 to display error values; also logs error
28 % values in the return argument ERRLOG. If OPTIONS(1) is set to 0, then
29 % only warning messages are displayed. If OPTIONS(1) is -1, then
30 % nothing is displayed.
31 %
32 % OPTIONS(2) is a measure of the absolute precision required for the
33 % value of CENTRES at the solution. If the absolute difference between
34 % the values of CENTRES between two successive steps is less than
35 % OPTIONS(2), then this condition is satisfied.
36 %
37 % OPTIONS(3) is a measure of the precision required of the error
38 % function at the solution. If the absolute difference between the
39 % error functions between two successive steps is less than OPTIONS(3),
40 % then this condition is satisfied. Both this and the previous
41 % condition must be satisfied for termination.
42 %
43 % OPTIONS(14) is the maximum number of iterations; default 100.
44 %
45 % See also
46 % GMMINIT, GMMEM
47 %
48
49 % Copyright (c) Ian T Nabney (1996-2001)
50
51 [ndata, data_dim] = size(data);
52 [ncentres, dim] = size(centres);
53
54 if dim ~= data_dim
55 if dim == 1 && ncentres == 1 && centres > 1
56
57 if ndata == numel(weights)
58
59 % ---
60 % allow for number of centres specification
61 % ---
62 dim = data_dim;
63 ncentres = centres;
64
65 options(5) = 1;
66 else
67 error('Data dimension does not match number of weights')
68 end
69
70 else
71 error('Data dimension does not match dimension of centres')
72 end
73 end
74
75 if (ncentres > ndata)
76 error('More centres than data')
77 end
78
79 % Sort out the options
80 if (options(14))
81 niters = options(14);
82 else
83 niters = 100;
84 end
85
86 store = 0;
87 if (nargout > 3)
88 store = 1;
89 errlog = zeros(1, niters);
90 end
91
92 % Check if centres and posteriors need to be initialised from data
93 if (options(5) == 1)
94 % Do the initialisation
95 perm = randperm(ndata);
96 perm = perm(1:ncentres);
97
98 % Assign first ncentres (permuted) data points as centres
99 centres = data(perm, :);
100 end
101 % Matrix to make unit vectors easy to construct
102 id = eye(ncentres);
103
104 % save accumulated weight for a center
105 cweights = zeros(ncentres, 1);
106
107 % Main loop of algorithm
108 for n = 1:niters
109
110 % Save old centres to check for termination
111 old_centres = centres;
112
113 % Calculate posteriors based on existing centres
114 d2 = dist2(data, centres);
115 % Assign each point to nearest centre
116 [minvals, index] = min(d2', [], 1);
117 post = logical(id(index,:));
118
119 % num_points = sum(post, 1);
120 % Adjust the centres based on new posteriors
121 for j = 1:ncentres
122 if (sum(weights(post(:,j))) > 0)
123 % ---
124 % NOTE: this is edited to include the weights.
125 % Instead of summing the vectors directly, the vectors are weighted
126 % and then the result is divided by the sum of the weights instead
127 % of the number of vectors for this class
128 % ---
129 cweights(j) = sum(weights(post(:,j)));
130
131 centres(j,:) = sum(diag(weights(post(:,j))) * data(post(:,j),:), 1)...
132 /cweights(j);
133 end
134 end
135
136 % Error value is total squared distance from cluster centres
137 % edit: weighted by the vectors weight
138 e = sum(minvals .* weights);
139 if store
140 errlog(n) = e;
141 end
142 if options(1) > 0
143 fprintf(1, 'Cycle %4d Error %11.6f\n', n, e);
144 end
145
146 if n > 1
147 % Test for termination
148 if max(max(abs(centres - old_centres))) < options(2) & ...
149 abs(old_e - e) < options(3)
150 options(8) = e;
151 return;
152 end
153 end
154 old_e = e;
155 end
156
157 % If we get here, then we haven't terminated in the given number of
158 % iterations.
159 options(8) = e;
160 if (options(1) >= 0)
161 disp(maxitmess);
162 end
163