Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/static/fgraph/fg1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % make an unrolled HMM, convert to factor graph, and check that | |
2 % loopy propagation on the fgraph gives the exact answers. | |
3 | |
4 seed = 1; | |
5 rand('state', seed); | |
6 randn('state', seed); | |
7 | |
8 T = 3; | |
9 Q = 3; | |
10 O = 3; | |
11 cts_obs = 0; | |
12 param_tying = 1; | |
13 bnet = mk_hmm_bnet(T, Q, O, cts_obs, param_tying); | |
14 | |
15 data = sample_bnet(bnet); | |
16 | |
17 fgraph = bnet_to_fgraph(bnet); | |
18 big_bnet = fgraph_to_bnet(fgraph); | |
19 % converting factor graph back does not recover the structure of the original bnet | |
20 | |
21 max_iter = 2*T; | |
22 | |
23 engine = {}; | |
24 engine{1} = jtree_inf_engine(bnet); | |
25 engine{2} = belprop_inf_engine(bnet, 'max_iter', max_iter); | |
26 engine{3} = belprop_fg_inf_engine(fgraph, 'max_iter', max_iter); | |
27 engine{4} = jtree_inf_engine(big_bnet); | |
28 nengines = length(engine); | |
29 | |
30 big_engine = 4; | |
31 fgraph_engine = 3; | |
32 | |
33 | |
34 N = 2*T; | |
35 evidence = cell(1,N); | |
36 onodes = bnet.observed; | |
37 evidence(onodes) = data(onodes); | |
38 hnodes = mysetdiff(1:N, onodes); | |
39 | |
40 bigN = length(big_bnet.dag); | |
41 big_evidence = cell(1, bigN); | |
42 big_evidence(onodes) = data(onodes); | |
43 big_evidence(N+1:end) = {1}; % factors are observed to be 1 | |
44 | |
45 ll = zeros(1, nengines); | |
46 for i=1:nengines | |
47 if i==big_engine | |
48 tic; [engine{i}, ll(i)] = enter_evidence(engine{i}, big_evidence); toc | |
49 else | |
50 tic; [engine{i}, ll(i)] = enter_evidence(engine{i}, evidence); toc | |
51 end | |
52 end | |
53 | |
54 % compare all engines to engine{1} | |
55 | |
56 % the log likelihood values may be bogus... | |
57 for i=2:nengines | |
58 %assert(approxeq(ll(1), ll(i))); | |
59 end | |
60 | |
61 | |
62 marg = zeros(T, nengines, Q); % marg(t,e,:) | |
63 for t=1:T | |
64 for e=1:nengines | |
65 m = marginal_nodes(engine{e}, t); | |
66 marg(t,e,:) = m.T; | |
67 end | |
68 end | |
69 marg | |
70 | |
71 | |
72 m = cell(nengines, T); | |
73 for i=1:T | |
74 for e=1:nengines | |
75 m{e,i} = marginal_nodes(engine{e}, hnodes(i)); | |
76 end | |
77 for e=2:nengines | |
78 assert(approxeq(m{e,i}.T, m{1,i}.T)); | |
79 end | |
80 end | |
81 | |
82 mpe = {}; | |
83 ll = zeros(1, nengines); | |
84 for e=1:nengines | |
85 if e==big_engine | |
86 mpe{e} = find_mpe(engine{e}, big_evidence); | |
87 mpe{e} = mpe{e}(1:N); % chop off dummy nodes | |
88 else | |
89 mpe{e} = find_mpe(engine{e}, evidence); | |
90 end | |
91 end | |
92 | |
93 % fgraph can't compute loglikelihood for software reasons | |
94 % jtree on the big_bnet gives the wrong ll | |
95 for e=2:nengines | |
96 %assert(approxeq(ll(1), ll(e))); | |
97 assert(approxeq(cell2num(mpe{1}), cell2num(mpe{e}))) | |
98 end |