wolffd@0
|
1 function engine = pearl_inf_engine(bnet, varargin)
|
wolffd@0
|
2 % PEARL_INF_ENGINE Pearl's algorithm (belief propagation)
|
wolffd@0
|
3 % engine = pearl_inf_engine(bnet, ...)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % If the graph has no loops (undirected cycles), you should use the tree protocol,
|
wolffd@0
|
6 % and the results will be exact.
|
wolffd@0
|
7 % Otherwise, you should use the parallel protocol, and the results may be approximate.
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % Optional arguments [default in brackets]
|
wolffd@0
|
10 % 'protocol' - tree or parallel ['parallel']
|
wolffd@0
|
11 %
|
wolffd@0
|
12 % Optional arguments for the loopy case
|
wolffd@0
|
13 % 'max_iter' - specifies the max num. iterations to perform [2*num nodes]
|
wolffd@0
|
14 % 'tol' - convergence criterion on messages [1e-3]
|
wolffd@0
|
15 % 'momentum' - msg = (m*old + (1-m)*new). [m=0]
|
wolffd@0
|
16 % 'filename' - msgs will be printed to this file, so you can assess convergence while it runs [[]]
|
wolffd@0
|
17 % 'storebel' - 1 means save engine.bel{n,t} for every iteration t and hidden node n [0]
|
wolffd@0
|
18 %
|
wolffd@0
|
19 % If there are discrete and cts nodes, we assume all the discretes are observed. In this
|
wolffd@0
|
20 % case, you must use the parallel protocol, and the evidence pattern must be fixed.
|
wolffd@0
|
21
|
wolffd@0
|
22
|
wolffd@0
|
23 N = length(bnet.dag);
|
wolffd@0
|
24 protocol = 'parallel';
|
wolffd@0
|
25 max_iter = 2*N;
|
wolffd@0
|
26 % We use N+2 for the following reason:
|
wolffd@0
|
27 % In N iterations, we get the exact answer for a tree.
|
wolffd@0
|
28 % In the N+1st iteration, we notice that the results are the same as before, and terminate.
|
wolffd@0
|
29 % In loopy_converged, we see that N+1 < max = N+2, and declare convergence.
|
wolffd@0
|
30 tol = 1e-3;
|
wolffd@0
|
31 momentum = 0;
|
wolffd@0
|
32 filename = [];
|
wolffd@0
|
33 storebel = 0;
|
wolffd@0
|
34
|
wolffd@0
|
35 args = varargin;
|
wolffd@0
|
36 for i=1:2:length(args)
|
wolffd@0
|
37 switch args{i},
|
wolffd@0
|
38 case 'protocol', protocol = args{i+1};
|
wolffd@0
|
39 case 'max_iter', max_iter = args{i+1};
|
wolffd@0
|
40 case 'tol', tol = args{i+1};
|
wolffd@0
|
41 case 'momentum', momentum = args{i+1};
|
wolffd@0
|
42 case 'filename', filename = args{i+1};
|
wolffd@0
|
43 case 'storebel', storebel = args{i+1};
|
wolffd@0
|
44 end
|
wolffd@0
|
45 end
|
wolffd@0
|
46
|
wolffd@0
|
47 engine.filename = filename;
|
wolffd@0
|
48 engine.storebel = storebel;
|
wolffd@0
|
49 engine.bel = [];
|
wolffd@0
|
50
|
wolffd@0
|
51 if strcmp(protocol, 'tree')
|
wolffd@0
|
52 % We first send messages up to the root (pivot node), and then back towards the leaves.
|
wolffd@0
|
53 % If the bnet is a singly connected graph (no loops), choosing a root induces a directed tree.
|
wolffd@0
|
54 % Peot and Shachter discuss ways to pick the root so as to minimize the work,
|
wolffd@0
|
55 % taking into account which nodes have changed.
|
wolffd@0
|
56 % For simplicity, we always pick the root to be the last node in the graph.
|
wolffd@0
|
57 % This means the first pass is equivalent to going forward in time in a DBN.
|
wolffd@0
|
58
|
wolffd@0
|
59 engine.root = N;
|
wolffd@0
|
60 [engine.adj_mat, engine.preorder, engine.postorder, loopy] = ...
|
wolffd@0
|
61 mk_rooted_tree(bnet.dag, engine.root);
|
wolffd@0
|
62 % engine.adj_mat might have different edge orientations from bnet.dag
|
wolffd@0
|
63 if loopy
|
wolffd@0
|
64 error('can only apply tree protocol to loop-less graphs')
|
wolffd@0
|
65 end
|
wolffd@0
|
66 else
|
wolffd@0
|
67 engine.root = [];
|
wolffd@0
|
68 engine.adj_mat = [];
|
wolffd@0
|
69 engine.preorder = [];
|
wolffd@0
|
70 engine.postorder = [];
|
wolffd@0
|
71 end
|
wolffd@0
|
72
|
wolffd@0
|
73 engine.niter = [];
|
wolffd@0
|
74 engine.protocol = protocol;
|
wolffd@0
|
75 engine.max_iter = max_iter;
|
wolffd@0
|
76 engine.tol = tol;
|
wolffd@0
|
77 engine.momentum = momentum;
|
wolffd@0
|
78 engine.maximize = [];
|
wolffd@0
|
79
|
wolffd@0
|
80 %onodes = find(~isemptycell(evidence));
|
wolffd@0
|
81 onodes = bnet.observed;
|
wolffd@0
|
82 engine.msg_type = determine_pot_type(bnet, onodes, 1:N); % needed also by marginal_nodes
|
wolffd@0
|
83 if strcmp(engine.msg_type, 'cg')
|
wolffd@0
|
84 error('messages must be discrete or Gaussian')
|
wolffd@0
|
85 end
|
wolffd@0
|
86 [engine.msg_dag, disconnected_nodes] = mk_msg_dag(bnet, engine.msg_type, onodes);
|
wolffd@0
|
87 engine.disconnected_nodes_bitv = zeros(1,N);
|
wolffd@0
|
88 engine.disconnected_nodes_bitv(disconnected_nodes) = 1;
|
wolffd@0
|
89
|
wolffd@0
|
90
|
wolffd@0
|
91 % this is where we store stuff between enter_evidence and marginal_nodes
|
wolffd@0
|
92 engine.marginal = cell(1,N);
|
wolffd@0
|
93 engine.evidence = [];
|
wolffd@0
|
94 engine.msg = [];
|
wolffd@0
|
95
|
wolffd@0
|
96 [engine.parent_index, engine.child_index] = mk_loopy_msg_indices(engine.msg_dag);
|
wolffd@0
|
97
|
wolffd@0
|
98 engine = class(engine, 'pearl_inf_engine', inf_engine(bnet));
|
wolffd@0
|
99
|
wolffd@0
|
100
|
wolffd@0
|
101 %%%%%%%%%
|
wolffd@0
|
102
|
wolffd@0
|
103 function [dag, disconnected_nodes] = mk_msg_dag(bnet, msg_type, onodes)
|
wolffd@0
|
104
|
wolffd@0
|
105 % If we are using Gaussian msgs, all discrete nodes must be observed;
|
wolffd@0
|
106 % they are then disconnected from the graph, so we don't try to send
|
wolffd@0
|
107 % msgs to/from them: their observed value simply serves to index into
|
wolffd@0
|
108 % the right set of parameters for the Gaussian nodes (which use CPD.ps
|
wolffd@0
|
109 % instead of parents(dag), and hence are unaffected by this "surgery").
|
wolffd@0
|
110
|
wolffd@0
|
111 disconnected_nodes = [];
|
wolffd@0
|
112 switch msg_type
|
wolffd@0
|
113 case 'd', dag = bnet.dag;
|
wolffd@0
|
114 case 'g',
|
wolffd@0
|
115 disconnected_nodes = bnet.dnodes;
|
wolffd@0
|
116 dag = bnet.dag;
|
wolffd@0
|
117 for i=disconnected_nodes(:)'
|
wolffd@0
|
118 ps = parents(bnet.dag, i);
|
wolffd@0
|
119 cs = children(bnet.dag, i);
|
wolffd@0
|
120 if ~isempty(ps), dag(ps, i) = 0; end
|
wolffd@0
|
121 if ~isempty(cs), dag(i, cs) = 0; end
|
wolffd@0
|
122 end
|
wolffd@0
|
123 end
|
wolffd@0
|
124
|
wolffd@0
|
125
|
wolffd@0
|
126 %%%%%%%%%%
|
wolffd@0
|
127 function [parent_index, child_index] = mk_loopy_msg_indices(dag)
|
wolffd@0
|
128 % MK_LOOPY_MSG_INDICES Compute "port numbers" for message passing
|
wolffd@0
|
129 % [parent_index, child_index] = mk_loopy_msg_indices(bnet)
|
wolffd@0
|
130 %
|
wolffd@0
|
131 % child_index{n}(c) = i means c is n's i'th child, i.e., i = find_equiv_posns(c, children(n))
|
wolffd@0
|
132 % child_index{n}(c) = 0 means c is not a child of n.
|
wolffd@0
|
133 % parent_index{n}{p} is defined similarly.
|
wolffd@0
|
134 % We need to use these indices since the pi_from_parent/ lambda_from_child cell arrays
|
wolffd@0
|
135 % cannot be sparse, and hence cannot be indexed by the actual number of the node.
|
wolffd@0
|
136 % Instead, we use the number of the "port" on which the message arrived.
|
wolffd@0
|
137
|
wolffd@0
|
138 N = length(dag);
|
wolffd@0
|
139 child_index = cell(1,N);
|
wolffd@0
|
140 parent_index = cell(1,N);
|
wolffd@0
|
141 for n=1:N
|
wolffd@0
|
142 cs = children(dag, n);
|
wolffd@0
|
143 child_index{n} = sparse(1,N);
|
wolffd@0
|
144 for i=1:length(cs)
|
wolffd@0
|
145 c = cs(i);
|
wolffd@0
|
146 child_index{n}(c) = i;
|
wolffd@0
|
147 end
|
wolffd@0
|
148 ps = parents(dag, n);
|
wolffd@0
|
149 parent_index{n} = sparse(1,N);
|
wolffd@0
|
150 for i=1:length(ps)
|
wolffd@0
|
151 p = ps(i);
|
wolffd@0
|
152 parent_index{n}(p) = i;
|
wolffd@0
|
153 end
|
wolffd@0
|
154 end
|
wolffd@0
|
155
|
wolffd@0
|
156
|
wolffd@0
|
157
|
wolffd@0
|
158
|