Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/knnfwd.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [y, l] = knnfwd(net, x) | |
2 %KNNFWD Forward propagation through a K-nearest-neighbour classifier. | |
3 % | |
4 % Description | |
5 % [Y, L] = KNNFWD(NET, X) takes a matrix X of input vectors (one vector | |
6 % per row) and uses the K-nearest-neighbour rule on the training data | |
7 % contained in NET to produce a matrix Y of outputs and a matrix L of | |
8 % classification labels. The nearest neighbours are determined using | |
9 % Euclidean distance. The IJth entry of Y counts the number of | |
10 % occurrences that an example from class J is among the K closest | |
11 % training examples to example I from X. The matrix L contains the | |
12 % predicted class labels as an index 1..N, not as 1-of-N coding. | |
13 % | |
14 % See also | |
15 % KMEANS, KNN | |
16 % | |
17 | |
18 % Copyright (c) Ian T Nabney (1996-2001) | |
19 | |
20 | |
21 errstring = consist(net, 'knn', x); | |
22 if ~isempty(errstring) | |
23 error(errstring); | |
24 end | |
25 | |
26 ntest = size(x, 1); % Number of input vectors. | |
27 nclass = size(net.tr_targets, 2); % Number of classes. | |
28 | |
29 % Compute matrix of squared distances between input vectors from the training | |
30 % and test sets. The matrix distsq has dimensions (ntrain, ntest). | |
31 | |
32 distsq = dist2(net.tr_in, x); | |
33 | |
34 % Now sort the distances. This generates a matrix kind of the same | |
35 % dimensions as distsq, in which each column gives the indices of the | |
36 % elements in the corresponding column of distsq in ascending order. | |
37 | |
38 [vals, kind] = sort(distsq); | |
39 y = zeros(ntest, nclass); | |
40 | |
41 for k=1:net.k | |
42 % We now look at the predictions made by the Kth nearest neighbours alone, | |
43 % and represent this as a 1-of-N coded matrix, and then accumulate the | |
44 % predictions so far. | |
45 | |
46 y = y + net.tr_targets(kind(k,:),:); | |
47 | |
48 end | |
49 | |
50 if nargout == 2 | |
51 % Convert this set of outputs to labels, randomly breaking ties | |
52 [temp, l] = max((y + 0.1*rand(size(y))), [], 2); | |
53 end |