annotate toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_fg_inf_engine/enter_evidence.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [engine, ll, niter] = enter_evidence(engine, evidence, varargin)
wolffd@0 2 % ENTER_EVIDENCE Propagate evidence using belief propagation
wolffd@0 3 % [engine, ll, niter] = enter_evidence(engine, evidence, ...)
wolffd@0 4 %
wolffd@0 5 % The log-likelihood is not computed; ll = 0.
wolffd@0 6 % niter contains the number of iterations used
wolffd@0 7 %
wolffd@0 8 % The following optional arguments can be specified in the form of name/value pairs:
wolffd@0 9 % [default value in brackets]
wolffd@0 10 %
wolffd@0 11 % maximize - 1 means use max-product, 0 means use sum-product [0]
wolffd@0 12 %
wolffd@0 13 % e.g., engine = enter_evidence(engine, ev, 'maximize', 1)
wolffd@0 14
wolffd@0 15 ll = 0;
wolffd@0 16 maximize = 0;
wolffd@0 17
wolffd@0 18 if nargin >= 3
wolffd@0 19 args = varargin;
wolffd@0 20 nargs = length(args);
wolffd@0 21 for i=1:2:nargs
wolffd@0 22 switch args{i},
wolffd@0 23 case 'maximize', maximize = args{i+1};
wolffd@0 24 otherwise,
wolffd@0 25 error(['invalid argument name ' args{i}]);
wolffd@0 26 end
wolffd@0 27 end
wolffd@0 28 end
wolffd@0 29
wolffd@0 30 verbose = 0;
wolffd@0 31
wolffd@0 32 ns = engine.fgraph.node_sizes;
wolffd@0 33 onodes = find(~isemptycell(evidence));
wolffd@0 34 hnodes = find(isemptycell(evidence));
wolffd@0 35 cnodes = engine.fgraph.cnodes;
wolffd@0 36 pot_type = determine_pot_type(engine.fgraph, onodes);
wolffd@0 37
wolffd@0 38 % prime each local kernel with evidence (if any)
wolffd@0 39 nfactors = engine.fgraph.nfactors;
wolffd@0 40 nvars = engine.fgraph.nvars;
wolffd@0 41 factors = cell(1,nfactors);
wolffd@0 42 for f=1:nfactors
wolffd@0 43 K = engine.fgraph.factors{engine.fgraph.equiv_class(f)};
wolffd@0 44 factors{f} = convert_to_pot(K, pot_type, engine.fgraph.dom{f}(:), evidence);
wolffd@0 45 end
wolffd@0 46
wolffd@0 47 % initialise msgs
wolffd@0 48 msg_var_to_fac = cell(nvars, nfactors);
wolffd@0 49 for x=1:nvars
wolffd@0 50 for f=engine.fgraph.dep{x}
wolffd@0 51 msg_var_to_fac{x,f} = mk_initial_pot(pot_type, x, ns, cnodes, onodes);
wolffd@0 52 end
wolffd@0 53 end
wolffd@0 54 msg_fac_to_var = cell(nfactors, nvars);
wolffd@0 55 dom = cell(1, nfactors);
wolffd@0 56 for f=1:nfactors
wolffd@0 57 %hdom{f} = myintersect(engine.fgraph.dom{f}, hnodes);
wolffd@0 58 dom{f} = engine.fgraph.dom{f}(:)';
wolffd@0 59 for x=dom{f}
wolffd@0 60 msg_fac_to_var{f,x} = mk_initial_pot(pot_type, x, ns, cnodes, onodes);
wolffd@0 61 %msg_fac_to_var{f,x} = marginalize_pot(factors{f}, x);
wolffd@0 62 end
wolffd@0 63 end
wolffd@0 64
wolffd@0 65
wolffd@0 66
wolffd@0 67 converged = 0;
wolffd@0 68 iter = 1;
wolffd@0 69 var_prod = cell(1, nvars);
wolffd@0 70 fac_prod = cell(1, nfactors);
wolffd@0 71
wolffd@0 72 while ~converged & (iter <= engine.max_iter)
wolffd@0 73 if verbose, fprintf('iter %d\n', iter); end
wolffd@0 74
wolffd@0 75 % absorb
wolffd@0 76 old_var_prod = var_prod;
wolffd@0 77 for x=1:nvars
wolffd@0 78 var_prod{x} = mk_initial_pot(pot_type, x, ns, cnodes, onodes);
wolffd@0 79 for f=engine.fgraph.dep{x}
wolffd@0 80 var_prod{x} = multiply_by_pot(var_prod{x}, msg_fac_to_var{f,x});
wolffd@0 81 end
wolffd@0 82 end
wolffd@0 83 for f=1:nfactors
wolffd@0 84 fac_prod{f} = mk_initial_pot(pot_type, dom{f}, ns, cnodes, onodes);
wolffd@0 85 for x=dom{f}
wolffd@0 86 fac_prod{f} = multiply_by_pot(fac_prod{f}, msg_var_to_fac{x,f});
wolffd@0 87 end
wolffd@0 88 end
wolffd@0 89
wolffd@0 90 % send msgs to neighbors
wolffd@0 91 old_msg_var_to_fac = msg_var_to_fac;
wolffd@0 92 old_msg_fac_to_var = msg_fac_to_var;
wolffd@0 93 converged = 1;
wolffd@0 94 for x=1:nvars
wolffd@0 95 %if verbose, disp(['var ' num2str(x) ' sending to fac ' num2str(engine.fgraph.dep{x})]); end
wolffd@0 96 for f=engine.fgraph.dep{x}
wolffd@0 97 temp = divide_by_pot(var_prod{x}, old_msg_fac_to_var{f,x});
wolffd@0 98 msg_var_to_fac{x,f} = normalize_pot(temp);
wolffd@0 99 if ~approxeq_pot(msg_var_to_fac{x,f}, old_msg_var_to_fac{x,f}, engine.tol), converged = 0; end
wolffd@0 100 end
wolffd@0 101 end
wolffd@0 102 for f=1:nfactors
wolffd@0 103 %if verbose, disp(['fac ' num2str(f) ' sending to var ' num2str(dom{f})]); end
wolffd@0 104 for x=dom{f}
wolffd@0 105 temp = divide_by_pot(fac_prod{f}, old_msg_var_to_fac{x,f});
wolffd@0 106 temp2 = multiply_by_pot(factors{f}, temp);
wolffd@0 107 temp3 = marginalize_pot(temp2, x, maximize);
wolffd@0 108 msg_fac_to_var{f,x} = normalize_pot(temp3);
wolffd@0 109 if ~approxeq_pot(msg_fac_to_var{f,x}, old_msg_fac_to_var{f,x}, engine.tol), converged = 0; end
wolffd@0 110 end
wolffd@0 111 end
wolffd@0 112
wolffd@0 113 if iter==1
wolffd@0 114 converged = 0;
wolffd@0 115 end
wolffd@0 116 iter = iter + 1;
wolffd@0 117 end
wolffd@0 118
wolffd@0 119 niter = iter - 1;
wolffd@0 120 engine.niter = niter;
wolffd@0 121
wolffd@0 122 for x=1:nvars
wolffd@0 123 engine.marginal_nodes{x} = normalize_pot(var_prod{x});
wolffd@0 124 end
wolffd@0 125
wolffd@0 126