Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/inference/static/@pearl_inf_engine/pearl_inf_engine.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/inference/static/@pearl_inf_engine/pearl_inf_engine.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,158 @@ +function engine = pearl_inf_engine(bnet, varargin) +% PEARL_INF_ENGINE Pearl's algorithm (belief propagation) +% engine = pearl_inf_engine(bnet, ...) +% +% If the graph has no loops (undirected cycles), you should use the tree protocol, +% and the results will be exact. +% Otherwise, you should use the parallel protocol, and the results may be approximate. +% +% Optional arguments [default in brackets] +% 'protocol' - tree or parallel ['parallel'] +% +% Optional arguments for the loopy case +% 'max_iter' - specifies the max num. iterations to perform [2*num nodes] +% 'tol' - convergence criterion on messages [1e-3] +% 'momentum' - msg = (m*old + (1-m)*new). [m=0] +% 'filename' - msgs will be printed to this file, so you can assess convergence while it runs [[]] +% 'storebel' - 1 means save engine.bel{n,t} for every iteration t and hidden node n [0] +% +% If there are discrete and cts nodes, we assume all the discretes are observed. In this +% case, you must use the parallel protocol, and the evidence pattern must be fixed. + + +N = length(bnet.dag); +protocol = 'parallel'; +max_iter = 2*N; +% We use N+2 for the following reason: +% In N iterations, we get the exact answer for a tree. +% In the N+1st iteration, we notice that the results are the same as before, and terminate. +% In loopy_converged, we see that N+1 < max = N+2, and declare convergence. +tol = 1e-3; +momentum = 0; +filename = []; +storebel = 0; + +args = varargin; +for i=1:2:length(args) + switch args{i}, + case 'protocol', protocol = args{i+1}; + case 'max_iter', max_iter = args{i+1}; + case 'tol', tol = args{i+1}; + case 'momentum', momentum = args{i+1}; + case 'filename', filename = args{i+1}; + case 'storebel', storebel = args{i+1}; + end +end + +engine.filename = filename; +engine.storebel = storebel; +engine.bel = []; + +if strcmp(protocol, 'tree') + % We first send messages up to the root (pivot node), and then back towards the leaves. + % If the bnet is a singly connected graph (no loops), choosing a root induces a directed tree. + % Peot and Shachter discuss ways to pick the root so as to minimize the work, + % taking into account which nodes have changed. + % For simplicity, we always pick the root to be the last node in the graph. + % This means the first pass is equivalent to going forward in time in a DBN. + + engine.root = N; + [engine.adj_mat, engine.preorder, engine.postorder, loopy] = ... + mk_rooted_tree(bnet.dag, engine.root); + % engine.adj_mat might have different edge orientations from bnet.dag + if loopy + error('can only apply tree protocol to loop-less graphs') + end +else + engine.root = []; + engine.adj_mat = []; + engine.preorder = []; + engine.postorder = []; +end + +engine.niter = []; +engine.protocol = protocol; +engine.max_iter = max_iter; +engine.tol = tol; +engine.momentum = momentum; +engine.maximize = []; + +%onodes = find(~isemptycell(evidence)); +onodes = bnet.observed; +engine.msg_type = determine_pot_type(bnet, onodes, 1:N); % needed also by marginal_nodes +if strcmp(engine.msg_type, 'cg') + error('messages must be discrete or Gaussian') +end +[engine.msg_dag, disconnected_nodes] = mk_msg_dag(bnet, engine.msg_type, onodes); +engine.disconnected_nodes_bitv = zeros(1,N); +engine.disconnected_nodes_bitv(disconnected_nodes) = 1; + + +% this is where we store stuff between enter_evidence and marginal_nodes +engine.marginal = cell(1,N); +engine.evidence = []; +engine.msg = []; + +[engine.parent_index, engine.child_index] = mk_loopy_msg_indices(engine.msg_dag); + +engine = class(engine, 'pearl_inf_engine', inf_engine(bnet)); + + +%%%%%%%%% + +function [dag, disconnected_nodes] = mk_msg_dag(bnet, msg_type, onodes) + +% If we are using Gaussian msgs, all discrete nodes must be observed; +% they are then disconnected from the graph, so we don't try to send +% msgs to/from them: their observed value simply serves to index into +% the right set of parameters for the Gaussian nodes (which use CPD.ps +% instead of parents(dag), and hence are unaffected by this "surgery"). + +disconnected_nodes = []; +switch msg_type + case 'd', dag = bnet.dag; + case 'g', + disconnected_nodes = bnet.dnodes; + dag = bnet.dag; + for i=disconnected_nodes(:)' + ps = parents(bnet.dag, i); + cs = children(bnet.dag, i); + if ~isempty(ps), dag(ps, i) = 0; end + if ~isempty(cs), dag(i, cs) = 0; end + end +end + + +%%%%%%%%%% +function [parent_index, child_index] = mk_loopy_msg_indices(dag) +% MK_LOOPY_MSG_INDICES Compute "port numbers" for message passing +% [parent_index, child_index] = mk_loopy_msg_indices(bnet) +% +% child_index{n}(c) = i means c is n's i'th child, i.e., i = find_equiv_posns(c, children(n)) +% child_index{n}(c) = 0 means c is not a child of n. +% parent_index{n}{p} is defined similarly. +% We need to use these indices since the pi_from_parent/ lambda_from_child cell arrays +% cannot be sparse, and hence cannot be indexed by the actual number of the node. +% Instead, we use the number of the "port" on which the message arrived. + +N = length(dag); +child_index = cell(1,N); +parent_index = cell(1,N); +for n=1:N + cs = children(dag, n); + child_index{n} = sparse(1,N); + for i=1:length(cs) + c = cs(i); + child_index{n}(c) = i; + end + ps = parents(dag, n); + parent_index{n} = sparse(1,N); + for i=1:length(ps) + p = ps(i); + parent_index{n}(p) = i; + end +end + + + +