Mercurial > hg > camir-aes2014
comparison toolboxes/MIRtoolbox1.3.2/somtoolbox/knn.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [C,P]=knn(d, Cp, K) | |
2 | |
3 %KNN K-Nearest Neighbor classifier using an arbitrary distance matrix | |
4 % | |
5 % [C,P]=knn(d, Cp, [K]) | |
6 % | |
7 % Input and output arguments ([]'s are optional): | |
8 % d (matrix) of size NxP: This is a precalculated dissimilarity (distance matrix). | |
9 % P is the number of prototype vectors and N is the number of data vectors | |
10 % That is, d(i,j) is the distance between data item i and prototype j. | |
11 % Cp (vector) of size Px1 that contains integer class labels. Cp(j) is the class of | |
12 % jth prototype. | |
13 % [K] (scalar) the maximum K in K-NN classifier, default is 1 | |
14 % C (matrix) of size NxK: integers indicating the class | |
15 % decision for data items according to the K-NN rule for each K. | |
16 % C(i,K) is the classification for data item i using the K-NN rule | |
17 % P (matrix) of size NxkxK: the relative amount of prototypes of | |
18 % each class among the K closest prototypes for each classifiee. | |
19 % That is, P(i,j,K) is the relative amount of prototypes of class j | |
20 % among K nearest prototypes for data item i. | |
21 % | |
22 % If there is a tie between representatives of two or more classes | |
23 % among the K closest neighbors to the classifiee, the class i selected randomly | |
24 % among these candidates. | |
25 % | |
26 % IMPORTANT If K>1 this function uses 'sort' which is considerably slower than | |
27 % 'max' which is used for K=1. If K>1 the knn always calculates | |
28 % results for all K-NN models from 1-NN up to K-NN. | |
29 % | |
30 % EXAMPLE 1 | |
31 % | |
32 % sP; % a SOM Toolbox data struct containing labeled prototype vectors | |
33 % [Cp,label]=som_label2num(sP); % get integer class labels for prototype vectors | |
34 % sD; % a SOM Toolbox data struct containing vectors to be classified | |
35 % d=som_eucdist2(sD,sP); % calculate euclidean distance matrix | |
36 % class=knn(d,Cp,10); % classify using 1,2,...,10-rules | |
37 % class(:,5); % includes results for 5NN | |
38 % label(class(:,5)) % original class labels for 5NN | |
39 % | |
40 % EXAMPLE 2 (leave-one-out-crossvalidate KNN for selection of proper K) | |
41 % | |
42 % P; % a data matrix of prototype vectors (rows) | |
43 % Cp; % column vector of integer class labels for vectors in P | |
44 % d=som_eucdist2(P,P); % calculate euclidean distance matrix PxP | |
45 % d(eye(size(d))==1)=NaN; % set self-dissimilarity to NaN: | |
46 % % this drops the prototype itself away from its neighborhood | |
47 % % leave-one-out-crossvalidation (LOOCV) | |
48 % class=knn(d,Cp,size(P,1)); % classify using all possible K | |
49 % % calculate and plot LOOC-validated errors for all K | |
50 % failratep = ... | |
51 % 100*sum((class~=repmat(Cp,1,size(P,1))))./size(P,1); plot(1:size(P,1),failratep) | |
52 | |
53 % See also SOM_LABEL2NUM, SOM_EUCDIST2, PDIST. | |
54 % | |
55 % Contributed to SOM Toolbox 2.0, October 29th, 2000 by Johan Himberg | |
56 % Copyright (c) by Johan Himberg | |
57 % http://www.cis.hut.fi/projects/somtoolbox/ | |
58 | |
59 % Version 2.0beta Johan 291000 | |
60 | |
61 %% Init %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
62 | |
63 % Check K | |
64 if nargin<3 | isempty(K), | |
65 K=1; | |
66 end | |
67 | |
68 if ~vis_valuetype(K,{'1x1'}) | |
69 error('Value for K must be a scalar'); | |
70 end | |
71 | |
72 % Check that dist is a matrix | |
73 if ~vis_valuetype(d,{'nxm'}), | |
74 error('Distance matrix not valid.') | |
75 end | |
76 | |
77 [N_data N_proto]=size(d); | |
78 | |
79 % Check class label vector: must be numerical and of integers | |
80 if ~vis_valuetype(Cp,{[N_proto 1]}); | |
81 error(['Class vector is invalid: has to be a N-of-data_rows x 1' ... | |
82 ' vector of integers']); | |
83 elseif sum(fix(Cp)-Cp)~=0 | |
84 error('Class labels in vector ''Cp'' must be integers.'); | |
85 end | |
86 | |
87 if size(d,2) ~= length(Cp), | |
88 error('Distance matrix and prototype class vector dimensions do not match.'); | |
89 end | |
90 | |
91 % Check if the classes are given as labels (no class input arg.) | |
92 % if they are take them from prototype struct | |
93 | |
94 % Find all class labels | |
95 ClassIndex=unique(Cp); | |
96 N_class=length(ClassIndex); % number of different classes | |
97 | |
98 | |
99 %%%% Classification %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
100 | |
101 if K==1, % sort distances only if K>1 | |
102 | |
103 % 1NN | |
104 % Select the closest prototype | |
105 [tmp,proto_index]=min(d,[],2); | |
106 C=Cp(proto_index); | |
107 | |
108 else | |
109 | |
110 % Sort the prototypes for each classifiee according to distance | |
111 [tmp, proto_index]=sort(d'); | |
112 | |
113 %% Select up to K closest prototypes | |
114 proto_index=proto_index(1:K,:); | |
115 knn_class=Cp(proto_index); | |
116 for i=1:N_class, | |
117 classcounter(:,:,i)=cumsum(knn_class==ClassIndex(i)); | |
118 end | |
119 | |
120 %% Vote between classes of K neighbors | |
121 [winner,vote_index]=max(classcounter,[],3); | |
122 | |
123 %%% Handle ties | |
124 | |
125 % Set index to classes that got as much votes as winner | |
126 | |
127 equal_to_winner=(repmat(winner,[1 1 N_class])==classcounter); | |
128 | |
129 % set index to ties | |
130 [tie_indexi,tie_indexj]=find(sum(equal_to_winner,3)>1); % drop the winner from counter | |
131 | |
132 % Go through tie cases and reset vote_index randomly to one | |
133 % of them | |
134 | |
135 for i=1:length(tie_indexi), | |
136 tie_class_index=find(squeeze(equal_to_winner(tie_indexi(i),tie_indexj(i),:))); | |
137 fortuna=randperm(length(tie_class_index)); | |
138 vote_index(tie_indexi(i),tie_indexj(i))=tie_class_index(fortuna(1)); | |
139 end | |
140 | |
141 C=ClassIndex(vote_index)'; | |
142 end | |
143 | |
144 %% Build output %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
145 | |
146 % Relative amount of classes in K neighbors for each classifiee | |
147 | |
148 if K==1, | |
149 P=zeros(N_data,N_class); | |
150 if nargout>1, | |
151 for i=1:N_data, | |
152 P(i,ClassIndex==C(i))=1; | |
153 end | |
154 end | |
155 else | |
156 P=shiftdim(classcounter,1)./repmat(shiftdim(1:K,-1), [N_data N_class 1]); | |
157 end | |
158 |