wolffd@0
|
1 function [engine, ll, niter] = enter_evidence(engine, evidence, varargin)
|
wolffd@0
|
2 % ENTER_EVIDENCE Propagate evidence using belief propagation
|
wolffd@0
|
3 % [engine, ll, niter] = enter_evidence(engine, evidence, ...)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % The log-likelihood is not computed; ll = 0.
|
wolffd@0
|
6 % niter contains the number of iterations used
|
wolffd@0
|
7 %
|
wolffd@0
|
8 % The following optional arguments can be specified in the form of name/value pairs:
|
wolffd@0
|
9 % [default value in brackets]
|
wolffd@0
|
10 %
|
wolffd@0
|
11 % maximize - 1 means use max-product, 0 means use sum-product [0]
|
wolffd@0
|
12 %
|
wolffd@0
|
13 % e.g., engine = enter_evidence(engine, ev, 'maximize', 1)
|
wolffd@0
|
14
|
wolffd@0
|
15 ll = 0;
|
wolffd@0
|
16 maximize = 0;
|
wolffd@0
|
17
|
wolffd@0
|
18 if nargin >= 3
|
wolffd@0
|
19 args = varargin;
|
wolffd@0
|
20 nargs = length(args);
|
wolffd@0
|
21 for i=1:2:nargs
|
wolffd@0
|
22 switch args{i},
|
wolffd@0
|
23 case 'maximize', maximize = args{i+1};
|
wolffd@0
|
24 otherwise,
|
wolffd@0
|
25 error(['invalid argument name ' args{i}]);
|
wolffd@0
|
26 end
|
wolffd@0
|
27 end
|
wolffd@0
|
28 end
|
wolffd@0
|
29
|
wolffd@0
|
30 verbose = 0;
|
wolffd@0
|
31
|
wolffd@0
|
32 ns = engine.fgraph.node_sizes;
|
wolffd@0
|
33 onodes = find(~isemptycell(evidence));
|
wolffd@0
|
34 hnodes = find(isemptycell(evidence));
|
wolffd@0
|
35 cnodes = engine.fgraph.cnodes;
|
wolffd@0
|
36 pot_type = determine_pot_type(engine.fgraph, onodes);
|
wolffd@0
|
37
|
wolffd@0
|
38 % prime each local kernel with evidence (if any)
|
wolffd@0
|
39 nfactors = engine.fgraph.nfactors;
|
wolffd@0
|
40 nvars = engine.fgraph.nvars;
|
wolffd@0
|
41 factors = cell(1,nfactors);
|
wolffd@0
|
42 for f=1:nfactors
|
wolffd@0
|
43 K = engine.fgraph.factors{engine.fgraph.equiv_class(f)};
|
wolffd@0
|
44 factors{f} = convert_to_pot(K, pot_type, engine.fgraph.dom{f}(:), evidence);
|
wolffd@0
|
45 end
|
wolffd@0
|
46
|
wolffd@0
|
47 % initialise msgs
|
wolffd@0
|
48 msg_var_to_fac = cell(nvars, nfactors);
|
wolffd@0
|
49 for x=1:nvars
|
wolffd@0
|
50 for f=engine.fgraph.dep{x}
|
wolffd@0
|
51 msg_var_to_fac{x,f} = mk_initial_pot(pot_type, x, ns, cnodes, onodes);
|
wolffd@0
|
52 end
|
wolffd@0
|
53 end
|
wolffd@0
|
54 msg_fac_to_var = cell(nfactors, nvars);
|
wolffd@0
|
55 dom = cell(1, nfactors);
|
wolffd@0
|
56 for f=1:nfactors
|
wolffd@0
|
57 %hdom{f} = myintersect(engine.fgraph.dom{f}, hnodes);
|
wolffd@0
|
58 dom{f} = engine.fgraph.dom{f}(:)';
|
wolffd@0
|
59 for x=dom{f}
|
wolffd@0
|
60 msg_fac_to_var{f,x} = mk_initial_pot(pot_type, x, ns, cnodes, onodes);
|
wolffd@0
|
61 %msg_fac_to_var{f,x} = marginalize_pot(factors{f}, x);
|
wolffd@0
|
62 end
|
wolffd@0
|
63 end
|
wolffd@0
|
64
|
wolffd@0
|
65
|
wolffd@0
|
66
|
wolffd@0
|
67 converged = 0;
|
wolffd@0
|
68 iter = 1;
|
wolffd@0
|
69 var_prod = cell(1, nvars);
|
wolffd@0
|
70 fac_prod = cell(1, nfactors);
|
wolffd@0
|
71
|
wolffd@0
|
72 while ~converged & (iter <= engine.max_iter)
|
wolffd@0
|
73 if verbose, fprintf('iter %d\n', iter); end
|
wolffd@0
|
74
|
wolffd@0
|
75 % absorb
|
wolffd@0
|
76 old_var_prod = var_prod;
|
wolffd@0
|
77 for x=1:nvars
|
wolffd@0
|
78 var_prod{x} = mk_initial_pot(pot_type, x, ns, cnodes, onodes);
|
wolffd@0
|
79 for f=engine.fgraph.dep{x}
|
wolffd@0
|
80 var_prod{x} = multiply_by_pot(var_prod{x}, msg_fac_to_var{f,x});
|
wolffd@0
|
81 end
|
wolffd@0
|
82 end
|
wolffd@0
|
83 for f=1:nfactors
|
wolffd@0
|
84 fac_prod{f} = mk_initial_pot(pot_type, dom{f}, ns, cnodes, onodes);
|
wolffd@0
|
85 for x=dom{f}
|
wolffd@0
|
86 fac_prod{f} = multiply_by_pot(fac_prod{f}, msg_var_to_fac{x,f});
|
wolffd@0
|
87 end
|
wolffd@0
|
88 end
|
wolffd@0
|
89
|
wolffd@0
|
90 % send msgs to neighbors
|
wolffd@0
|
91 old_msg_var_to_fac = msg_var_to_fac;
|
wolffd@0
|
92 old_msg_fac_to_var = msg_fac_to_var;
|
wolffd@0
|
93 converged = 1;
|
wolffd@0
|
94 for x=1:nvars
|
wolffd@0
|
95 %if verbose, disp(['var ' num2str(x) ' sending to fac ' num2str(engine.fgraph.dep{x})]); end
|
wolffd@0
|
96 for f=engine.fgraph.dep{x}
|
wolffd@0
|
97 temp = divide_by_pot(var_prod{x}, old_msg_fac_to_var{f,x});
|
wolffd@0
|
98 msg_var_to_fac{x,f} = normalize_pot(temp);
|
wolffd@0
|
99 if ~approxeq_pot(msg_var_to_fac{x,f}, old_msg_var_to_fac{x,f}, engine.tol), converged = 0; end
|
wolffd@0
|
100 end
|
wolffd@0
|
101 end
|
wolffd@0
|
102 for f=1:nfactors
|
wolffd@0
|
103 %if verbose, disp(['fac ' num2str(f) ' sending to var ' num2str(dom{f})]); end
|
wolffd@0
|
104 for x=dom{f}
|
wolffd@0
|
105 temp = divide_by_pot(fac_prod{f}, old_msg_var_to_fac{x,f});
|
wolffd@0
|
106 temp2 = multiply_by_pot(factors{f}, temp);
|
wolffd@0
|
107 temp3 = marginalize_pot(temp2, x, maximize);
|
wolffd@0
|
108 msg_fac_to_var{f,x} = normalize_pot(temp3);
|
wolffd@0
|
109 if ~approxeq_pot(msg_fac_to_var{f,x}, old_msg_fac_to_var{f,x}, engine.tol), converged = 0; end
|
wolffd@0
|
110 end
|
wolffd@0
|
111 end
|
wolffd@0
|
112
|
wolffd@0
|
113 if iter==1
|
wolffd@0
|
114 converged = 0;
|
wolffd@0
|
115 end
|
wolffd@0
|
116 iter = iter + 1;
|
wolffd@0
|
117 end
|
wolffd@0
|
118
|
wolffd@0
|
119 niter = iter - 1;
|
wolffd@0
|
120 engine.niter = niter;
|
wolffd@0
|
121
|
wolffd@0
|
122 for x=1:nvars
|
wolffd@0
|
123 engine.marginal_nodes{x} = normalize_pot(var_prod{x});
|
wolffd@0
|
124 end
|
wolffd@0
|
125
|
wolffd@0
|
126
|