wolffd@0
|
1 %DEMKNN1 Demonstrate nearest neighbour classifier.
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % Description
|
wolffd@0
|
4 % The problem consists of data in a two-dimensional space. The data is
|
wolffd@0
|
5 % drawn from three spherical Gaussian distributions with priors 0.3,
|
wolffd@0
|
6 % 0.5 and 0.2; centres (2, 3.5), (0, 0) and (0,2); and standard
|
wolffd@0
|
7 % deviations 0.2, 0.5 and 1.0. The first figure contains a scatter plot
|
wolffd@0
|
8 % of the data. The data is the same as in DEMGMM1.
|
wolffd@0
|
9 %
|
wolffd@0
|
10 % The second figure shows the data labelled with the corresponding
|
wolffd@0
|
11 % class given by the classifier.
|
wolffd@0
|
12 %
|
wolffd@0
|
13 % See also
|
wolffd@0
|
14 % DEM2DDAT, DEMGMM1, KNN
|
wolffd@0
|
15 %
|
wolffd@0
|
16
|
wolffd@0
|
17 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
18
|
wolffd@0
|
19 clc
|
wolffd@0
|
20 disp('This program demonstrates the use of the K nearest neighbour algorithm.')
|
wolffd@0
|
21 disp(' ')
|
wolffd@0
|
22 disp('Press any key to continue.')
|
wolffd@0
|
23 pause
|
wolffd@0
|
24 % Generate the test data
|
wolffd@0
|
25 ndata = 250;
|
wolffd@0
|
26 randn('state', 42);
|
wolffd@0
|
27 rand('state', 42);
|
wolffd@0
|
28
|
wolffd@0
|
29 [data, c] = dem2ddat(ndata);
|
wolffd@0
|
30
|
wolffd@0
|
31 % Randomise data order
|
wolffd@0
|
32 data = data(randperm(ndata),:);
|
wolffd@0
|
33
|
wolffd@0
|
34 clc
|
wolffd@0
|
35 disp('We generate the data in two-dimensional space from a mixture of')
|
wolffd@0
|
36 disp('three spherical Gaussians. The centres are shown as black crosses')
|
wolffd@0
|
37 disp('in the plot.')
|
wolffd@0
|
38 disp(' ')
|
wolffd@0
|
39 disp('Press any key to continue.')
|
wolffd@0
|
40 pause
|
wolffd@0
|
41 fh1 = figure;
|
wolffd@0
|
42 plot(data(:, 1), data(:, 2), 'o')
|
wolffd@0
|
43 set(gca, 'Box', 'on')
|
wolffd@0
|
44 hold on
|
wolffd@0
|
45 title('Data')
|
wolffd@0
|
46 hp1 = plot(c(:, 1), c(:,2), 'k+')
|
wolffd@0
|
47 % Increase size of crosses
|
wolffd@0
|
48 set(hp1, 'MarkerSize', 8);
|
wolffd@0
|
49 set(hp1, 'LineWidth', 2);
|
wolffd@0
|
50 hold off
|
wolffd@0
|
51
|
wolffd@0
|
52 clc
|
wolffd@0
|
53 disp('We next use the centres as training examplars for the K nearest')
|
wolffd@0
|
54 disp('neighbour algorithm.')
|
wolffd@0
|
55 disp(' ')
|
wolffd@0
|
56 disp('Press any key to continue.')
|
wolffd@0
|
57 pause
|
wolffd@0
|
58
|
wolffd@0
|
59 % Use centres as training data
|
wolffd@0
|
60 train_labels = [1, 0, 0; 0, 1, 0; 0, 0, 1];
|
wolffd@0
|
61
|
wolffd@0
|
62 % Label the test data up to kmax neighbours
|
wolffd@0
|
63 kmax = 1;
|
wolffd@0
|
64 net = knn(2, 3, kmax, c, train_labels);
|
wolffd@0
|
65 [y, l] = knnfwd(net, data);
|
wolffd@0
|
66
|
wolffd@0
|
67 clc
|
wolffd@0
|
68 disp('We now plot each data point coloured according to its classification.')
|
wolffd@0
|
69 disp(' ')
|
wolffd@0
|
70 disp('Press any key to continue.')
|
wolffd@0
|
71 pause
|
wolffd@0
|
72 % Plot the result
|
wolffd@0
|
73 fh2 = figure;
|
wolffd@0
|
74 colors = ['b.'; 'r.'; 'g.'];
|
wolffd@0
|
75 for i = 1:3
|
wolffd@0
|
76 thisX = data(l == i,1);
|
wolffd@0
|
77 thisY = data(l == i,2);
|
wolffd@0
|
78 hp(i) = plot(thisX, thisY, colors(i,:));
|
wolffd@0
|
79 set(hp(i), 'MarkerSize', 12);
|
wolffd@0
|
80 if i == 1
|
wolffd@0
|
81 hold on
|
wolffd@0
|
82 end
|
wolffd@0
|
83 end
|
wolffd@0
|
84 set(gca, 'Box', 'on');
|
wolffd@0
|
85 legend('Class 1', 'Class 2', 'Class 3', 2)
|
wolffd@0
|
86 hold on
|
wolffd@0
|
87 labels = ['1', '2', '3'];
|
wolffd@0
|
88 hp2 = plot(c(:, 1), c(:,2), 'k+');
|
wolffd@0
|
89 % Increase size of crosses
|
wolffd@0
|
90 set(hp2, 'MarkerSize', 8);
|
wolffd@0
|
91 set(hp2, 'LineWidth', 2);
|
wolffd@0
|
92
|
wolffd@0
|
93 test_labels = labels(l(:,1));
|
wolffd@0
|
94
|
wolffd@0
|
95 title('Training data and data labels')
|
wolffd@0
|
96 hold off
|
wolffd@0
|
97
|
wolffd@0
|
98 disp('The demonstration is now complete: press any key to exit.')
|
wolffd@0
|
99 pause
|
wolffd@0
|
100 close(fh1);
|
wolffd@0
|
101 close(fh2);
|
wolffd@0
|
102 clear all;
|
wolffd@0
|
103
|