wolffd@0: function engine = pearl_inf_engine(bnet, varargin) wolffd@0: % PEARL_INF_ENGINE Pearl's algorithm (belief propagation) wolffd@0: % engine = pearl_inf_engine(bnet, ...) wolffd@0: % wolffd@0: % If the graph has no loops (undirected cycles), you should use the tree protocol, wolffd@0: % and the results will be exact. wolffd@0: % Otherwise, you should use the parallel protocol, and the results may be approximate. wolffd@0: % wolffd@0: % Optional arguments [default in brackets] wolffd@0: % 'protocol' - tree or parallel ['parallel'] wolffd@0: % wolffd@0: % Optional arguments for the loopy case wolffd@0: % 'max_iter' - specifies the max num. iterations to perform [2*num nodes] wolffd@0: % 'tol' - convergence criterion on messages [1e-3] wolffd@0: % 'momentum' - msg = (m*old + (1-m)*new). [m=0] wolffd@0: % 'filename' - msgs will be printed to this file, so you can assess convergence while it runs [[]] wolffd@0: % 'storebel' - 1 means save engine.bel{n,t} for every iteration t and hidden node n [0] wolffd@0: % wolffd@0: % If there are discrete and cts nodes, we assume all the discretes are observed. In this wolffd@0: % case, you must use the parallel protocol, and the evidence pattern must be fixed. wolffd@0: wolffd@0: wolffd@0: N = length(bnet.dag); wolffd@0: protocol = 'parallel'; wolffd@0: max_iter = 2*N; wolffd@0: % We use N+2 for the following reason: wolffd@0: % In N iterations, we get the exact answer for a tree. wolffd@0: % In the N+1st iteration, we notice that the results are the same as before, and terminate. wolffd@0: % In loopy_converged, we see that N+1 < max = N+2, and declare convergence. wolffd@0: tol = 1e-3; wolffd@0: momentum = 0; wolffd@0: filename = []; wolffd@0: storebel = 0; wolffd@0: wolffd@0: args = varargin; wolffd@0: for i=1:2:length(args) wolffd@0: switch args{i}, wolffd@0: case 'protocol', protocol = args{i+1}; wolffd@0: case 'max_iter', max_iter = args{i+1}; wolffd@0: case 'tol', tol = args{i+1}; wolffd@0: case 'momentum', momentum = args{i+1}; wolffd@0: case 'filename', filename = args{i+1}; wolffd@0: case 'storebel', storebel = args{i+1}; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: engine.filename = filename; wolffd@0: engine.storebel = storebel; wolffd@0: engine.bel = []; wolffd@0: wolffd@0: if strcmp(protocol, 'tree') wolffd@0: % We first send messages up to the root (pivot node), and then back towards the leaves. wolffd@0: % If the bnet is a singly connected graph (no loops), choosing a root induces a directed tree. wolffd@0: % Peot and Shachter discuss ways to pick the root so as to minimize the work, wolffd@0: % taking into account which nodes have changed. wolffd@0: % For simplicity, we always pick the root to be the last node in the graph. wolffd@0: % This means the first pass is equivalent to going forward in time in a DBN. wolffd@0: wolffd@0: engine.root = N; wolffd@0: [engine.adj_mat, engine.preorder, engine.postorder, loopy] = ... wolffd@0: mk_rooted_tree(bnet.dag, engine.root); wolffd@0: % engine.adj_mat might have different edge orientations from bnet.dag wolffd@0: if loopy wolffd@0: error('can only apply tree protocol to loop-less graphs') wolffd@0: end wolffd@0: else wolffd@0: engine.root = []; wolffd@0: engine.adj_mat = []; wolffd@0: engine.preorder = []; wolffd@0: engine.postorder = []; wolffd@0: end wolffd@0: wolffd@0: engine.niter = []; wolffd@0: engine.protocol = protocol; wolffd@0: engine.max_iter = max_iter; wolffd@0: engine.tol = tol; wolffd@0: engine.momentum = momentum; wolffd@0: engine.maximize = []; wolffd@0: wolffd@0: %onodes = find(~isemptycell(evidence)); wolffd@0: onodes = bnet.observed; wolffd@0: engine.msg_type = determine_pot_type(bnet, onodes, 1:N); % needed also by marginal_nodes wolffd@0: if strcmp(engine.msg_type, 'cg') wolffd@0: error('messages must be discrete or Gaussian') wolffd@0: end wolffd@0: [engine.msg_dag, disconnected_nodes] = mk_msg_dag(bnet, engine.msg_type, onodes); wolffd@0: engine.disconnected_nodes_bitv = zeros(1,N); wolffd@0: engine.disconnected_nodes_bitv(disconnected_nodes) = 1; wolffd@0: wolffd@0: wolffd@0: % this is where we store stuff between enter_evidence and marginal_nodes wolffd@0: engine.marginal = cell(1,N); wolffd@0: engine.evidence = []; wolffd@0: engine.msg = []; wolffd@0: wolffd@0: [engine.parent_index, engine.child_index] = mk_loopy_msg_indices(engine.msg_dag); wolffd@0: wolffd@0: engine = class(engine, 'pearl_inf_engine', inf_engine(bnet)); wolffd@0: wolffd@0: wolffd@0: %%%%%%%%% wolffd@0: wolffd@0: function [dag, disconnected_nodes] = mk_msg_dag(bnet, msg_type, onodes) wolffd@0: wolffd@0: % If we are using Gaussian msgs, all discrete nodes must be observed; wolffd@0: % they are then disconnected from the graph, so we don't try to send wolffd@0: % msgs to/from them: their observed value simply serves to index into wolffd@0: % the right set of parameters for the Gaussian nodes (which use CPD.ps wolffd@0: % instead of parents(dag), and hence are unaffected by this "surgery"). wolffd@0: wolffd@0: disconnected_nodes = []; wolffd@0: switch msg_type wolffd@0: case 'd', dag = bnet.dag; wolffd@0: case 'g', wolffd@0: disconnected_nodes = bnet.dnodes; wolffd@0: dag = bnet.dag; wolffd@0: for i=disconnected_nodes(:)' wolffd@0: ps = parents(bnet.dag, i); wolffd@0: cs = children(bnet.dag, i); wolffd@0: if ~isempty(ps), dag(ps, i) = 0; end wolffd@0: if ~isempty(cs), dag(i, cs) = 0; end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%%%%%%%%% wolffd@0: function [parent_index, child_index] = mk_loopy_msg_indices(dag) wolffd@0: % MK_LOOPY_MSG_INDICES Compute "port numbers" for message passing wolffd@0: % [parent_index, child_index] = mk_loopy_msg_indices(bnet) wolffd@0: % wolffd@0: % child_index{n}(c) = i means c is n's i'th child, i.e., i = find_equiv_posns(c, children(n)) wolffd@0: % child_index{n}(c) = 0 means c is not a child of n. wolffd@0: % parent_index{n}{p} is defined similarly. wolffd@0: % We need to use these indices since the pi_from_parent/ lambda_from_child cell arrays wolffd@0: % cannot be sparse, and hence cannot be indexed by the actual number of the node. wolffd@0: % Instead, we use the number of the "port" on which the message arrived. wolffd@0: wolffd@0: N = length(dag); wolffd@0: child_index = cell(1,N); wolffd@0: parent_index = cell(1,N); wolffd@0: for n=1:N wolffd@0: cs = children(dag, n); wolffd@0: child_index{n} = sparse(1,N); wolffd@0: for i=1:length(cs) wolffd@0: c = cs(i); wolffd@0: child_index{n}(c) = i; wolffd@0: end wolffd@0: ps = parents(dag, n); wolffd@0: parent_index{n} = sparse(1,N); wolffd@0: for i=1:length(ps) wolffd@0: p = ps(i); wolffd@0: parent_index{n}(p) = i; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: wolffd@0: