Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_fg_inf_engine/enter_evidence.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_fg_inf_engine/enter_evidence.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,126 @@ +function [engine, ll, niter] = enter_evidence(engine, evidence, varargin) +% ENTER_EVIDENCE Propagate evidence using belief propagation +% [engine, ll, niter] = enter_evidence(engine, evidence, ...) +% +% The log-likelihood is not computed; ll = 0. +% niter contains the number of iterations used +% +% The following optional arguments can be specified in the form of name/value pairs: +% [default value in brackets] +% +% maximize - 1 means use max-product, 0 means use sum-product [0] +% +% e.g., engine = enter_evidence(engine, ev, 'maximize', 1) + +ll = 0; +maximize = 0; + +if nargin >= 3 + args = varargin; + nargs = length(args); + for i=1:2:nargs + switch args{i}, + case 'maximize', maximize = args{i+1}; + otherwise, + error(['invalid argument name ' args{i}]); + end + end +end + +verbose = 0; + +ns = engine.fgraph.node_sizes; +onodes = find(~isemptycell(evidence)); +hnodes = find(isemptycell(evidence)); +cnodes = engine.fgraph.cnodes; +pot_type = determine_pot_type(engine.fgraph, onodes); + +% prime each local kernel with evidence (if any) +nfactors = engine.fgraph.nfactors; +nvars = engine.fgraph.nvars; +factors = cell(1,nfactors); +for f=1:nfactors + K = engine.fgraph.factors{engine.fgraph.equiv_class(f)}; + factors{f} = convert_to_pot(K, pot_type, engine.fgraph.dom{f}(:), evidence); +end + +% initialise msgs +msg_var_to_fac = cell(nvars, nfactors); +for x=1:nvars + for f=engine.fgraph.dep{x} + msg_var_to_fac{x,f} = mk_initial_pot(pot_type, x, ns, cnodes, onodes); + end +end +msg_fac_to_var = cell(nfactors, nvars); +dom = cell(1, nfactors); +for f=1:nfactors + %hdom{f} = myintersect(engine.fgraph.dom{f}, hnodes); + dom{f} = engine.fgraph.dom{f}(:)'; + for x=dom{f} + msg_fac_to_var{f,x} = mk_initial_pot(pot_type, x, ns, cnodes, onodes); + %msg_fac_to_var{f,x} = marginalize_pot(factors{f}, x); + end +end + + + +converged = 0; +iter = 1; +var_prod = cell(1, nvars); +fac_prod = cell(1, nfactors); + +while ~converged & (iter <= engine.max_iter) + if verbose, fprintf('iter %d\n', iter); end + + % absorb + old_var_prod = var_prod; + for x=1:nvars + var_prod{x} = mk_initial_pot(pot_type, x, ns, cnodes, onodes); + for f=engine.fgraph.dep{x} + var_prod{x} = multiply_by_pot(var_prod{x}, msg_fac_to_var{f,x}); + end + end + for f=1:nfactors + fac_prod{f} = mk_initial_pot(pot_type, dom{f}, ns, cnodes, onodes); + for x=dom{f} + fac_prod{f} = multiply_by_pot(fac_prod{f}, msg_var_to_fac{x,f}); + end + end + + % send msgs to neighbors + old_msg_var_to_fac = msg_var_to_fac; + old_msg_fac_to_var = msg_fac_to_var; + converged = 1; + for x=1:nvars + %if verbose, disp(['var ' num2str(x) ' sending to fac ' num2str(engine.fgraph.dep{x})]); end + for f=engine.fgraph.dep{x} + temp = divide_by_pot(var_prod{x}, old_msg_fac_to_var{f,x}); + msg_var_to_fac{x,f} = normalize_pot(temp); + if ~approxeq_pot(msg_var_to_fac{x,f}, old_msg_var_to_fac{x,f}, engine.tol), converged = 0; end + end + end + for f=1:nfactors + %if verbose, disp(['fac ' num2str(f) ' sending to var ' num2str(dom{f})]); end + for x=dom{f} + temp = divide_by_pot(fac_prod{f}, old_msg_var_to_fac{x,f}); + temp2 = multiply_by_pot(factors{f}, temp); + temp3 = marginalize_pot(temp2, x, maximize); + msg_fac_to_var{f,x} = normalize_pot(temp3); + if ~approxeq_pot(msg_fac_to_var{f,x}, old_msg_fac_to_var{f,x}, engine.tol), converged = 0; end + end + end + + if iter==1 + converged = 0; + end + iter = iter + 1; +end + +niter = iter - 1; +engine.niter = niter; + +for x=1:nvars + engine.marginal_nodes{x} = normalize_pot(var_prod{x}); +end + +