comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_inf_engine/Old/enter_evidence1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function engine = enter_evidence(engine, evidence)
2
3 doms = engine.fgraph.doms;
4 ndoms = length(doms);
5 ns = engine.fgraph.node_sizes;
6 obs = find(~isemptycell(evidence));
7 cobs = myintersect(obs, engine.fgraph.cnodes);
8 dobs = myintersect(obs, engine.fgraph.dnodes);
9 ns(cobs) = 0;
10 ns(dobs) = 1;
11
12 % recompute the weight of each domain now that we know what nodes are observed
13 for i=1:ndoms
14 engine.dom_weight(i) = prod(ns(engine.fgraph.doms{i}));
15 end
16
17 % prime each local kernel with evidence (if any)
18 local_kernel = cell(1, ndoms);
19 for i=1:length(engine.fgraph.kernels_of_type)
20 u = engine.fgraph.kernels_of_type{i};
21 local_kernel(u) = kernel_to_dpots(engine.fgraph.kernels{i}, evidence, engine.fgraph.domains_of_type{i});
22 end
23
24 % initialise all msgs to 1s
25 msg = cell(ndoms, ndoms);
26 for i=1:ndoms
27 nbrs = engine.fgraph.nbrs{i};
28 for j=nbrs(:)'
29 dom = engine.fgraph.sepset{i,j};
30 msg{i,j} = dpot(dom, ns(dom));
31 end
32 end
33
34 prod_of_msg = cell(1, ndoms);
35 bel = cell(1, ndoms);
36 old_bel = cell(1, ndoms);
37
38 converged = 0;
39 iter = 1;
40 while ~converged & (iter <= engine.max_iter)
41
42 % each node multiplies all its incoming msgs
43 for i=1:ndoms
44 prod_of_msg{i} = dpot(doms{i}, ns(doms{i}));
45 nbrs = engine.fgraph.nbrs{i};
46 for j=nbrs(:)'
47 prod_of_msg{i} = multiply_by_pot(prod_of_msg{i}, msg{j,i});
48 end
49 end
50
51 % each node computes its local belief
52 old_bel = bel;
53 for i=1:ndoms
54 bel{i} = normalize_pot(multiply_pots(prod_of_msg{i}, local_kernel{i}));
55 end
56
57 % converged?
58 if iter==1
59 converged = 0;
60 else
61 converged = 1;
62 for i=1:ndoms
63 belT = get_params(bel{i}, 'table');
64 old_belT = get_params(old_bel{i}, 'table');
65 if ~approxeq(belT, old_belT, engine.tol)
66 converged = 0;
67 break;
68 end
69 end
70 end
71
72 if ~converged
73 old_msg = msg;
74 % each node sends a msg to each of its neighbors
75 for i=1:ndoms
76 nbrs = engine.fgraph.nbrs{i};
77 for j=nbrs(:)'
78 % multiply all incoming msgs except from j
79 temp = prod_of_msg{i};
80 temp = divide_by_pot(temp, old_msg{j,i});
81 % send msg from i to j
82 temp = multiply_by_pot(temp, local_kernel{i});
83 msg{i,j} = normalize_pot(marginalize_pot(temp, engine.fgraph.sepset{i,j}));
84 end
85 end
86 end
87
88 iter = iter + 1
89 end
90
91 engine.marginal_domains = bel;
92 %for i=1:ndoms
93 %engine.marginal_domains{i} = get_params(bel{i}, 'table');
94 %end