Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@pearl_inf_engine/pearl_inf_engine.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function engine = pearl_inf_engine(bnet, varargin) | |
2 % PEARL_INF_ENGINE Pearl's algorithm (belief propagation) | |
3 % engine = pearl_inf_engine(bnet, ...) | |
4 % | |
5 % If the graph has no loops (undirected cycles), you should use the tree protocol, | |
6 % and the results will be exact. | |
7 % Otherwise, you should use the parallel protocol, and the results may be approximate. | |
8 % | |
9 % Optional arguments [default in brackets] | |
10 % 'protocol' - tree or parallel ['parallel'] | |
11 % | |
12 % Optional arguments for the loopy case | |
13 % 'max_iter' - specifies the max num. iterations to perform [2*num nodes] | |
14 % 'tol' - convergence criterion on messages [1e-3] | |
15 % 'momentum' - msg = (m*old + (1-m)*new). [m=0] | |
16 % 'filename' - msgs will be printed to this file, so you can assess convergence while it runs [[]] | |
17 % 'storebel' - 1 means save engine.bel{n,t} for every iteration t and hidden node n [0] | |
18 % | |
19 % If there are discrete and cts nodes, we assume all the discretes are observed. In this | |
20 % case, you must use the parallel protocol, and the evidence pattern must be fixed. | |
21 | |
22 | |
23 N = length(bnet.dag); | |
24 protocol = 'parallel'; | |
25 max_iter = 2*N; | |
26 % We use N+2 for the following reason: | |
27 % In N iterations, we get the exact answer for a tree. | |
28 % In the N+1st iteration, we notice that the results are the same as before, and terminate. | |
29 % In loopy_converged, we see that N+1 < max = N+2, and declare convergence. | |
30 tol = 1e-3; | |
31 momentum = 0; | |
32 filename = []; | |
33 storebel = 0; | |
34 | |
35 args = varargin; | |
36 for i=1:2:length(args) | |
37 switch args{i}, | |
38 case 'protocol', protocol = args{i+1}; | |
39 case 'max_iter', max_iter = args{i+1}; | |
40 case 'tol', tol = args{i+1}; | |
41 case 'momentum', momentum = args{i+1}; | |
42 case 'filename', filename = args{i+1}; | |
43 case 'storebel', storebel = args{i+1}; | |
44 end | |
45 end | |
46 | |
47 engine.filename = filename; | |
48 engine.storebel = storebel; | |
49 engine.bel = []; | |
50 | |
51 if strcmp(protocol, 'tree') | |
52 % We first send messages up to the root (pivot node), and then back towards the leaves. | |
53 % If the bnet is a singly connected graph (no loops), choosing a root induces a directed tree. | |
54 % Peot and Shachter discuss ways to pick the root so as to minimize the work, | |
55 % taking into account which nodes have changed. | |
56 % For simplicity, we always pick the root to be the last node in the graph. | |
57 % This means the first pass is equivalent to going forward in time in a DBN. | |
58 | |
59 engine.root = N; | |
60 [engine.adj_mat, engine.preorder, engine.postorder, loopy] = ... | |
61 mk_rooted_tree(bnet.dag, engine.root); | |
62 % engine.adj_mat might have different edge orientations from bnet.dag | |
63 if loopy | |
64 error('can only apply tree protocol to loop-less graphs') | |
65 end | |
66 else | |
67 engine.root = []; | |
68 engine.adj_mat = []; | |
69 engine.preorder = []; | |
70 engine.postorder = []; | |
71 end | |
72 | |
73 engine.niter = []; | |
74 engine.protocol = protocol; | |
75 engine.max_iter = max_iter; | |
76 engine.tol = tol; | |
77 engine.momentum = momentum; | |
78 engine.maximize = []; | |
79 | |
80 %onodes = find(~isemptycell(evidence)); | |
81 onodes = bnet.observed; | |
82 engine.msg_type = determine_pot_type(bnet, onodes, 1:N); % needed also by marginal_nodes | |
83 if strcmp(engine.msg_type, 'cg') | |
84 error('messages must be discrete or Gaussian') | |
85 end | |
86 [engine.msg_dag, disconnected_nodes] = mk_msg_dag(bnet, engine.msg_type, onodes); | |
87 engine.disconnected_nodes_bitv = zeros(1,N); | |
88 engine.disconnected_nodes_bitv(disconnected_nodes) = 1; | |
89 | |
90 | |
91 % this is where we store stuff between enter_evidence and marginal_nodes | |
92 engine.marginal = cell(1,N); | |
93 engine.evidence = []; | |
94 engine.msg = []; | |
95 | |
96 [engine.parent_index, engine.child_index] = mk_loopy_msg_indices(engine.msg_dag); | |
97 | |
98 engine = class(engine, 'pearl_inf_engine', inf_engine(bnet)); | |
99 | |
100 | |
101 %%%%%%%%% | |
102 | |
103 function [dag, disconnected_nodes] = mk_msg_dag(bnet, msg_type, onodes) | |
104 | |
105 % If we are using Gaussian msgs, all discrete nodes must be observed; | |
106 % they are then disconnected from the graph, so we don't try to send | |
107 % msgs to/from them: their observed value simply serves to index into | |
108 % the right set of parameters for the Gaussian nodes (which use CPD.ps | |
109 % instead of parents(dag), and hence are unaffected by this "surgery"). | |
110 | |
111 disconnected_nodes = []; | |
112 switch msg_type | |
113 case 'd', dag = bnet.dag; | |
114 case 'g', | |
115 disconnected_nodes = bnet.dnodes; | |
116 dag = bnet.dag; | |
117 for i=disconnected_nodes(:)' | |
118 ps = parents(bnet.dag, i); | |
119 cs = children(bnet.dag, i); | |
120 if ~isempty(ps), dag(ps, i) = 0; end | |
121 if ~isempty(cs), dag(i, cs) = 0; end | |
122 end | |
123 end | |
124 | |
125 | |
126 %%%%%%%%%% | |
127 function [parent_index, child_index] = mk_loopy_msg_indices(dag) | |
128 % MK_LOOPY_MSG_INDICES Compute "port numbers" for message passing | |
129 % [parent_index, child_index] = mk_loopy_msg_indices(bnet) | |
130 % | |
131 % child_index{n}(c) = i means c is n's i'th child, i.e., i = find_equiv_posns(c, children(n)) | |
132 % child_index{n}(c) = 0 means c is not a child of n. | |
133 % parent_index{n}{p} is defined similarly. | |
134 % We need to use these indices since the pi_from_parent/ lambda_from_child cell arrays | |
135 % cannot be sparse, and hence cannot be indexed by the actual number of the node. | |
136 % Instead, we use the number of the "port" on which the message arrived. | |
137 | |
138 N = length(dag); | |
139 child_index = cell(1,N); | |
140 parent_index = cell(1,N); | |
141 for n=1:N | |
142 cs = children(dag, n); | |
143 child_index{n} = sparse(1,N); | |
144 for i=1:length(cs) | |
145 c = cs(i); | |
146 child_index{n}(c) = i; | |
147 end | |
148 ps = parents(dag, n); | |
149 parent_index{n} = sparse(1,N); | |
150 for i=1:length(ps) | |
151 p = ps(i); | |
152 parent_index{n}(p) = i; | |
153 end | |
154 end | |
155 | |
156 | |
157 | |
158 |