Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/static/softmax1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % Check that softmax works with a simple classification demo. | |
2 % Based on netlab's demglm2 | |
3 % X -> Q where X is an input node, and Q is a softmax | |
4 | |
5 rand('state', 0); | |
6 randn('state', 0); | |
7 | |
8 % Check inference | |
9 | |
10 input_dim = 2; | |
11 num_classes = 3; | |
12 IRLS_iter = 3; | |
13 | |
14 net = glm(input_dim, num_classes, 'softmax'); | |
15 | |
16 dag = zeros(2); | |
17 dag(1,2) = 1; | |
18 discrete_nodes = [2]; | |
19 bnet = mk_bnet(dag, [input_dim num_classes], 'discrete', discrete_nodes, 'observed', 1); | |
20 bnet.CPD{1} = root_CPD(bnet, 1); | |
21 clamped = 0; | |
22 bnet.CPD{2} = softmax_CPD(bnet, 2, net.w1, net.b1, clamped, IRLS_iter); | |
23 | |
24 engine = jtree_inf_engine(bnet); | |
25 | |
26 x = rand(1, input_dim); | |
27 q = glmfwd(net, x); | |
28 | |
29 [engine, ll] = enter_evidence(engine, {x, []}); | |
30 m = marginal_nodes(engine, 2); | |
31 assert(approxeq(m.T(:), q(:))); | |
32 | |
33 | |
34 % Check learning | |
35 % We use EM, but in fact there is no hidden data. | |
36 % The M step will call IRLS on the softmax node. | |
37 | |
38 % Generate data from three classes in 2d | |
39 input_dim = 2; | |
40 num_classes = 3; | |
41 | |
42 % Fix seeds for reproducible results | |
43 randn('state', 42); | |
44 rand('state', 42); | |
45 | |
46 ndata = 10; | |
47 % Generate mixture of three Gaussians in two dimensional space | |
48 data = randn(ndata, input_dim); | |
49 targets = zeros(ndata, 3); | |
50 | |
51 % Priors for the clusters | |
52 prior(1) = 0.4; | |
53 prior(2) = 0.3; | |
54 prior(3) = 0.3; | |
55 | |
56 % Cluster centres | |
57 c = [2.0, 2.0; 0.0, 0.0; 1, -1]; | |
58 | |
59 ndata1 = prior(1)*ndata; | |
60 ndata2 = (prior(1) + prior(2))*ndata; | |
61 % Put first cluster at (2, 2) | |
62 data(1:ndata1, 1) = data(1:ndata1, 1) * 0.5 + c(1,1); | |
63 data(1:ndata1, 2) = data(1:ndata1, 2) * 0.5 + c(1,2); | |
64 targets(1:ndata1, 1) = 1; | |
65 | |
66 % Leave second cluster at (0,0) | |
67 data((ndata1 + 1):ndata2, :) = ... | |
68 data((ndata1 + 1):ndata2, :); | |
69 targets((ndata1+1):ndata2, 2) = 1; | |
70 | |
71 data((ndata2+1):ndata, 1) = data((ndata2+1):ndata,1) *0.6 + c(3, 1); | |
72 data((ndata2+1):ndata, 2) = data((ndata2+1):ndata,2) *0.6 + c(3, 2); | |
73 targets((ndata2+1):ndata, 3) = 1; | |
74 | |
75 | |
76 if 0 | |
77 ndata = 1; | |
78 data = x; | |
79 targets = [1 0 0]; | |
80 end | |
81 | |
82 options = foptions; | |
83 options(1) = -1; % verbose | |
84 options(14) = IRLS_iter; | |
85 [net2, options2] = glmtrain(net, options, data, targets); | |
86 net2.ll = options2(8); % type 'help foptions' for details | |
87 | |
88 cases = cell(2, ndata); | |
89 for l=1:ndata | |
90 q = find(targets(l,:)==1); | |
91 x = data(l,:); | |
92 cases{1,l} = x(:); | |
93 cases{2,l} = q; | |
94 end | |
95 | |
96 max_iter = 2; % we have complete observability, so 1 iter is enough | |
97 [bnet2, ll2] = learn_params_em(engine, cases, max_iter); | |
98 | |
99 w = get_field(bnet2.CPD{2},'weights'); | |
100 b = get_field(bnet2.CPD{2},'offset')'; | |
101 | |
102 w | |
103 net2.w1 | |
104 | |
105 b | |
106 net2.b1 | |
107 | |
108 % assert(approxeq(net2.ll, ll2)); % glmtrain returns ll after final M step, learn_params before | |
109 |