Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@cbk_inf_engine/junk @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function engine = cbk_inf_engine(bnet, varargin) | |
2 % Just the same as bk_inf_engine, but you can specify overlapping clusters. | |
3 | |
4 ss = length(bnet.intra); | |
5 % set default params | |
6 clusters = 'exact'; | |
7 | |
8 if nargin >= 2 | |
9 args = varargin; | |
10 nargs = length(args); | |
11 for i=1:2:nargs | |
12 switch args{i}, | |
13 case 'clusters', clusters = args{i+1}; | |
14 otherwise, error(['unrecognized argument ' args{i}]) | |
15 end | |
16 end | |
17 end | |
18 | |
19 if strcmp(clusters, 'exact') | |
20 %clusters = { compute_interface_nodes(bnet.intra, bnet.inter) }; | |
21 clusters = { 1:ss }; | |
22 elseif strcmp(clusters, 'ff') | |
23 clusters = num2cell(1:ss); | |
24 end | |
25 | |
26 | |
27 % We need to insert the prior on the clusters in slice 1, | |
28 % and extract the posterior on the clusters in slice 2. | |
29 % We don't need to care about the separators, b/c they're subsets of the clusters. | |
30 C = length(clusters); | |
31 clusters2 = cell(1,2*C); | |
32 clusters2(1:C) = clusters; | |
33 for c=1:C | |
34 clusters2{c+C} = clusters{c} + ss; | |
35 end | |
36 | |
37 onodes = bnet.observed; | |
38 obs_nodes = [onodes(:) onodes(:)+ss]; | |
39 engine.sub_engine = jtree_inf_engine(bnet, 'clusters', clusters2); | |
40 | |
41 %FH >>> | |
42 %Compute separators. | |
43 ns = bnet.node_sizes(:,1); | |
44 ns(onodes) = 1; | |
45 [clusters, separators] = build_jt(clusters, 1:length(ns), ns); | |
46 S = length(separators); | |
47 engine.separators = separators; | |
48 | |
49 %Compute size of clusters. | |
50 cl_sizes = zeros(1,C); | |
51 for c=1:C | |
52 cl_sizes(c) = prod(ns(clusters{c})); | |
53 end | |
54 | |
55 %Assign separators to the smallest cluster subsuming them. | |
56 engine.cluster_ass_to_separator = zeros(S, 1); | |
57 for s=1:S | |
58 subsuming_clusters = []; | |
59 %find smaunk | |
60 | |
61 for c=1:C | |
62 if mysubset(separators{s}, clusters{c}) | |
63 subsuming_clusters(end+1) = c; | |
64 end | |
65 end | |
66 c = argmin(cl_sizes(subsuming_clusters)); | |
67 engine.cluster_ass_to_separator(s) = subsuming_clusters(c); | |
68 end | |
69 | |
70 %<<< FH | |
71 | |
72 engine.clq_ass_to_cluster = zeros(C, 2); | |
73 for c=1:C | |
74 engine.clq_ass_to_cluster(c,1) = clq_containing_nodes(engine.sub_engine, clusters{c}); | |
75 engine.clq_ass_to_cluster(c,2) = clq_containing_nodes(engine.sub_engine, clusters{c}+ss); | |
76 end | |
77 engine.clusters = clusters; | |
78 | |
79 engine.clq_ass_to_node = zeros(ss, 2); | |
80 for i=1:ss | |
81 engine.clq_ass_to_node(i, 1) = clq_containing_nodes(engine.sub_engine, i); | |
82 engine.clq_ass_to_node(i, 2) = clq_containing_nodes(engine.sub_engine, i+ss); | |
83 end | |
84 | |
85 | |
86 | |
87 % Also create an engine just for slice 1 | |
88 bnet1 = mk_bnet(bnet.intra1, bnet.node_sizes_slice, 'discrete', myintersect(bnet.dnodes, 1:ss), ... | |
89 'equiv_class', bnet.equiv_class(:,1), 'observed', onodes); | |
90 for i=1:max(bnet1.equiv_class) | |
91 bnet1.CPD{i} = bnet.CPD{i}; | |
92 end | |
93 | |
94 engine.sub_engine1 = jtree_inf_engine(bnet1, 'clusters', clusters); | |
95 | |
96 engine.clq_ass_to_cluster1 = zeros(1,C); | |
97 for c=1:C | |
98 engine.clq_ass_to_cluster1(c) = clq_containing_nodes(engine.sub_engine1, clusters{c}); | |
99 end | |
100 | |
101 engine.clq_ass_to_node1 = zeros(1, ss); | |
102 for i=1:ss | |
103 engine.clq_ass_to_node1(i) = clq_containing_nodes(engine.sub_engine1, i); | |
104 end | |
105 | |
106 engine.clpot = []; % this is where we store the results between enter_evidence and marginal_nodes | |
107 engine.filter = []; | |
108 engine.maximize = []; | |
109 engine.T = []; | |
110 | |
111 engine.bel = []; | |
112 engine.bel_clpot = []; | |
113 engine.slice1 = []; | |
114 %engine.pot_type = 'cg'; | |
115 % hack for online inference so we can cope with hidden Gaussians and discrete | |
116 % it will not affect the pot type used in enter_evidence | |
117 engine.pot_type = determine_pot_type(bnet, onodes); | |
118 | |
119 engine = class(engine, 'cbk_inf_engine', inf_engine(bnet)); | |
120 | |
121 | |
122 | |
123 | |
124 function [cliques, seps, jt_size] = build_jt(cliques, vars, ns) | |
125 % BUILD_JT connects the cliques into a jtree, computes the respective | |
126 % separators and the size of the resulting jtree. | |
127 % | |
128 % [cliques, seps, jt_size] = build_jt(cliques, vars, ns) | |
129 % ns(i) has to hold the size of vars(i) | |
130 % vars has to be a superset of the union of cliques. | |
131 | |
132 %======== Compute the jtree with tool from BNT. This wants the vars to be 1:N. | |
133 %==== Map from nodes to their indices. | |
134 %disp('Computing jtree for cliques with vars and ns:'); | |
135 %cliques | |
136 %vars | |
137 %ns' | |
138 | |
139 inv_nodes = sparse(1,max(vars)); | |
140 N = length(vars); | |
141 for i=1:N | |
142 inv_nodes(vars(i)) = i; | |
143 end | |
144 | |
145 tmp_cliques = cell(1,length(cliques)); | |
146 %==== Temporarily map clique vars to their indices. | |
147 for i=1:length(cliques) | |
148 tmp_cliques{i} = inv_nodes(cliques{i}); | |
149 end | |
150 | |
151 %=== Compute the jtree, using BNT. | |
152 [jtree, root, B, w] = cliques_to_jtree(tmp_cliques, ns); | |
153 | |
154 | |
155 %======== Now, compute the separators between connected cliques and their weights. | |
156 seps = {}; | |
157 s_w = []; | |
158 [is,js] = find(jtree > 0); | |
159 for k=1:length(is) | |
160 i = is(k); j = js(k); | |
161 sep = vars(find(B(i,:) & B(j,:))); % intersect(cliques{i}, cliques{j}); | |
162 if i>j | length(sep) == 0, continue; end; | |
163 seps{end+1} = sep; | |
164 s_w(end+1) = prod(ns(inv_nodes(seps{end}))); | |
165 end | |
166 | |
167 cl_w = sum(w); | |
168 sep_w = sum(s_w); | |
169 assert(cl_w > sep_w, 'Weight of cliques must be bigger than weight of separators'); | |
170 | |
171 jt_size = cl_w + sep_w; | |
172 % jt.cliques = cliques; | |
173 % jt.seps = seps; | |
174 % jt.size = jt_size; | |
175 % jt.ns = ns'; | |
176 % jt; |