wolffd@0: function engine = cbk_inf_engine(bnet, varargin) wolffd@0: % Just the same as bk_inf_engine, but you can specify overlapping clusters. wolffd@0: wolffd@0: ss = length(bnet.intra); wolffd@0: % set default params wolffd@0: clusters = 'exact'; wolffd@0: wolffd@0: if nargin >= 2 wolffd@0: args = varargin; wolffd@0: nargs = length(args); wolffd@0: for i=1:2:nargs wolffd@0: switch args{i}, wolffd@0: case 'clusters', clusters = args{i+1}; wolffd@0: otherwise, error(['unrecognized argument ' args{i}]) wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if strcmp(clusters, 'exact') wolffd@0: %clusters = { compute_interface_nodes(bnet.intra, bnet.inter) }; wolffd@0: clusters = { 1:ss }; wolffd@0: elseif strcmp(clusters, 'ff') wolffd@0: clusters = num2cell(1:ss); wolffd@0: end wolffd@0: wolffd@0: wolffd@0: % We need to insert the prior on the clusters in slice 1, wolffd@0: % and extract the posterior on the clusters in slice 2. wolffd@0: % We don't need to care about the separators, b/c they're subsets of the clusters. wolffd@0: C = length(clusters); wolffd@0: clusters2 = cell(1,2*C); wolffd@0: clusters2(1:C) = clusters; wolffd@0: for c=1:C wolffd@0: clusters2{c+C} = clusters{c} + ss; wolffd@0: end wolffd@0: wolffd@0: onodes = bnet.observed; wolffd@0: obs_nodes = [onodes(:) onodes(:)+ss]; wolffd@0: engine.sub_engine = jtree_inf_engine(bnet, 'clusters', clusters2); wolffd@0: wolffd@0: %FH >>> wolffd@0: %Compute separators. wolffd@0: ns = bnet.node_sizes(:,1); wolffd@0: ns(onodes) = 1; wolffd@0: [clusters, separators] = build_jt(clusters, 1:length(ns), ns); wolffd@0: S = length(separators); wolffd@0: engine.separators = separators; wolffd@0: wolffd@0: %Compute size of clusters. wolffd@0: cl_sizes = zeros(1,C); wolffd@0: for c=1:C wolffd@0: cl_sizes(c) = prod(ns(clusters{c})); wolffd@0: end wolffd@0: wolffd@0: %Assign separators to the smallest cluster subsuming them. wolffd@0: engine.cluster_ass_to_separator = zeros(S, 1); wolffd@0: for s=1:S wolffd@0: subsuming_clusters = []; wolffd@0: %find smaunk wolffd@0: wolffd@0: for c=1:C wolffd@0: if mysubset(separators{s}, clusters{c}) wolffd@0: subsuming_clusters(end+1) = c; wolffd@0: end wolffd@0: end wolffd@0: c = argmin(cl_sizes(subsuming_clusters)); wolffd@0: engine.cluster_ass_to_separator(s) = subsuming_clusters(c); wolffd@0: end wolffd@0: wolffd@0: %<<< FH wolffd@0: wolffd@0: engine.clq_ass_to_cluster = zeros(C, 2); wolffd@0: for c=1:C wolffd@0: engine.clq_ass_to_cluster(c,1) = clq_containing_nodes(engine.sub_engine, clusters{c}); wolffd@0: engine.clq_ass_to_cluster(c,2) = clq_containing_nodes(engine.sub_engine, clusters{c}+ss); wolffd@0: end wolffd@0: engine.clusters = clusters; wolffd@0: wolffd@0: engine.clq_ass_to_node = zeros(ss, 2); wolffd@0: for i=1:ss wolffd@0: engine.clq_ass_to_node(i, 1) = clq_containing_nodes(engine.sub_engine, i); wolffd@0: engine.clq_ass_to_node(i, 2) = clq_containing_nodes(engine.sub_engine, i+ss); wolffd@0: end wolffd@0: wolffd@0: wolffd@0: wolffd@0: % Also create an engine just for slice 1 wolffd@0: bnet1 = mk_bnet(bnet.intra1, bnet.node_sizes_slice, 'discrete', myintersect(bnet.dnodes, 1:ss), ... wolffd@0: 'equiv_class', bnet.equiv_class(:,1), 'observed', onodes); wolffd@0: for i=1:max(bnet1.equiv_class) wolffd@0: bnet1.CPD{i} = bnet.CPD{i}; wolffd@0: end wolffd@0: wolffd@0: engine.sub_engine1 = jtree_inf_engine(bnet1, 'clusters', clusters); wolffd@0: wolffd@0: engine.clq_ass_to_cluster1 = zeros(1,C); wolffd@0: for c=1:C wolffd@0: engine.clq_ass_to_cluster1(c) = clq_containing_nodes(engine.sub_engine1, clusters{c}); wolffd@0: end wolffd@0: wolffd@0: engine.clq_ass_to_node1 = zeros(1, ss); wolffd@0: for i=1:ss wolffd@0: engine.clq_ass_to_node1(i) = clq_containing_nodes(engine.sub_engine1, i); wolffd@0: end wolffd@0: wolffd@0: engine.clpot = []; % this is where we store the results between enter_evidence and marginal_nodes wolffd@0: engine.filter = []; wolffd@0: engine.maximize = []; wolffd@0: engine.T = []; wolffd@0: wolffd@0: engine.bel = []; wolffd@0: engine.bel_clpot = []; wolffd@0: engine.slice1 = []; wolffd@0: %engine.pot_type = 'cg'; wolffd@0: % hack for online inference so we can cope with hidden Gaussians and discrete wolffd@0: % it will not affect the pot type used in enter_evidence wolffd@0: engine.pot_type = determine_pot_type(bnet, onodes); wolffd@0: wolffd@0: engine = class(engine, 'cbk_inf_engine', inf_engine(bnet)); wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: function [cliques, seps, jt_size] = build_jt(cliques, vars, ns) wolffd@0: % BUILD_JT connects the cliques into a jtree, computes the respective wolffd@0: % separators and the size of the resulting jtree. wolffd@0: % wolffd@0: % [cliques, seps, jt_size] = build_jt(cliques, vars, ns) wolffd@0: % ns(i) has to hold the size of vars(i) wolffd@0: % vars has to be a superset of the union of cliques. wolffd@0: wolffd@0: %======== Compute the jtree with tool from BNT. This wants the vars to be 1:N. wolffd@0: %==== Map from nodes to their indices. wolffd@0: %disp('Computing jtree for cliques with vars and ns:'); wolffd@0: %cliques wolffd@0: %vars wolffd@0: %ns' wolffd@0: wolffd@0: inv_nodes = sparse(1,max(vars)); wolffd@0: N = length(vars); wolffd@0: for i=1:N wolffd@0: inv_nodes(vars(i)) = i; wolffd@0: end wolffd@0: wolffd@0: tmp_cliques = cell(1,length(cliques)); wolffd@0: %==== Temporarily map clique vars to their indices. wolffd@0: for i=1:length(cliques) wolffd@0: tmp_cliques{i} = inv_nodes(cliques{i}); wolffd@0: end wolffd@0: wolffd@0: %=== Compute the jtree, using BNT. wolffd@0: [jtree, root, B, w] = cliques_to_jtree(tmp_cliques, ns); wolffd@0: wolffd@0: wolffd@0: %======== Now, compute the separators between connected cliques and their weights. wolffd@0: seps = {}; wolffd@0: s_w = []; wolffd@0: [is,js] = find(jtree > 0); wolffd@0: for k=1:length(is) wolffd@0: i = is(k); j = js(k); wolffd@0: sep = vars(find(B(i,:) & B(j,:))); % intersect(cliques{i}, cliques{j}); wolffd@0: if i>j | length(sep) == 0, continue; end; wolffd@0: seps{end+1} = sep; wolffd@0: s_w(end+1) = prod(ns(inv_nodes(seps{end}))); wolffd@0: end wolffd@0: wolffd@0: cl_w = sum(w); wolffd@0: sep_w = sum(s_w); wolffd@0: assert(cl_w > sep_w, 'Weight of cliques must be bigger than weight of separators'); wolffd@0: wolffd@0: jt_size = cl_w + sep_w; wolffd@0: % jt.cliques = cliques; wolffd@0: % jt.seps = seps; wolffd@0: % jt.size = jt_size; wolffd@0: % jt.ns = ns'; wolffd@0: % jt;