wolffd@0: %DEMKNN1 Demonstrate nearest neighbour classifier. wolffd@0: % wolffd@0: % Description wolffd@0: % The problem consists of data in a two-dimensional space. The data is wolffd@0: % drawn from three spherical Gaussian distributions with priors 0.3, wolffd@0: % 0.5 and 0.2; centres (2, 3.5), (0, 0) and (0,2); and standard wolffd@0: % deviations 0.2, 0.5 and 1.0. The first figure contains a scatter plot wolffd@0: % of the data. The data is the same as in DEMGMM1. wolffd@0: % wolffd@0: % The second figure shows the data labelled with the corresponding wolffd@0: % class given by the classifier. wolffd@0: % wolffd@0: % See also wolffd@0: % DEM2DDAT, DEMGMM1, KNN wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: clc wolffd@0: disp('This program demonstrates the use of the K nearest neighbour algorithm.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: % Generate the test data wolffd@0: ndata = 250; wolffd@0: randn('state', 42); wolffd@0: rand('state', 42); wolffd@0: wolffd@0: [data, c] = dem2ddat(ndata); wolffd@0: wolffd@0: % Randomise data order wolffd@0: data = data(randperm(ndata),:); wolffd@0: wolffd@0: clc wolffd@0: disp('We generate the data in two-dimensional space from a mixture of') wolffd@0: disp('three spherical Gaussians. The centres are shown as black crosses') wolffd@0: disp('in the plot.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: fh1 = figure; wolffd@0: plot(data(:, 1), data(:, 2), 'o') wolffd@0: set(gca, 'Box', 'on') wolffd@0: hold on wolffd@0: title('Data') wolffd@0: hp1 = plot(c(:, 1), c(:,2), 'k+') wolffd@0: % Increase size of crosses wolffd@0: set(hp1, 'MarkerSize', 8); wolffd@0: set(hp1, 'LineWidth', 2); wolffd@0: hold off wolffd@0: wolffd@0: clc wolffd@0: disp('We next use the centres as training examplars for the K nearest') wolffd@0: disp('neighbour algorithm.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: wolffd@0: % Use centres as training data wolffd@0: train_labels = [1, 0, 0; 0, 1, 0; 0, 0, 1]; wolffd@0: wolffd@0: % Label the test data up to kmax neighbours wolffd@0: kmax = 1; wolffd@0: net = knn(2, 3, kmax, c, train_labels); wolffd@0: [y, l] = knnfwd(net, data); wolffd@0: wolffd@0: clc wolffd@0: disp('We now plot each data point coloured according to its classification.') wolffd@0: disp(' ') wolffd@0: disp('Press any key to continue.') wolffd@0: pause wolffd@0: % Plot the result wolffd@0: fh2 = figure; wolffd@0: colors = ['b.'; 'r.'; 'g.']; wolffd@0: for i = 1:3 wolffd@0: thisX = data(l == i,1); wolffd@0: thisY = data(l == i,2); wolffd@0: hp(i) = plot(thisX, thisY, colors(i,:)); wolffd@0: set(hp(i), 'MarkerSize', 12); wolffd@0: if i == 1 wolffd@0: hold on wolffd@0: end wolffd@0: end wolffd@0: set(gca, 'Box', 'on'); wolffd@0: legend('Class 1', 'Class 2', 'Class 3', 2) wolffd@0: hold on wolffd@0: labels = ['1', '2', '3']; wolffd@0: hp2 = plot(c(:, 1), c(:,2), 'k+'); wolffd@0: % Increase size of crosses wolffd@0: set(hp2, 'MarkerSize', 8); wolffd@0: set(hp2, 'LineWidth', 2); wolffd@0: wolffd@0: test_labels = labels(l(:,1)); wolffd@0: wolffd@0: title('Training data and data labels') wolffd@0: hold off wolffd@0: wolffd@0: disp('The demonstration is now complete: press any key to exit.') wolffd@0: pause wolffd@0: close(fh1); wolffd@0: close(fh2); wolffd@0: clear all; wolffd@0: