comparison toolboxes/FullBNT-1.0.7/netlab3.3/demglm2.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 %DEMGLM2 Demonstrate simple classification using a generalized linear model.
2 %
3 % Description
4 % The problem consists of a two dimensional input matrix DATA and a
5 % vector of classifications T. The data is generated from three
6 % Gaussian clusters, and a generalized linear model with softmax output
7 % is trained using iterative reweighted least squares. A plot of the
8 % data together with regions shaded by the classification given by the
9 % network is generated.
10 %
11 % See also
12 % DEMGLM1, GLM, GLMTRAIN
13 %
14
15 % Copyright (c) Ian T Nabney (1996-2001)
16
17
18 % Generate data from three classes in 2d
19 input_dim = 2;
20
21 % Fix seeds for reproducible results
22 randn('state', 42);
23 rand('state', 42);
24
25 ndata = 100;
26 % Generate mixture of three Gaussians in two dimensional space
27 mix = gmm(2, 3, 'spherical');
28 mix.priors = [0.4 0.3 0.3]; % Cluster priors
29 mix.centres = [2, 2; 0.0, 0.0; 1, -1]; % Cluster centres
30 mix.covars = [0.5 1.0 0.6];
31
32 [data, label] = gmmsamp(mix, ndata);
33 id = eye(3);
34 targets = id(label,:);
35
36 % Plot the result
37
38 clc
39 disp('This demonstration illustrates the use of a generalized linear model')
40 disp('to classify data from three classes in a two-dimensional space. We')
41 disp('begin by generating and plotting the data.')
42 disp(' ')
43 disp('Press any key to continue.')
44 pause
45
46 fh1 = figure;
47 plot(data(label==1,1), data(label==1,2), 'bo');
48 hold on
49 axis([-4 5 -4 5]);
50 set(gca, 'Box', 'on')
51 plot(data(label==2,1), data(label==2,2), 'rx')
52 plot(data(label==3, 1), data(label==3, 2), 'go')
53 title('Data')
54
55 clc
56 disp('Now we fit a model consisting of a softmax function of')
57 disp('a linear combination of the input variables.')
58 disp(' ')
59 disp('The model is trained using the IRLS algorithm for up to 10 iterations')
60 disp(' ')
61 disp('Press any key to continue.')
62 pause
63
64 net = glm(input_dim, size(targets, 2), 'softmax');
65 options = foptions;
66 options(1) = 1;
67 options(14) = 10;
68 net = glmtrain(net, options, data, targets);
69
70 disp(' ')
71 disp('We now plot the decision regions given by this model.')
72 disp(' ')
73 disp('Press any key to continue.')
74 pause
75
76 x = -4.0:0.2:5.0;
77 y = -4.0:0.2:5.0;
78 [X, Y] = meshgrid(x,y);
79 X = X(:);
80 Y = Y(:);
81 grid = [X Y];
82 Z = glmfwd(net, grid);
83 [foo , class] = max(Z');
84 class = class';
85 colors = ['b.'; 'r.'; 'g.'];
86 for i = 1:3
87 thisX = X(class == i);
88 thisY = Y(class == i);
89 h = plot(thisX, thisY, colors(i,:));
90 set(h, 'MarkerSize', 8);
91 end
92 title('Plot of Decision regions')
93
94 hold off
95
96 clc
97 disp('Note that the boundaries of decision regions are straight lines.')
98 disp(' ')
99 disp('Press any key to end.')
100 pause
101 close(fh1);
102 clear all;
103