Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_fg_inf_engine/enter_evidence.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [engine, ll, niter] = enter_evidence(engine, evidence, varargin) | |
2 % ENTER_EVIDENCE Propagate evidence using belief propagation | |
3 % [engine, ll, niter] = enter_evidence(engine, evidence, ...) | |
4 % | |
5 % The log-likelihood is not computed; ll = 0. | |
6 % niter contains the number of iterations used | |
7 % | |
8 % The following optional arguments can be specified in the form of name/value pairs: | |
9 % [default value in brackets] | |
10 % | |
11 % maximize - 1 means use max-product, 0 means use sum-product [0] | |
12 % | |
13 % e.g., engine = enter_evidence(engine, ev, 'maximize', 1) | |
14 | |
15 ll = 0; | |
16 maximize = 0; | |
17 | |
18 if nargin >= 3 | |
19 args = varargin; | |
20 nargs = length(args); | |
21 for i=1:2:nargs | |
22 switch args{i}, | |
23 case 'maximize', maximize = args{i+1}; | |
24 otherwise, | |
25 error(['invalid argument name ' args{i}]); | |
26 end | |
27 end | |
28 end | |
29 | |
30 verbose = 0; | |
31 | |
32 ns = engine.fgraph.node_sizes; | |
33 onodes = find(~isemptycell(evidence)); | |
34 hnodes = find(isemptycell(evidence)); | |
35 cnodes = engine.fgraph.cnodes; | |
36 pot_type = determine_pot_type(engine.fgraph, onodes); | |
37 | |
38 % prime each local kernel with evidence (if any) | |
39 nfactors = engine.fgraph.nfactors; | |
40 nvars = engine.fgraph.nvars; | |
41 factors = cell(1,nfactors); | |
42 for f=1:nfactors | |
43 K = engine.fgraph.factors{engine.fgraph.equiv_class(f)}; | |
44 factors{f} = convert_to_pot(K, pot_type, engine.fgraph.dom{f}(:), evidence); | |
45 end | |
46 | |
47 % initialise msgs | |
48 msg_var_to_fac = cell(nvars, nfactors); | |
49 for x=1:nvars | |
50 for f=engine.fgraph.dep{x} | |
51 msg_var_to_fac{x,f} = mk_initial_pot(pot_type, x, ns, cnodes, onodes); | |
52 end | |
53 end | |
54 msg_fac_to_var = cell(nfactors, nvars); | |
55 dom = cell(1, nfactors); | |
56 for f=1:nfactors | |
57 %hdom{f} = myintersect(engine.fgraph.dom{f}, hnodes); | |
58 dom{f} = engine.fgraph.dom{f}(:)'; | |
59 for x=dom{f} | |
60 msg_fac_to_var{f,x} = mk_initial_pot(pot_type, x, ns, cnodes, onodes); | |
61 %msg_fac_to_var{f,x} = marginalize_pot(factors{f}, x); | |
62 end | |
63 end | |
64 | |
65 | |
66 | |
67 converged = 0; | |
68 iter = 1; | |
69 var_prod = cell(1, nvars); | |
70 fac_prod = cell(1, nfactors); | |
71 | |
72 while ~converged & (iter <= engine.max_iter) | |
73 if verbose, fprintf('iter %d\n', iter); end | |
74 | |
75 % absorb | |
76 old_var_prod = var_prod; | |
77 for x=1:nvars | |
78 var_prod{x} = mk_initial_pot(pot_type, x, ns, cnodes, onodes); | |
79 for f=engine.fgraph.dep{x} | |
80 var_prod{x} = multiply_by_pot(var_prod{x}, msg_fac_to_var{f,x}); | |
81 end | |
82 end | |
83 for f=1:nfactors | |
84 fac_prod{f} = mk_initial_pot(pot_type, dom{f}, ns, cnodes, onodes); | |
85 for x=dom{f} | |
86 fac_prod{f} = multiply_by_pot(fac_prod{f}, msg_var_to_fac{x,f}); | |
87 end | |
88 end | |
89 | |
90 % send msgs to neighbors | |
91 old_msg_var_to_fac = msg_var_to_fac; | |
92 old_msg_fac_to_var = msg_fac_to_var; | |
93 converged = 1; | |
94 for x=1:nvars | |
95 %if verbose, disp(['var ' num2str(x) ' sending to fac ' num2str(engine.fgraph.dep{x})]); end | |
96 for f=engine.fgraph.dep{x} | |
97 temp = divide_by_pot(var_prod{x}, old_msg_fac_to_var{f,x}); | |
98 msg_var_to_fac{x,f} = normalize_pot(temp); | |
99 if ~approxeq_pot(msg_var_to_fac{x,f}, old_msg_var_to_fac{x,f}, engine.tol), converged = 0; end | |
100 end | |
101 end | |
102 for f=1:nfactors | |
103 %if verbose, disp(['fac ' num2str(f) ' sending to var ' num2str(dom{f})]); end | |
104 for x=dom{f} | |
105 temp = divide_by_pot(fac_prod{f}, old_msg_var_to_fac{x,f}); | |
106 temp2 = multiply_by_pot(factors{f}, temp); | |
107 temp3 = marginalize_pot(temp2, x, maximize); | |
108 msg_fac_to_var{f,x} = normalize_pot(temp3); | |
109 if ~approxeq_pot(msg_fac_to_var{f,x}, old_msg_fac_to_var{f,x}, engine.tol), converged = 0; end | |
110 end | |
111 end | |
112 | |
113 if iter==1 | |
114 converged = 0; | |
115 end | |
116 iter = iter + 1; | |
117 end | |
118 | |
119 niter = iter - 1; | |
120 engine.niter = niter; | |
121 | |
122 for x=1:nvars | |
123 engine.marginal_nodes{x} = normalize_pot(var_prod{x}); | |
124 end | |
125 | |
126 |