matthiasm@8: function engine = cbk_inf_engine(bnet, varargin) matthiasm@8: % Just the same as bk_inf_engine, but you can specify overlapping clusters. matthiasm@8: matthiasm@8: ss = length(bnet.intra); matthiasm@8: % set default params matthiasm@8: clusters = 'exact'; matthiasm@8: matthiasm@8: if nargin >= 2 matthiasm@8: args = varargin; matthiasm@8: nargs = length(args); matthiasm@8: for i=1:2:nargs matthiasm@8: switch args{i}, matthiasm@8: case 'clusters', clusters = args{i+1}; matthiasm@8: otherwise, error(['unrecognized argument ' args{i}]) matthiasm@8: end matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: if strcmp(clusters, 'exact') matthiasm@8: %clusters = { compute_interface_nodes(bnet.intra, bnet.inter) }; matthiasm@8: clusters = { 1:ss }; matthiasm@8: elseif strcmp(clusters, 'ff') matthiasm@8: clusters = num2cell(1:ss); matthiasm@8: end matthiasm@8: matthiasm@8: matthiasm@8: % We need to insert the prior on the clusters in slice 1, matthiasm@8: % and extract the posterior on the clusters in slice 2. matthiasm@8: % We don't need to care about the separators, b/c they're subsets of the clusters. matthiasm@8: C = length(clusters); matthiasm@8: clusters2 = cell(1,2*C); matthiasm@8: clusters2(1:C) = clusters; matthiasm@8: for c=1:C matthiasm@8: clusters2{c+C} = clusters{c} + ss; matthiasm@8: end matthiasm@8: matthiasm@8: onodes = bnet.observed; matthiasm@8: obs_nodes = [onodes(:) onodes(:)+ss]; matthiasm@8: engine.sub_engine = jtree_inf_engine(bnet, 'clusters', clusters2); matthiasm@8: matthiasm@8: %FH >>> matthiasm@8: %Compute separators. matthiasm@8: ns = bnet.node_sizes(:,1); matthiasm@8: ns(onodes) = 1; matthiasm@8: [clusters, separators] = build_jt(clusters, 1:length(ns), ns); matthiasm@8: S = length(separators); matthiasm@8: engine.separators = separators; matthiasm@8: matthiasm@8: %Compute size of clusters. matthiasm@8: cl_sizes = zeros(1,C); matthiasm@8: for c=1:C matthiasm@8: cl_sizes(c) = prod(ns(clusters{c})); matthiasm@8: end matthiasm@8: matthiasm@8: %Assign separators to the smallest cluster subsuming them. matthiasm@8: engine.cluster_ass_to_separator = zeros(S, 1); matthiasm@8: for s=1:S matthiasm@8: subsuming_clusters = []; matthiasm@8: %find smaunk matthiasm@8: matthiasm@8: for c=1:C matthiasm@8: if mysubset(separators{s}, clusters{c}) matthiasm@8: subsuming_clusters(end+1) = c; matthiasm@8: end matthiasm@8: end matthiasm@8: c = argmin(cl_sizes(subsuming_clusters)); matthiasm@8: engine.cluster_ass_to_separator(s) = subsuming_clusters(c); matthiasm@8: end matthiasm@8: matthiasm@8: %<<< FH matthiasm@8: matthiasm@8: engine.clq_ass_to_cluster = zeros(C, 2); matthiasm@8: for c=1:C matthiasm@8: engine.clq_ass_to_cluster(c,1) = clq_containing_nodes(engine.sub_engine, clusters{c}); matthiasm@8: engine.clq_ass_to_cluster(c,2) = clq_containing_nodes(engine.sub_engine, clusters{c}+ss); matthiasm@8: end matthiasm@8: engine.clusters = clusters; matthiasm@8: matthiasm@8: engine.clq_ass_to_node = zeros(ss, 2); matthiasm@8: for i=1:ss matthiasm@8: engine.clq_ass_to_node(i, 1) = clq_containing_nodes(engine.sub_engine, i); matthiasm@8: engine.clq_ass_to_node(i, 2) = clq_containing_nodes(engine.sub_engine, i+ss); matthiasm@8: end matthiasm@8: matthiasm@8: matthiasm@8: matthiasm@8: % Also create an engine just for slice 1 matthiasm@8: bnet1 = mk_bnet(bnet.intra1, bnet.node_sizes_slice, 'discrete', myintersect(bnet.dnodes, 1:ss), ... matthiasm@8: 'equiv_class', bnet.equiv_class(:,1), 'observed', onodes); matthiasm@8: for i=1:max(bnet1.equiv_class) matthiasm@8: bnet1.CPD{i} = bnet.CPD{i}; matthiasm@8: end matthiasm@8: matthiasm@8: engine.sub_engine1 = jtree_inf_engine(bnet1, 'clusters', clusters); matthiasm@8: matthiasm@8: engine.clq_ass_to_cluster1 = zeros(1,C); matthiasm@8: for c=1:C matthiasm@8: engine.clq_ass_to_cluster1(c) = clq_containing_nodes(engine.sub_engine1, clusters{c}); matthiasm@8: end matthiasm@8: matthiasm@8: engine.clq_ass_to_node1 = zeros(1, ss); matthiasm@8: for i=1:ss matthiasm@8: engine.clq_ass_to_node1(i) = clq_containing_nodes(engine.sub_engine1, i); matthiasm@8: end matthiasm@8: matthiasm@8: engine.clpot = []; % this is where we store the results between enter_evidence and marginal_nodes matthiasm@8: engine.filter = []; matthiasm@8: engine.maximize = []; matthiasm@8: engine.T = []; matthiasm@8: matthiasm@8: engine.bel = []; matthiasm@8: engine.bel_clpot = []; matthiasm@8: engine.slice1 = []; matthiasm@8: %engine.pot_type = 'cg'; matthiasm@8: % hack for online inference so we can cope with hidden Gaussians and discrete matthiasm@8: % it will not affect the pot type used in enter_evidence matthiasm@8: engine.pot_type = determine_pot_type(bnet, onodes); matthiasm@8: matthiasm@8: engine = class(engine, 'cbk_inf_engine', inf_engine(bnet)); matthiasm@8: matthiasm@8: matthiasm@8: matthiasm@8: matthiasm@8: function [cliques, seps, jt_size] = build_jt(cliques, vars, ns) matthiasm@8: % BUILD_JT connects the cliques into a jtree, computes the respective matthiasm@8: % separators and the size of the resulting jtree. matthiasm@8: % matthiasm@8: % [cliques, seps, jt_size] = build_jt(cliques, vars, ns) matthiasm@8: % ns(i) has to hold the size of vars(i) matthiasm@8: % vars has to be a superset of the union of cliques. matthiasm@8: matthiasm@8: %======== Compute the jtree with tool from BNT. This wants the vars to be 1:N. matthiasm@8: %==== Map from nodes to their indices. matthiasm@8: %disp('Computing jtree for cliques with vars and ns:'); matthiasm@8: %cliques matthiasm@8: %vars matthiasm@8: %ns' matthiasm@8: matthiasm@8: inv_nodes = sparse(1,max(vars)); matthiasm@8: N = length(vars); matthiasm@8: for i=1:N matthiasm@8: inv_nodes(vars(i)) = i; matthiasm@8: end matthiasm@8: matthiasm@8: tmp_cliques = cell(1,length(cliques)); matthiasm@8: %==== Temporarily map clique vars to their indices. matthiasm@8: for i=1:length(cliques) matthiasm@8: tmp_cliques{i} = inv_nodes(cliques{i}); matthiasm@8: end matthiasm@8: matthiasm@8: %=== Compute the jtree, using BNT. matthiasm@8: [jtree, root, B, w] = cliques_to_jtree(tmp_cliques, ns); matthiasm@8: matthiasm@8: matthiasm@8: %======== Now, compute the separators between connected cliques and their weights. matthiasm@8: seps = {}; matthiasm@8: s_w = []; matthiasm@8: [is,js] = find(jtree > 0); matthiasm@8: for k=1:length(is) matthiasm@8: i = is(k); j = js(k); matthiasm@8: sep = vars(find(B(i,:) & B(j,:))); % intersect(cliques{i}, cliques{j}); matthiasm@8: if i>j | length(sep) == 0, continue; end; matthiasm@8: seps{end+1} = sep; matthiasm@8: s_w(end+1) = prod(ns(inv_nodes(seps{end}))); matthiasm@8: end matthiasm@8: matthiasm@8: cl_w = sum(w); matthiasm@8: sep_w = sum(s_w); matthiasm@8: assert(cl_w > sep_w, 'Weight of cliques must be bigger than weight of separators'); matthiasm@8: matthiasm@8: jt_size = cl_w + sep_w; matthiasm@8: % jt.cliques = cliques; matthiasm@8: % jt.seps = seps; matthiasm@8: % jt.size = jt_size; matthiasm@8: % jt.ns = ns'; matthiasm@8: % jt;