annotate toolboxes/FullBNT-1.0.7/bnt/inference/static/@pearl_inf_engine/pearl_inf_engine.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function engine = pearl_inf_engine(bnet, varargin)
wolffd@0 2 % PEARL_INF_ENGINE Pearl's algorithm (belief propagation)
wolffd@0 3 % engine = pearl_inf_engine(bnet, ...)
wolffd@0 4 %
wolffd@0 5 % If the graph has no loops (undirected cycles), you should use the tree protocol,
wolffd@0 6 % and the results will be exact.
wolffd@0 7 % Otherwise, you should use the parallel protocol, and the results may be approximate.
wolffd@0 8 %
wolffd@0 9 % Optional arguments [default in brackets]
wolffd@0 10 % 'protocol' - tree or parallel ['parallel']
wolffd@0 11 %
wolffd@0 12 % Optional arguments for the loopy case
wolffd@0 13 % 'max_iter' - specifies the max num. iterations to perform [2*num nodes]
wolffd@0 14 % 'tol' - convergence criterion on messages [1e-3]
wolffd@0 15 % 'momentum' - msg = (m*old + (1-m)*new). [m=0]
wolffd@0 16 % 'filename' - msgs will be printed to this file, so you can assess convergence while it runs [[]]
wolffd@0 17 % 'storebel' - 1 means save engine.bel{n,t} for every iteration t and hidden node n [0]
wolffd@0 18 %
wolffd@0 19 % If there are discrete and cts nodes, we assume all the discretes are observed. In this
wolffd@0 20 % case, you must use the parallel protocol, and the evidence pattern must be fixed.
wolffd@0 21
wolffd@0 22
wolffd@0 23 N = length(bnet.dag);
wolffd@0 24 protocol = 'parallel';
wolffd@0 25 max_iter = 2*N;
wolffd@0 26 % We use N+2 for the following reason:
wolffd@0 27 % In N iterations, we get the exact answer for a tree.
wolffd@0 28 % In the N+1st iteration, we notice that the results are the same as before, and terminate.
wolffd@0 29 % In loopy_converged, we see that N+1 < max = N+2, and declare convergence.
wolffd@0 30 tol = 1e-3;
wolffd@0 31 momentum = 0;
wolffd@0 32 filename = [];
wolffd@0 33 storebel = 0;
wolffd@0 34
wolffd@0 35 args = varargin;
wolffd@0 36 for i=1:2:length(args)
wolffd@0 37 switch args{i},
wolffd@0 38 case 'protocol', protocol = args{i+1};
wolffd@0 39 case 'max_iter', max_iter = args{i+1};
wolffd@0 40 case 'tol', tol = args{i+1};
wolffd@0 41 case 'momentum', momentum = args{i+1};
wolffd@0 42 case 'filename', filename = args{i+1};
wolffd@0 43 case 'storebel', storebel = args{i+1};
wolffd@0 44 end
wolffd@0 45 end
wolffd@0 46
wolffd@0 47 engine.filename = filename;
wolffd@0 48 engine.storebel = storebel;
wolffd@0 49 engine.bel = [];
wolffd@0 50
wolffd@0 51 if strcmp(protocol, 'tree')
wolffd@0 52 % We first send messages up to the root (pivot node), and then back towards the leaves.
wolffd@0 53 % If the bnet is a singly connected graph (no loops), choosing a root induces a directed tree.
wolffd@0 54 % Peot and Shachter discuss ways to pick the root so as to minimize the work,
wolffd@0 55 % taking into account which nodes have changed.
wolffd@0 56 % For simplicity, we always pick the root to be the last node in the graph.
wolffd@0 57 % This means the first pass is equivalent to going forward in time in a DBN.
wolffd@0 58
wolffd@0 59 engine.root = N;
wolffd@0 60 [engine.adj_mat, engine.preorder, engine.postorder, loopy] = ...
wolffd@0 61 mk_rooted_tree(bnet.dag, engine.root);
wolffd@0 62 % engine.adj_mat might have different edge orientations from bnet.dag
wolffd@0 63 if loopy
wolffd@0 64 error('can only apply tree protocol to loop-less graphs')
wolffd@0 65 end
wolffd@0 66 else
wolffd@0 67 engine.root = [];
wolffd@0 68 engine.adj_mat = [];
wolffd@0 69 engine.preorder = [];
wolffd@0 70 engine.postorder = [];
wolffd@0 71 end
wolffd@0 72
wolffd@0 73 engine.niter = [];
wolffd@0 74 engine.protocol = protocol;
wolffd@0 75 engine.max_iter = max_iter;
wolffd@0 76 engine.tol = tol;
wolffd@0 77 engine.momentum = momentum;
wolffd@0 78 engine.maximize = [];
wolffd@0 79
wolffd@0 80 %onodes = find(~isemptycell(evidence));
wolffd@0 81 onodes = bnet.observed;
wolffd@0 82 engine.msg_type = determine_pot_type(bnet, onodes, 1:N); % needed also by marginal_nodes
wolffd@0 83 if strcmp(engine.msg_type, 'cg')
wolffd@0 84 error('messages must be discrete or Gaussian')
wolffd@0 85 end
wolffd@0 86 [engine.msg_dag, disconnected_nodes] = mk_msg_dag(bnet, engine.msg_type, onodes);
wolffd@0 87 engine.disconnected_nodes_bitv = zeros(1,N);
wolffd@0 88 engine.disconnected_nodes_bitv(disconnected_nodes) = 1;
wolffd@0 89
wolffd@0 90
wolffd@0 91 % this is where we store stuff between enter_evidence and marginal_nodes
wolffd@0 92 engine.marginal = cell(1,N);
wolffd@0 93 engine.evidence = [];
wolffd@0 94 engine.msg = [];
wolffd@0 95
wolffd@0 96 [engine.parent_index, engine.child_index] = mk_loopy_msg_indices(engine.msg_dag);
wolffd@0 97
wolffd@0 98 engine = class(engine, 'pearl_inf_engine', inf_engine(bnet));
wolffd@0 99
wolffd@0 100
wolffd@0 101 %%%%%%%%%
wolffd@0 102
wolffd@0 103 function [dag, disconnected_nodes] = mk_msg_dag(bnet, msg_type, onodes)
wolffd@0 104
wolffd@0 105 % If we are using Gaussian msgs, all discrete nodes must be observed;
wolffd@0 106 % they are then disconnected from the graph, so we don't try to send
wolffd@0 107 % msgs to/from them: their observed value simply serves to index into
wolffd@0 108 % the right set of parameters for the Gaussian nodes (which use CPD.ps
wolffd@0 109 % instead of parents(dag), and hence are unaffected by this "surgery").
wolffd@0 110
wolffd@0 111 disconnected_nodes = [];
wolffd@0 112 switch msg_type
wolffd@0 113 case 'd', dag = bnet.dag;
wolffd@0 114 case 'g',
wolffd@0 115 disconnected_nodes = bnet.dnodes;
wolffd@0 116 dag = bnet.dag;
wolffd@0 117 for i=disconnected_nodes(:)'
wolffd@0 118 ps = parents(bnet.dag, i);
wolffd@0 119 cs = children(bnet.dag, i);
wolffd@0 120 if ~isempty(ps), dag(ps, i) = 0; end
wolffd@0 121 if ~isempty(cs), dag(i, cs) = 0; end
wolffd@0 122 end
wolffd@0 123 end
wolffd@0 124
wolffd@0 125
wolffd@0 126 %%%%%%%%%%
wolffd@0 127 function [parent_index, child_index] = mk_loopy_msg_indices(dag)
wolffd@0 128 % MK_LOOPY_MSG_INDICES Compute "port numbers" for message passing
wolffd@0 129 % [parent_index, child_index] = mk_loopy_msg_indices(bnet)
wolffd@0 130 %
wolffd@0 131 % child_index{n}(c) = i means c is n's i'th child, i.e., i = find_equiv_posns(c, children(n))
wolffd@0 132 % child_index{n}(c) = 0 means c is not a child of n.
wolffd@0 133 % parent_index{n}{p} is defined similarly.
wolffd@0 134 % We need to use these indices since the pi_from_parent/ lambda_from_child cell arrays
wolffd@0 135 % cannot be sparse, and hence cannot be indexed by the actual number of the node.
wolffd@0 136 % Instead, we use the number of the "port" on which the message arrived.
wolffd@0 137
wolffd@0 138 N = length(dag);
wolffd@0 139 child_index = cell(1,N);
wolffd@0 140 parent_index = cell(1,N);
wolffd@0 141 for n=1:N
wolffd@0 142 cs = children(dag, n);
wolffd@0 143 child_index{n} = sparse(1,N);
wolffd@0 144 for i=1:length(cs)
wolffd@0 145 c = cs(i);
wolffd@0 146 child_index{n}(c) = i;
wolffd@0 147 end
wolffd@0 148 ps = parents(dag, n);
wolffd@0 149 parent_index{n} = sparse(1,N);
wolffd@0 150 for i=1:length(ps)
wolffd@0 151 p = ps(i);
wolffd@0 152 parent_index{n}(p) = i;
wolffd@0 153 end
wolffd@0 154 end
wolffd@0 155
wolffd@0 156
wolffd@0 157
wolffd@0 158