comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/cmp_inference_dbn.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [time, engine] = cmp_inference_dbn(bnet, engine, T, varargin)
2 % CMP_INFERENCE_DBN Compare several inference engines on a DBN
3 % function [time, engine] = cmp_inference_dbn(bnet, engine, T, ...)
4 %
5 % engine{i} is the i'th inference engine.
6 % time(e) = elapsed time for doing inference with engine e
7 %
8 % The list below gives optional arguments [default value in brackets].
9 %
10 % exact - specifies which engines do exact inference [ 1:length(engine) ]
11 % singletons_only - if 1, we only call marginal_nodes, else this and marginal_family [0]
12 % check_ll - 1 means we check that the log-likelihoods are correct [1]
13
14 % set default params
15 exact = 1:length(engine);
16 singletons_only = 0;
17 check_ll = 1;
18 onodes = bnet.observed;
19
20 args = varargin;
21 nargs = length(args);
22 for i=1:2:nargs
23 switch args{i},
24 case 'exact', exact = args{i+1};
25 case 'singletons_only', singletons_only = args{i+1};
26 case 'check_ll', check_ll = args{i+1};
27 case 'observed', onodes = args{i+1};
28 otherwise,
29 error(['unrecognized argument ' args{i}])
30 end
31 end
32
33 E = length(engine);
34 ref = exact(1); % reference
35
36 ss = length(bnet.intra);
37 ev = sample_dbn(bnet, 'length', T);
38 evidence = cell(ss,T);
39 evidence(onodes,:) = ev(onodes, :);
40
41 for i=1:E
42 tic;
43 [engine{i}, ll(i)] = enter_evidence(engine{i}, evidence);
44 time(i)=toc;
45 fprintf('engine %d took %6.4f seconds\n', i, time(i));
46 end
47
48 cmp = mysetdiff(exact, ref);
49 if check_ll
50 for i=cmp(:)'
51 if ~approxeq(ll(ref), ll(i))
52 error(['engine ' num2str(i) ' has wrong ll'])
53 end
54 end
55 end
56 ll
57
58 hnodes = mysetdiff(1:ss, onodes);
59
60 if ~singletons_only
61 get_marginals(engine, hnodes, exact, 0, T);
62 end
63 get_marginals(engine, hnodes, exact, 1, T);
64
65 %%%%%%%%%%
66
67 function get_marginals(engine, hnodes, exact, singletons, T)
68
69 bnet = bnet_from_engine(engine{1});
70 N = length(bnet.intra);
71 cnodes_bitv = zeros(1,N);
72 cnodes_bitv(bnet.cnodes) = 1;
73 ref = exact(1); % reference
74 cmp = exact(2:end);
75 E = length(engine);
76 m = cell(1,E);
77
78 for t=1:T
79 for n=1:N
80 %for n=hnodes(:)'
81 for e=1:E
82 if singletons
83 m{e} = marginal_nodes(engine{e}, n, t);
84 else
85 m{e} = marginal_family(engine{e}, n, t);
86 end
87 end
88 for e=cmp(:)'
89 assert(isequal(m{e}.domain, m{ref}.domain));
90 if cnodes_bitv(n) & isfield(m{e}, 'mu') & isfield(m{ref}, 'mu')
91 wrong = ~approxeq(m{ref}.mu, m{e}.mu) | ~approxeq(m{ref}.Sigma, m{e}.Sigma);
92 else
93 wrong = ~approxeq(m{ref}.T(:), m{e}.T(:));
94 end
95 if wrong
96 error(sprintf('engine %d is wrong; n=%d, t=%d, fam=%d', e, n, t, ~singletons))
97 end
98 end
99 end
100 end