annotate toolboxes/FullBNT-1.0.7/bnt/inference/static/@var_elim_inf_engine/marginal_nodes.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [marginal, loglik] = marginal_nodes(engine, query, add_ev)
wolffd@0 2 % MARGINAL_NODES Compute the marginal on the specified query nodes (var_elim)
wolffd@0 3 % [marginal, loglik] = marginal_nodes(engine, query)
wolffd@0 4
wolffd@0 5 if nargin < 3, add_ev = 0; end
wolffd@0 6
wolffd@0 7 assert(length(query)>=1);
wolffd@0 8
wolffd@0 9 evidence = engine.evidence;
wolffd@0 10
wolffd@0 11 bnet = bnet_from_engine(engine);
wolffd@0 12 ns = bnet.node_sizes;
wolffd@0 13 n = length(bnet.dag);
wolffd@0 14
wolffd@0 15 onodes = find(~isemptycell(evidence));
wolffd@0 16 hnodes = find(isemptycell(evidence));
wolffd@0 17 pot_type = determine_pot_type(bnet, onodes);
wolffd@0 18
wolffd@0 19 % Fold the evidence into the CPTs - this could be done in 'enter_evidence'
wolffd@0 20 CPT = cell(1,n);
wolffd@0 21 for i=1:n
wolffd@0 22 fam = family(bnet.dag, i);
wolffd@0 23 CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence);
wolffd@0 24 end
wolffd@0 25
wolffd@0 26
wolffd@0 27
wolffd@0 28 sum_over = mysetdiff(1:n, query);
wolffd@0 29 order = [query sum_over]; % no attempt to optimize this
wolffd@0 30
wolffd@0 31 % Initialize the buckets with the product of the CPTs assigned to them
wolffd@0 32 B = cell(1,n+1);
wolffd@0 33 for b=1:n+1
wolffd@0 34 B{b} = mk_initial_pot(pot_type, [], [], [], []);
wolffd@0 35 end
wolffd@0 36 for i=1:n
wolffd@0 37 b = bucket_num(domain_pot(CPT{i}), order);
wolffd@0 38 B{b} = multiply_pots(B{b}, CPT{i});
wolffd@0 39 end
wolffd@0 40
wolffd@0 41 % Do the marginalization
wolffd@0 42 sum_over = sum_over(length(sum_over):-1:1); % reverse
wolffd@0 43 for i=sum_over(:)'
wolffd@0 44 % summing over variable i which occurs in bucket j
wolffd@0 45 j = bucket_num(i, order);
wolffd@0 46 rest = mysetdiff(domain_pot(B{j}), i);
wolffd@0 47 % minka
wolffd@0 48 if ~isempty(rest)
wolffd@0 49 temp = marginalize_pot(B{j}, rest);
wolffd@0 50 b = bucket_num(domain_pot(temp), order);
wolffd@0 51 %fprintf('summing over bucket %d (var %d), putting result into bucket %d\n', j, i, b);
wolffd@0 52 B{b} = multiply_pots(B{b}, temp);
wolffd@0 53 end
wolffd@0 54 end
wolffd@0 55
wolffd@0 56 % Combine all the remaining buckets into one
wolffd@0 57 result = B{1};
wolffd@0 58 for i=2:length(query)
wolffd@0 59 if ~isempty(domain_pot(B{i}))
wolffd@0 60 result = multiply_pots(result, B{i});
wolffd@0 61 end
wolffd@0 62 end
wolffd@0 63 [result, loglik] = normalize_pot(result);
wolffd@0 64
wolffd@0 65
wolffd@0 66 marginal = pot_to_marginal(result);
wolffd@0 67 % minka: from jtree_inf_engine
wolffd@0 68 if add_ev
wolffd@0 69 bnet = bnet_from_engine(engine);
wolffd@0 70 %marginal = add_ev_to_dmarginal(marginal, engine.evidence, bnet.node_sizes);
wolffd@0 71 marginal = add_evidence_to_gmarginal(marginal, engine.evidence, bnet.node_sizes, bnet.cnodes);
wolffd@0 72 end
wolffd@0 73
wolffd@0 74 %%%%%%%%%
wolffd@0 75
wolffd@0 76 function b = bucket_num(domain, order)
wolffd@0 77
wolffd@0 78 b = max(find_equiv_posns(domain, order));
wolffd@0 79