Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_inf_engine/Old/enter_evidence1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function engine = enter_evidence(engine, evidence) | |
2 | |
3 doms = engine.fgraph.doms; | |
4 ndoms = length(doms); | |
5 ns = engine.fgraph.node_sizes; | |
6 obs = find(~isemptycell(evidence)); | |
7 cobs = myintersect(obs, engine.fgraph.cnodes); | |
8 dobs = myintersect(obs, engine.fgraph.dnodes); | |
9 ns(cobs) = 0; | |
10 ns(dobs) = 1; | |
11 | |
12 % recompute the weight of each domain now that we know what nodes are observed | |
13 for i=1:ndoms | |
14 engine.dom_weight(i) = prod(ns(engine.fgraph.doms{i})); | |
15 end | |
16 | |
17 % prime each local kernel with evidence (if any) | |
18 local_kernel = cell(1, ndoms); | |
19 for i=1:length(engine.fgraph.kernels_of_type) | |
20 u = engine.fgraph.kernels_of_type{i}; | |
21 local_kernel(u) = kernel_to_dpots(engine.fgraph.kernels{i}, evidence, engine.fgraph.domains_of_type{i}); | |
22 end | |
23 | |
24 % initialise all msgs to 1s | |
25 msg = cell(ndoms, ndoms); | |
26 for i=1:ndoms | |
27 nbrs = engine.fgraph.nbrs{i}; | |
28 for j=nbrs(:)' | |
29 dom = engine.fgraph.sepset{i,j}; | |
30 msg{i,j} = dpot(dom, ns(dom)); | |
31 end | |
32 end | |
33 | |
34 prod_of_msg = cell(1, ndoms); | |
35 bel = cell(1, ndoms); | |
36 old_bel = cell(1, ndoms); | |
37 | |
38 converged = 0; | |
39 iter = 1; | |
40 while ~converged & (iter <= engine.max_iter) | |
41 | |
42 % each node multiplies all its incoming msgs | |
43 for i=1:ndoms | |
44 prod_of_msg{i} = dpot(doms{i}, ns(doms{i})); | |
45 nbrs = engine.fgraph.nbrs{i}; | |
46 for j=nbrs(:)' | |
47 prod_of_msg{i} = multiply_by_pot(prod_of_msg{i}, msg{j,i}); | |
48 end | |
49 end | |
50 | |
51 % each node computes its local belief | |
52 old_bel = bel; | |
53 for i=1:ndoms | |
54 bel{i} = normalize_pot(multiply_pots(prod_of_msg{i}, local_kernel{i})); | |
55 end | |
56 | |
57 % converged? | |
58 if iter==1 | |
59 converged = 0; | |
60 else | |
61 converged = 1; | |
62 for i=1:ndoms | |
63 belT = get_params(bel{i}, 'table'); | |
64 old_belT = get_params(old_bel{i}, 'table'); | |
65 if ~approxeq(belT, old_belT, engine.tol) | |
66 converged = 0; | |
67 break; | |
68 end | |
69 end | |
70 end | |
71 | |
72 if ~converged | |
73 old_msg = msg; | |
74 % each node sends a msg to each of its neighbors | |
75 for i=1:ndoms | |
76 nbrs = engine.fgraph.nbrs{i}; | |
77 for j=nbrs(:)' | |
78 % multiply all incoming msgs except from j | |
79 temp = prod_of_msg{i}; | |
80 temp = divide_by_pot(temp, old_msg{j,i}); | |
81 % send msg from i to j | |
82 temp = multiply_by_pot(temp, local_kernel{i}); | |
83 msg{i,j} = normalize_pot(marginalize_pot(temp, engine.fgraph.sepset{i,j})); | |
84 end | |
85 end | |
86 end | |
87 | |
88 iter = iter + 1 | |
89 end | |
90 | |
91 engine.marginal_domains = bel; | |
92 %for i=1:ndoms | |
93 %engine.marginal_domains{i} = get_params(bel{i}, 'table'); | |
94 %end |