comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@var_elim_inf_engine/marginal_nodes.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [marginal, loglik] = marginal_nodes(engine, query, add_ev)
2 % MARGINAL_NODES Compute the marginal on the specified query nodes (var_elim)
3 % [marginal, loglik] = marginal_nodes(engine, query)
4
5 if nargin < 3, add_ev = 0; end
6
7 assert(length(query)>=1);
8
9 evidence = engine.evidence;
10
11 bnet = bnet_from_engine(engine);
12 ns = bnet.node_sizes;
13 n = length(bnet.dag);
14
15 onodes = find(~isemptycell(evidence));
16 hnodes = find(isemptycell(evidence));
17 pot_type = determine_pot_type(bnet, onodes);
18
19 % Fold the evidence into the CPTs - this could be done in 'enter_evidence'
20 CPT = cell(1,n);
21 for i=1:n
22 fam = family(bnet.dag, i);
23 CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence);
24 end
25
26
27
28 sum_over = mysetdiff(1:n, query);
29 order = [query sum_over]; % no attempt to optimize this
30
31 % Initialize the buckets with the product of the CPTs assigned to them
32 B = cell(1,n+1);
33 for b=1:n+1
34 B{b} = mk_initial_pot(pot_type, [], [], [], []);
35 end
36 for i=1:n
37 b = bucket_num(domain_pot(CPT{i}), order);
38 B{b} = multiply_pots(B{b}, CPT{i});
39 end
40
41 % Do the marginalization
42 sum_over = sum_over(length(sum_over):-1:1); % reverse
43 for i=sum_over(:)'
44 % summing over variable i which occurs in bucket j
45 j = bucket_num(i, order);
46 rest = mysetdiff(domain_pot(B{j}), i);
47 % minka
48 if ~isempty(rest)
49 temp = marginalize_pot(B{j}, rest);
50 b = bucket_num(domain_pot(temp), order);
51 %fprintf('summing over bucket %d (var %d), putting result into bucket %d\n', j, i, b);
52 B{b} = multiply_pots(B{b}, temp);
53 end
54 end
55
56 % Combine all the remaining buckets into one
57 result = B{1};
58 for i=2:length(query)
59 if ~isempty(domain_pot(B{i}))
60 result = multiply_pots(result, B{i});
61 end
62 end
63 [result, loglik] = normalize_pot(result);
64
65
66 marginal = pot_to_marginal(result);
67 % minka: from jtree_inf_engine
68 if add_ev
69 bnet = bnet_from_engine(engine);
70 %marginal = add_ev_to_dmarginal(marginal, engine.evidence, bnet.node_sizes);
71 marginal = add_evidence_to_gmarginal(marginal, engine.evidence, bnet.node_sizes, bnet.cnodes);
72 end
73
74 %%%%%%%%%
75
76 function b = bucket_num(domain, order)
77
78 b = max(find_equiv_posns(domain, order));
79