wolffd@0
|
1 function [time, engine] = cmp_inference_dbn(bnet, engine, T, varargin)
|
wolffd@0
|
2 % CMP_INFERENCE_DBN Compare several inference engines on a DBN
|
wolffd@0
|
3 % function [time, engine] = cmp_inference_dbn(bnet, engine, T, ...)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % engine{i} is the i'th inference engine.
|
wolffd@0
|
6 % time(e) = elapsed time for doing inference with engine e
|
wolffd@0
|
7 %
|
wolffd@0
|
8 % The list below gives optional arguments [default value in brackets].
|
wolffd@0
|
9 %
|
wolffd@0
|
10 % exact - specifies which engines do exact inference [ 1:length(engine) ]
|
wolffd@0
|
11 % singletons_only - if 1, we only call marginal_nodes, else this and marginal_family [0]
|
wolffd@0
|
12 % check_ll - 1 means we check that the log-likelihoods are correct [1]
|
wolffd@0
|
13
|
wolffd@0
|
14 % set default params
|
wolffd@0
|
15 exact = 1:length(engine);
|
wolffd@0
|
16 singletons_only = 0;
|
wolffd@0
|
17 check_ll = 1;
|
wolffd@0
|
18 onodes = bnet.observed;
|
wolffd@0
|
19
|
wolffd@0
|
20 args = varargin;
|
wolffd@0
|
21 nargs = length(args);
|
wolffd@0
|
22 for i=1:2:nargs
|
wolffd@0
|
23 switch args{i},
|
wolffd@0
|
24 case 'exact', exact = args{i+1};
|
wolffd@0
|
25 case 'singletons_only', singletons_only = args{i+1};
|
wolffd@0
|
26 case 'check_ll', check_ll = args{i+1};
|
wolffd@0
|
27 case 'observed', onodes = args{i+1};
|
wolffd@0
|
28 otherwise,
|
wolffd@0
|
29 error(['unrecognized argument ' args{i}])
|
wolffd@0
|
30 end
|
wolffd@0
|
31 end
|
wolffd@0
|
32
|
wolffd@0
|
33 E = length(engine);
|
wolffd@0
|
34 ref = exact(1); % reference
|
wolffd@0
|
35
|
wolffd@0
|
36 ss = length(bnet.intra);
|
wolffd@0
|
37 ev = sample_dbn(bnet, 'length', T);
|
wolffd@0
|
38 evidence = cell(ss,T);
|
wolffd@0
|
39 evidence(onodes,:) = ev(onodes, :);
|
wolffd@0
|
40
|
wolffd@0
|
41 for i=1:E
|
wolffd@0
|
42 tic;
|
wolffd@0
|
43 [engine{i}, ll(i)] = enter_evidence(engine{i}, evidence);
|
wolffd@0
|
44 time(i)=toc;
|
wolffd@0
|
45 fprintf('engine %d took %6.4f seconds\n', i, time(i));
|
wolffd@0
|
46 end
|
wolffd@0
|
47
|
wolffd@0
|
48 cmp = mysetdiff(exact, ref);
|
wolffd@0
|
49 if check_ll
|
wolffd@0
|
50 for i=cmp(:)'
|
wolffd@0
|
51 if ~approxeq(ll(ref), ll(i))
|
wolffd@0
|
52 error(['engine ' num2str(i) ' has wrong ll'])
|
wolffd@0
|
53 end
|
wolffd@0
|
54 end
|
wolffd@0
|
55 end
|
wolffd@0
|
56 ll
|
wolffd@0
|
57
|
wolffd@0
|
58 hnodes = mysetdiff(1:ss, onodes);
|
wolffd@0
|
59
|
wolffd@0
|
60 if ~singletons_only
|
wolffd@0
|
61 get_marginals(engine, hnodes, exact, 0, T);
|
wolffd@0
|
62 end
|
wolffd@0
|
63 get_marginals(engine, hnodes, exact, 1, T);
|
wolffd@0
|
64
|
wolffd@0
|
65 %%%%%%%%%%
|
wolffd@0
|
66
|
wolffd@0
|
67 function get_marginals(engine, hnodes, exact, singletons, T)
|
wolffd@0
|
68
|
wolffd@0
|
69 bnet = bnet_from_engine(engine{1});
|
wolffd@0
|
70 N = length(bnet.intra);
|
wolffd@0
|
71 cnodes_bitv = zeros(1,N);
|
wolffd@0
|
72 cnodes_bitv(bnet.cnodes) = 1;
|
wolffd@0
|
73 ref = exact(1); % reference
|
wolffd@0
|
74 cmp = exact(2:end);
|
wolffd@0
|
75 E = length(engine);
|
wolffd@0
|
76 m = cell(1,E);
|
wolffd@0
|
77
|
wolffd@0
|
78 for t=1:T
|
wolffd@0
|
79 for n=1:N
|
wolffd@0
|
80 %for n=hnodes(:)'
|
wolffd@0
|
81 for e=1:E
|
wolffd@0
|
82 if singletons
|
wolffd@0
|
83 m{e} = marginal_nodes(engine{e}, n, t);
|
wolffd@0
|
84 else
|
wolffd@0
|
85 m{e} = marginal_family(engine{e}, n, t);
|
wolffd@0
|
86 end
|
wolffd@0
|
87 end
|
wolffd@0
|
88 for e=cmp(:)'
|
wolffd@0
|
89 assert(isequal(m{e}.domain, m{ref}.domain));
|
wolffd@0
|
90 if cnodes_bitv(n) & isfield(m{e}, 'mu') & isfield(m{ref}, 'mu')
|
wolffd@0
|
91 wrong = ~approxeq(m{ref}.mu, m{e}.mu) | ~approxeq(m{ref}.Sigma, m{e}.Sigma);
|
wolffd@0
|
92 else
|
wolffd@0
|
93 wrong = ~approxeq(m{ref}.T(:), m{e}.T(:));
|
wolffd@0
|
94 end
|
wolffd@0
|
95 if wrong
|
wolffd@0
|
96 error(sprintf('engine %d is wrong; n=%d, t=%d, fam=%d', e, n, t, ~singletons))
|
wolffd@0
|
97 end
|
wolffd@0
|
98 end
|
wolffd@0
|
99 end
|
wolffd@0
|
100 end
|