wolffd@0
|
1 function engine = cbk_inf_engine(bnet, varargin)
|
wolffd@0
|
2 % Just the same as bk_inf_engine, but you can specify overlapping clusters.
|
wolffd@0
|
3
|
wolffd@0
|
4 ss = length(bnet.intra);
|
wolffd@0
|
5 % set default params
|
wolffd@0
|
6 clusters = 'exact';
|
wolffd@0
|
7
|
wolffd@0
|
8 if nargin >= 2
|
wolffd@0
|
9 args = varargin;
|
wolffd@0
|
10 nargs = length(args);
|
wolffd@0
|
11 for i=1:2:nargs
|
wolffd@0
|
12 switch args{i},
|
wolffd@0
|
13 case 'clusters', clusters = args{i+1};
|
wolffd@0
|
14 otherwise, error(['unrecognized argument ' args{i}])
|
wolffd@0
|
15 end
|
wolffd@0
|
16 end
|
wolffd@0
|
17 end
|
wolffd@0
|
18
|
wolffd@0
|
19 if strcmp(clusters, 'exact')
|
wolffd@0
|
20 %clusters = { compute_interface_nodes(bnet.intra, bnet.inter) };
|
wolffd@0
|
21 clusters = { 1:ss };
|
wolffd@0
|
22 elseif strcmp(clusters, 'ff')
|
wolffd@0
|
23 clusters = num2cell(1:ss);
|
wolffd@0
|
24 end
|
wolffd@0
|
25
|
wolffd@0
|
26
|
wolffd@0
|
27 % We need to insert the prior on the clusters in slice 1,
|
wolffd@0
|
28 % and extract the posterior on the clusters in slice 2.
|
wolffd@0
|
29 % We don't need to care about the separators, b/c they're subsets of the clusters.
|
wolffd@0
|
30 C = length(clusters);
|
wolffd@0
|
31 clusters2 = cell(1,2*C);
|
wolffd@0
|
32 clusters2(1:C) = clusters;
|
wolffd@0
|
33 for c=1:C
|
wolffd@0
|
34 clusters2{c+C} = clusters{c} + ss;
|
wolffd@0
|
35 end
|
wolffd@0
|
36
|
wolffd@0
|
37 onodes = bnet.observed;
|
wolffd@0
|
38 obs_nodes = [onodes(:) onodes(:)+ss];
|
wolffd@0
|
39 engine.sub_engine = jtree_inf_engine(bnet, 'clusters', clusters2);
|
wolffd@0
|
40
|
wolffd@0
|
41 %FH >>>
|
wolffd@0
|
42 %Compute separators.
|
wolffd@0
|
43 ns = bnet.node_sizes(:,1);
|
wolffd@0
|
44 ns(onodes) = 1;
|
wolffd@0
|
45 [clusters, separators] = build_jt(clusters, 1:length(ns), ns);
|
wolffd@0
|
46 S = length(separators);
|
wolffd@0
|
47 engine.separators = separators;
|
wolffd@0
|
48
|
wolffd@0
|
49 %Compute size of clusters.
|
wolffd@0
|
50 cl_sizes = zeros(1,C);
|
wolffd@0
|
51 for c=1:C
|
wolffd@0
|
52 cl_sizes(c) = prod(ns(clusters{c}));
|
wolffd@0
|
53 end
|
wolffd@0
|
54
|
wolffd@0
|
55 %Assign separators to the smallest cluster subsuming them.
|
wolffd@0
|
56 engine.cluster_ass_to_separator = zeros(S, 1);
|
wolffd@0
|
57 for s=1:S
|
wolffd@0
|
58 subsuming_clusters = [];
|
wolffd@0
|
59 %find smaunk
|
wolffd@0
|
60
|
wolffd@0
|
61 for c=1:C
|
wolffd@0
|
62 if mysubset(separators{s}, clusters{c})
|
wolffd@0
|
63 subsuming_clusters(end+1) = c;
|
wolffd@0
|
64 end
|
wolffd@0
|
65 end
|
wolffd@0
|
66 c = argmin(cl_sizes(subsuming_clusters));
|
wolffd@0
|
67 engine.cluster_ass_to_separator(s) = subsuming_clusters(c);
|
wolffd@0
|
68 end
|
wolffd@0
|
69
|
wolffd@0
|
70 %<<< FH
|
wolffd@0
|
71
|
wolffd@0
|
72 engine.clq_ass_to_cluster = zeros(C, 2);
|
wolffd@0
|
73 for c=1:C
|
wolffd@0
|
74 engine.clq_ass_to_cluster(c,1) = clq_containing_nodes(engine.sub_engine, clusters{c});
|
wolffd@0
|
75 engine.clq_ass_to_cluster(c,2) = clq_containing_nodes(engine.sub_engine, clusters{c}+ss);
|
wolffd@0
|
76 end
|
wolffd@0
|
77 engine.clusters = clusters;
|
wolffd@0
|
78
|
wolffd@0
|
79 engine.clq_ass_to_node = zeros(ss, 2);
|
wolffd@0
|
80 for i=1:ss
|
wolffd@0
|
81 engine.clq_ass_to_node(i, 1) = clq_containing_nodes(engine.sub_engine, i);
|
wolffd@0
|
82 engine.clq_ass_to_node(i, 2) = clq_containing_nodes(engine.sub_engine, i+ss);
|
wolffd@0
|
83 end
|
wolffd@0
|
84
|
wolffd@0
|
85
|
wolffd@0
|
86
|
wolffd@0
|
87 % Also create an engine just for slice 1
|
wolffd@0
|
88 bnet1 = mk_bnet(bnet.intra1, bnet.node_sizes_slice, 'discrete', myintersect(bnet.dnodes, 1:ss), ...
|
wolffd@0
|
89 'equiv_class', bnet.equiv_class(:,1), 'observed', onodes);
|
wolffd@0
|
90 for i=1:max(bnet1.equiv_class)
|
wolffd@0
|
91 bnet1.CPD{i} = bnet.CPD{i};
|
wolffd@0
|
92 end
|
wolffd@0
|
93
|
wolffd@0
|
94 engine.sub_engine1 = jtree_inf_engine(bnet1, 'clusters', clusters);
|
wolffd@0
|
95
|
wolffd@0
|
96 engine.clq_ass_to_cluster1 = zeros(1,C);
|
wolffd@0
|
97 for c=1:C
|
wolffd@0
|
98 engine.clq_ass_to_cluster1(c) = clq_containing_nodes(engine.sub_engine1, clusters{c});
|
wolffd@0
|
99 end
|
wolffd@0
|
100
|
wolffd@0
|
101 engine.clq_ass_to_node1 = zeros(1, ss);
|
wolffd@0
|
102 for i=1:ss
|
wolffd@0
|
103 engine.clq_ass_to_node1(i) = clq_containing_nodes(engine.sub_engine1, i);
|
wolffd@0
|
104 end
|
wolffd@0
|
105
|
wolffd@0
|
106 engine.clpot = []; % this is where we store the results between enter_evidence and marginal_nodes
|
wolffd@0
|
107 engine.filter = [];
|
wolffd@0
|
108 engine.maximize = [];
|
wolffd@0
|
109 engine.T = [];
|
wolffd@0
|
110
|
wolffd@0
|
111 engine.bel = [];
|
wolffd@0
|
112 engine.bel_clpot = [];
|
wolffd@0
|
113 engine.slice1 = [];
|
wolffd@0
|
114 %engine.pot_type = 'cg';
|
wolffd@0
|
115 % hack for online inference so we can cope with hidden Gaussians and discrete
|
wolffd@0
|
116 % it will not affect the pot type used in enter_evidence
|
wolffd@0
|
117 engine.pot_type = determine_pot_type(bnet, onodes);
|
wolffd@0
|
118
|
wolffd@0
|
119 engine = class(engine, 'cbk_inf_engine', inf_engine(bnet));
|
wolffd@0
|
120
|
wolffd@0
|
121
|
wolffd@0
|
122
|
wolffd@0
|
123
|
wolffd@0
|
124 function [cliques, seps, jt_size] = build_jt(cliques, vars, ns)
|
wolffd@0
|
125 % BUILD_JT connects the cliques into a jtree, computes the respective
|
wolffd@0
|
126 % separators and the size of the resulting jtree.
|
wolffd@0
|
127 %
|
wolffd@0
|
128 % [cliques, seps, jt_size] = build_jt(cliques, vars, ns)
|
wolffd@0
|
129 % ns(i) has to hold the size of vars(i)
|
wolffd@0
|
130 % vars has to be a superset of the union of cliques.
|
wolffd@0
|
131
|
wolffd@0
|
132 %======== Compute the jtree with tool from BNT. This wants the vars to be 1:N.
|
wolffd@0
|
133 %==== Map from nodes to their indices.
|
wolffd@0
|
134 %disp('Computing jtree for cliques with vars and ns:');
|
wolffd@0
|
135 %cliques
|
wolffd@0
|
136 %vars
|
wolffd@0
|
137 %ns'
|
wolffd@0
|
138
|
wolffd@0
|
139 inv_nodes = sparse(1,max(vars));
|
wolffd@0
|
140 N = length(vars);
|
wolffd@0
|
141 for i=1:N
|
wolffd@0
|
142 inv_nodes(vars(i)) = i;
|
wolffd@0
|
143 end
|
wolffd@0
|
144
|
wolffd@0
|
145 tmp_cliques = cell(1,length(cliques));
|
wolffd@0
|
146 %==== Temporarily map clique vars to their indices.
|
wolffd@0
|
147 for i=1:length(cliques)
|
wolffd@0
|
148 tmp_cliques{i} = inv_nodes(cliques{i});
|
wolffd@0
|
149 end
|
wolffd@0
|
150
|
wolffd@0
|
151 %=== Compute the jtree, using BNT.
|
wolffd@0
|
152 [jtree, root, B, w] = cliques_to_jtree(tmp_cliques, ns);
|
wolffd@0
|
153
|
wolffd@0
|
154
|
wolffd@0
|
155 %======== Now, compute the separators between connected cliques and their weights.
|
wolffd@0
|
156 seps = {};
|
wolffd@0
|
157 s_w = [];
|
wolffd@0
|
158 [is,js] = find(jtree > 0);
|
wolffd@0
|
159 for k=1:length(is)
|
wolffd@0
|
160 i = is(k); j = js(k);
|
wolffd@0
|
161 sep = vars(find(B(i,:) & B(j,:))); % intersect(cliques{i}, cliques{j});
|
wolffd@0
|
162 if i>j | length(sep) == 0, continue; end;
|
wolffd@0
|
163 seps{end+1} = sep;
|
wolffd@0
|
164 s_w(end+1) = prod(ns(inv_nodes(seps{end})));
|
wolffd@0
|
165 end
|
wolffd@0
|
166
|
wolffd@0
|
167 cl_w = sum(w);
|
wolffd@0
|
168 sep_w = sum(s_w);
|
wolffd@0
|
169 assert(cl_w > sep_w, 'Weight of cliques must be bigger than weight of separators');
|
wolffd@0
|
170
|
wolffd@0
|
171 jt_size = cl_w + sep_w;
|
wolffd@0
|
172 % jt.cliques = cliques;
|
wolffd@0
|
173 % jt.seps = seps;
|
wolffd@0
|
174 % jt.size = jt_size;
|
wolffd@0
|
175 % jt.ns = ns';
|
wolffd@0
|
176 % jt;
|