wolffd@0: function [marginal, loglik] = marginal_nodes(engine, query, add_ev) wolffd@0: % MARGINAL_NODES Compute the marginal on the specified query nodes (var_elim) wolffd@0: % [marginal, loglik] = marginal_nodes(engine, query) wolffd@0: wolffd@0: if nargin < 3, add_ev = 0; end wolffd@0: wolffd@0: assert(length(query)>=1); wolffd@0: wolffd@0: evidence = engine.evidence; wolffd@0: wolffd@0: bnet = bnet_from_engine(engine); wolffd@0: ns = bnet.node_sizes; wolffd@0: n = length(bnet.dag); wolffd@0: wolffd@0: onodes = find(~isemptycell(evidence)); wolffd@0: hnodes = find(isemptycell(evidence)); wolffd@0: pot_type = determine_pot_type(bnet, onodes); wolffd@0: wolffd@0: % Fold the evidence into the CPTs - this could be done in 'enter_evidence' wolffd@0: CPT = cell(1,n); wolffd@0: for i=1:n wolffd@0: fam = family(bnet.dag, i); wolffd@0: CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence); wolffd@0: end wolffd@0: wolffd@0: wolffd@0: wolffd@0: sum_over = mysetdiff(1:n, query); wolffd@0: order = [query sum_over]; % no attempt to optimize this wolffd@0: wolffd@0: % Initialize the buckets with the product of the CPTs assigned to them wolffd@0: B = cell(1,n+1); wolffd@0: for b=1:n+1 wolffd@0: B{b} = mk_initial_pot(pot_type, [], [], [], []); wolffd@0: end wolffd@0: for i=1:n wolffd@0: b = bucket_num(domain_pot(CPT{i}), order); wolffd@0: B{b} = multiply_pots(B{b}, CPT{i}); wolffd@0: end wolffd@0: wolffd@0: % Do the marginalization wolffd@0: sum_over = sum_over(length(sum_over):-1:1); % reverse wolffd@0: for i=sum_over(:)' wolffd@0: % summing over variable i which occurs in bucket j wolffd@0: j = bucket_num(i, order); wolffd@0: rest = mysetdiff(domain_pot(B{j}), i); wolffd@0: % minka wolffd@0: if ~isempty(rest) wolffd@0: temp = marginalize_pot(B{j}, rest); wolffd@0: b = bucket_num(domain_pot(temp), order); wolffd@0: %fprintf('summing over bucket %d (var %d), putting result into bucket %d\n', j, i, b); wolffd@0: B{b} = multiply_pots(B{b}, temp); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % Combine all the remaining buckets into one wolffd@0: result = B{1}; wolffd@0: for i=2:length(query) wolffd@0: if ~isempty(domain_pot(B{i})) wolffd@0: result = multiply_pots(result, B{i}); wolffd@0: end wolffd@0: end wolffd@0: [result, loglik] = normalize_pot(result); wolffd@0: wolffd@0: wolffd@0: marginal = pot_to_marginal(result); wolffd@0: % minka: from jtree_inf_engine wolffd@0: if add_ev wolffd@0: bnet = bnet_from_engine(engine); wolffd@0: %marginal = add_ev_to_dmarginal(marginal, engine.evidence, bnet.node_sizes); wolffd@0: marginal = add_evidence_to_gmarginal(marginal, engine.evidence, bnet.node_sizes, bnet.cnodes); wolffd@0: end wolffd@0: wolffd@0: %%%%%%%%% wolffd@0: wolffd@0: function b = bucket_num(domain, order) wolffd@0: wolffd@0: b = max(find_equiv_posns(domain, order)); wolffd@0: