annotate _FullBNT/BNT/inference/dynamic/@cbk_inf_engine/junk @ 9:4ea6619cb3f5 tip

removed log files
author matthiasm
date Fri, 11 Apr 2014 15:55:11 +0100
parents b5b38998ef3b
children
rev   line source
matthiasm@8 1 function engine = cbk_inf_engine(bnet, varargin)
matthiasm@8 2 % Just the same as bk_inf_engine, but you can specify overlapping clusters.
matthiasm@8 3
matthiasm@8 4 ss = length(bnet.intra);
matthiasm@8 5 % set default params
matthiasm@8 6 clusters = 'exact';
matthiasm@8 7
matthiasm@8 8 if nargin >= 2
matthiasm@8 9 args = varargin;
matthiasm@8 10 nargs = length(args);
matthiasm@8 11 for i=1:2:nargs
matthiasm@8 12 switch args{i},
matthiasm@8 13 case 'clusters', clusters = args{i+1};
matthiasm@8 14 otherwise, error(['unrecognized argument ' args{i}])
matthiasm@8 15 end
matthiasm@8 16 end
matthiasm@8 17 end
matthiasm@8 18
matthiasm@8 19 if strcmp(clusters, 'exact')
matthiasm@8 20 %clusters = { compute_interface_nodes(bnet.intra, bnet.inter) };
matthiasm@8 21 clusters = { 1:ss };
matthiasm@8 22 elseif strcmp(clusters, 'ff')
matthiasm@8 23 clusters = num2cell(1:ss);
matthiasm@8 24 end
matthiasm@8 25
matthiasm@8 26
matthiasm@8 27 % We need to insert the prior on the clusters in slice 1,
matthiasm@8 28 % and extract the posterior on the clusters in slice 2.
matthiasm@8 29 % We don't need to care about the separators, b/c they're subsets of the clusters.
matthiasm@8 30 C = length(clusters);
matthiasm@8 31 clusters2 = cell(1,2*C);
matthiasm@8 32 clusters2(1:C) = clusters;
matthiasm@8 33 for c=1:C
matthiasm@8 34 clusters2{c+C} = clusters{c} + ss;
matthiasm@8 35 end
matthiasm@8 36
matthiasm@8 37 onodes = bnet.observed;
matthiasm@8 38 obs_nodes = [onodes(:) onodes(:)+ss];
matthiasm@8 39 engine.sub_engine = jtree_inf_engine(bnet, 'clusters', clusters2);
matthiasm@8 40
matthiasm@8 41 %FH >>>
matthiasm@8 42 %Compute separators.
matthiasm@8 43 ns = bnet.node_sizes(:,1);
matthiasm@8 44 ns(onodes) = 1;
matthiasm@8 45 [clusters, separators] = build_jt(clusters, 1:length(ns), ns);
matthiasm@8 46 S = length(separators);
matthiasm@8 47 engine.separators = separators;
matthiasm@8 48
matthiasm@8 49 %Compute size of clusters.
matthiasm@8 50 cl_sizes = zeros(1,C);
matthiasm@8 51 for c=1:C
matthiasm@8 52 cl_sizes(c) = prod(ns(clusters{c}));
matthiasm@8 53 end
matthiasm@8 54
matthiasm@8 55 %Assign separators to the smallest cluster subsuming them.
matthiasm@8 56 engine.cluster_ass_to_separator = zeros(S, 1);
matthiasm@8 57 for s=1:S
matthiasm@8 58 subsuming_clusters = [];
matthiasm@8 59 %find smaunk
matthiasm@8 60
matthiasm@8 61 for c=1:C
matthiasm@8 62 if mysubset(separators{s}, clusters{c})
matthiasm@8 63 subsuming_clusters(end+1) = c;
matthiasm@8 64 end
matthiasm@8 65 end
matthiasm@8 66 c = argmin(cl_sizes(subsuming_clusters));
matthiasm@8 67 engine.cluster_ass_to_separator(s) = subsuming_clusters(c);
matthiasm@8 68 end
matthiasm@8 69
matthiasm@8 70 %<<< FH
matthiasm@8 71
matthiasm@8 72 engine.clq_ass_to_cluster = zeros(C, 2);
matthiasm@8 73 for c=1:C
matthiasm@8 74 engine.clq_ass_to_cluster(c,1) = clq_containing_nodes(engine.sub_engine, clusters{c});
matthiasm@8 75 engine.clq_ass_to_cluster(c,2) = clq_containing_nodes(engine.sub_engine, clusters{c}+ss);
matthiasm@8 76 end
matthiasm@8 77 engine.clusters = clusters;
matthiasm@8 78
matthiasm@8 79 engine.clq_ass_to_node = zeros(ss, 2);
matthiasm@8 80 for i=1:ss
matthiasm@8 81 engine.clq_ass_to_node(i, 1) = clq_containing_nodes(engine.sub_engine, i);
matthiasm@8 82 engine.clq_ass_to_node(i, 2) = clq_containing_nodes(engine.sub_engine, i+ss);
matthiasm@8 83 end
matthiasm@8 84
matthiasm@8 85
matthiasm@8 86
matthiasm@8 87 % Also create an engine just for slice 1
matthiasm@8 88 bnet1 = mk_bnet(bnet.intra1, bnet.node_sizes_slice, 'discrete', myintersect(bnet.dnodes, 1:ss), ...
matthiasm@8 89 'equiv_class', bnet.equiv_class(:,1), 'observed', onodes);
matthiasm@8 90 for i=1:max(bnet1.equiv_class)
matthiasm@8 91 bnet1.CPD{i} = bnet.CPD{i};
matthiasm@8 92 end
matthiasm@8 93
matthiasm@8 94 engine.sub_engine1 = jtree_inf_engine(bnet1, 'clusters', clusters);
matthiasm@8 95
matthiasm@8 96 engine.clq_ass_to_cluster1 = zeros(1,C);
matthiasm@8 97 for c=1:C
matthiasm@8 98 engine.clq_ass_to_cluster1(c) = clq_containing_nodes(engine.sub_engine1, clusters{c});
matthiasm@8 99 end
matthiasm@8 100
matthiasm@8 101 engine.clq_ass_to_node1 = zeros(1, ss);
matthiasm@8 102 for i=1:ss
matthiasm@8 103 engine.clq_ass_to_node1(i) = clq_containing_nodes(engine.sub_engine1, i);
matthiasm@8 104 end
matthiasm@8 105
matthiasm@8 106 engine.clpot = []; % this is where we store the results between enter_evidence and marginal_nodes
matthiasm@8 107 engine.filter = [];
matthiasm@8 108 engine.maximize = [];
matthiasm@8 109 engine.T = [];
matthiasm@8 110
matthiasm@8 111 engine.bel = [];
matthiasm@8 112 engine.bel_clpot = [];
matthiasm@8 113 engine.slice1 = [];
matthiasm@8 114 %engine.pot_type = 'cg';
matthiasm@8 115 % hack for online inference so we can cope with hidden Gaussians and discrete
matthiasm@8 116 % it will not affect the pot type used in enter_evidence
matthiasm@8 117 engine.pot_type = determine_pot_type(bnet, onodes);
matthiasm@8 118
matthiasm@8 119 engine = class(engine, 'cbk_inf_engine', inf_engine(bnet));
matthiasm@8 120
matthiasm@8 121
matthiasm@8 122
matthiasm@8 123
matthiasm@8 124 function [cliques, seps, jt_size] = build_jt(cliques, vars, ns)
matthiasm@8 125 % BUILD_JT connects the cliques into a jtree, computes the respective
matthiasm@8 126 % separators and the size of the resulting jtree.
matthiasm@8 127 %
matthiasm@8 128 % [cliques, seps, jt_size] = build_jt(cliques, vars, ns)
matthiasm@8 129 % ns(i) has to hold the size of vars(i)
matthiasm@8 130 % vars has to be a superset of the union of cliques.
matthiasm@8 131
matthiasm@8 132 %======== Compute the jtree with tool from BNT. This wants the vars to be 1:N.
matthiasm@8 133 %==== Map from nodes to their indices.
matthiasm@8 134 %disp('Computing jtree for cliques with vars and ns:');
matthiasm@8 135 %cliques
matthiasm@8 136 %vars
matthiasm@8 137 %ns'
matthiasm@8 138
matthiasm@8 139 inv_nodes = sparse(1,max(vars));
matthiasm@8 140 N = length(vars);
matthiasm@8 141 for i=1:N
matthiasm@8 142 inv_nodes(vars(i)) = i;
matthiasm@8 143 end
matthiasm@8 144
matthiasm@8 145 tmp_cliques = cell(1,length(cliques));
matthiasm@8 146 %==== Temporarily map clique vars to their indices.
matthiasm@8 147 for i=1:length(cliques)
matthiasm@8 148 tmp_cliques{i} = inv_nodes(cliques{i});
matthiasm@8 149 end
matthiasm@8 150
matthiasm@8 151 %=== Compute the jtree, using BNT.
matthiasm@8 152 [jtree, root, B, w] = cliques_to_jtree(tmp_cliques, ns);
matthiasm@8 153
matthiasm@8 154
matthiasm@8 155 %======== Now, compute the separators between connected cliques and their weights.
matthiasm@8 156 seps = {};
matthiasm@8 157 s_w = [];
matthiasm@8 158 [is,js] = find(jtree > 0);
matthiasm@8 159 for k=1:length(is)
matthiasm@8 160 i = is(k); j = js(k);
matthiasm@8 161 sep = vars(find(B(i,:) & B(j,:))); % intersect(cliques{i}, cliques{j});
matthiasm@8 162 if i>j | length(sep) == 0, continue; end;
matthiasm@8 163 seps{end+1} = sep;
matthiasm@8 164 s_w(end+1) = prod(ns(inv_nodes(seps{end})));
matthiasm@8 165 end
matthiasm@8 166
matthiasm@8 167 cl_w = sum(w);
matthiasm@8 168 sep_w = sum(s_w);
matthiasm@8 169 assert(cl_w > sep_w, 'Weight of cliques must be bigger than weight of separators');
matthiasm@8 170
matthiasm@8 171 jt_size = cl_w + sep_w;
matthiasm@8 172 % jt.cliques = cliques;
matthiasm@8 173 % jt.seps = seps;
matthiasm@8 174 % jt.size = jt_size;
matthiasm@8 175 % jt.ns = ns';
matthiasm@8 176 % jt;