wolffd@0
|
1 function [msg, niter] = parallel_protocol(engine, evidence, msg)
|
wolffd@0
|
2
|
wolffd@0
|
3 bnet = bnet_from_engine(engine);
|
wolffd@0
|
4 N = length(bnet.dag);
|
wolffd@0
|
5 ns = bnet.node_sizes(:);
|
wolffd@0
|
6
|
wolffd@0
|
7 if ~isempty(engine.filename)
|
wolffd@0
|
8 fid = fopen(engine.filename, 'w');
|
wolffd@0
|
9 if fid == 0
|
wolffd@0
|
10 error(['could not open ' engine.filename ' for writing'])
|
wolffd@0
|
11 end
|
wolffd@0
|
12 else
|
wolffd@0
|
13 fid = [];
|
wolffd@0
|
14 end
|
wolffd@0
|
15
|
wolffd@0
|
16 converged = 0;
|
wolffd@0
|
17 iter = 1;
|
wolffd@0
|
18 hidden = find(isemptycell(evidence));
|
wolffd@0
|
19 bel = cell(1,N);
|
wolffd@0
|
20 old_bel = cell(1,N);
|
wolffd@0
|
21 %nodes = mysetdiff(1:N, engine.disconnected_nodes);
|
wolffd@0
|
22 nodes = find(~engine.disconnected_nodes_bitv);
|
wolffd@0
|
23 while ~converged & (iter <= engine.max_iter)
|
wolffd@0
|
24 % Everybody updates their state in parallel
|
wolffd@0
|
25 for n=nodes(:)'
|
wolffd@0
|
26 cs_msg = children(engine.msg_dag, n);
|
wolffd@0
|
27 %msg{n}.lambda = compute_lambda(n, cs, msg);
|
wolffd@0
|
28 msg{n}.lambda = prod_lambda_msgs(n, cs_msg, msg, engine.msg_type);
|
wolffd@0
|
29 ps_orig = parents(bnet.dag, n);
|
wolffd@0
|
30 msg{n}.pi = CPD_to_pi(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, msg, evidence);
|
wolffd@0
|
31 end
|
wolffd@0
|
32
|
wolffd@0
|
33 changed = 0;
|
wolffd@0
|
34 if ~isempty(fid)
|
wolffd@0
|
35 fprintf(fid, 'ITERATION %d\n', iter);
|
wolffd@0
|
36 end
|
wolffd@0
|
37 for n=hidden(:)' % this will not contain any disconnected nodes
|
wolffd@0
|
38 old_bel{n} = bel{n};
|
wolffd@0
|
39 bel{n} = compute_bel(engine.msg_type, msg{n}.pi, msg{n}.lambda);
|
wolffd@0
|
40 if ~isempty(fid)
|
wolffd@0
|
41 fprintf(fid, 'node %d: %s\n', n, bel_to_str(bel{n}, engine.msg_type));
|
wolffd@0
|
42 end
|
wolffd@0
|
43 if engine.storebel
|
wolffd@0
|
44 engine.bel{n,iter} = bel{n};
|
wolffd@0
|
45 end
|
wolffd@0
|
46 if (iter == 1) | ~approxeq_bel(bel{n}, old_bel{n}, engine.tol, engine.msg_type)
|
wolffd@0
|
47 changed = 1;
|
wolffd@0
|
48 end
|
wolffd@0
|
49 end
|
wolffd@0
|
50 %converged = ~changed;
|
wolffd@0
|
51 converged = ~changed & (iter > 1); % Sonia Leach changed this
|
wolffd@0
|
52
|
wolffd@0
|
53 if ~converged
|
wolffd@0
|
54 % Everybody sends to all their neighbors in parallel
|
wolffd@0
|
55 for n=nodes(:)'
|
wolffd@0
|
56 % lambda msgs to parents
|
wolffd@0
|
57 ps_msg = parents(engine.msg_dag, n);
|
wolffd@0
|
58 ps_orig = parents(bnet.dag, n);
|
wolffd@0
|
59 for p=ps_msg(:)'
|
wolffd@0
|
60 j = engine.child_index{p}(n); % n is p's j'th child
|
wolffd@0
|
61 old_msg = msg{p}.lambda_from_child{j}(:);
|
wolffd@0
|
62 new_msg = CPD_to_lambda_msg(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, ...
|
wolffd@0
|
63 msg, p, evidence);
|
wolffd@0
|
64 lam_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type);
|
wolffd@0
|
65 msg{p}.lambda_from_child{j} = lam_msg;
|
wolffd@0
|
66 end
|
wolffd@0
|
67
|
wolffd@0
|
68 % pi msgs to children
|
wolffd@0
|
69 cs_msg = children(engine.msg_dag, n);
|
wolffd@0
|
70 for c=cs_msg(:)'
|
wolffd@0
|
71 j = engine.parent_index{c}(n); % n is c's j'th parent
|
wolffd@0
|
72 old_msg = msg{c}.pi_from_parent{j}(:);
|
wolffd@0
|
73 %new_msg = compute_pi_msg(n, cs, msg, c));
|
wolffd@0
|
74 new_msg = compute_bel(engine.msg_type, msg{n}.pi, prod_lambda_msgs(n, cs_msg, msg, engine.msg_type, c));
|
wolffd@0
|
75 pi_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type);
|
wolffd@0
|
76 msg{c}.pi_from_parent{j} = pi_msg;
|
wolffd@0
|
77 end
|
wolffd@0
|
78 end
|
wolffd@0
|
79 iter = iter + 1;
|
wolffd@0
|
80 end
|
wolffd@0
|
81 end
|
wolffd@0
|
82
|
wolffd@0
|
83 if fid > 0, fclose(fid); end
|
wolffd@0
|
84 %niter = iter - 1;
|
wolffd@0
|
85 niter = iter;
|
wolffd@0
|
86
|
wolffd@0
|
87 %%%%%%%%%%
|
wolffd@0
|
88
|
wolffd@0
|
89 function str = bel_to_str(bel, type)
|
wolffd@0
|
90
|
wolffd@0
|
91 switch type
|
wolffd@0
|
92 case 'd', str = sprintf('%9.4f ', bel(:)');
|
wolffd@0
|
93 case 'g', str = sprintf('%9.4f ', bel.mu(:)');
|
wolffd@0
|
94 end
|
wolffd@0
|
95
|
wolffd@0
|
96
|
wolffd@0
|
97 %%%%%%%
|
wolffd@0
|
98
|
wolffd@0
|
99 function a = approxeq_bel(bel1, bel2, tol, type)
|
wolffd@0
|
100
|
wolffd@0
|
101 switch type
|
wolffd@0
|
102 case 'd', a = approxeq(bel1, bel2, tol);
|
wolffd@0
|
103 case 'g', a = approxeq(bel1.mu, bel2.mu, tol) & approxeq(bel1.Sigma, bel2.Sigma, tol);
|
wolffd@0
|
104 end
|
wolffd@0
|
105
|
wolffd@0
|
106
|
wolffd@0
|
107 %%%%%%%
|
wolffd@0
|
108
|
wolffd@0
|
109 function msg = convex_combination_msg(old_msg, new_msg, old_weight, type)
|
wolffd@0
|
110
|
wolffd@0
|
111 switch type
|
wolffd@0
|
112 case 'd', msg = old_weight * old_msg + (1-old_weight)*new_msg;
|
wolffd@0
|
113 case 'g', msg = new_msg;
|
wolffd@0
|
114 end
|