Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@pearl_inf_engine/private/parallel_protocol.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [msg, niter] = parallel_protocol(engine, evidence, msg) | |
2 | |
3 bnet = bnet_from_engine(engine); | |
4 N = length(bnet.dag); | |
5 ns = bnet.node_sizes(:); | |
6 | |
7 if ~isempty(engine.filename) | |
8 fid = fopen(engine.filename, 'w'); | |
9 if fid == 0 | |
10 error(['could not open ' engine.filename ' for writing']) | |
11 end | |
12 else | |
13 fid = []; | |
14 end | |
15 | |
16 converged = 0; | |
17 iter = 1; | |
18 hidden = find(isemptycell(evidence)); | |
19 bel = cell(1,N); | |
20 old_bel = cell(1,N); | |
21 %nodes = mysetdiff(1:N, engine.disconnected_nodes); | |
22 nodes = find(~engine.disconnected_nodes_bitv); | |
23 while ~converged & (iter <= engine.max_iter) | |
24 % Everybody updates their state in parallel | |
25 for n=nodes(:)' | |
26 cs_msg = children(engine.msg_dag, n); | |
27 %msg{n}.lambda = compute_lambda(n, cs, msg); | |
28 msg{n}.lambda = prod_lambda_msgs(n, cs_msg, msg, engine.msg_type); | |
29 ps_orig = parents(bnet.dag, n); | |
30 msg{n}.pi = CPD_to_pi(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, msg, evidence); | |
31 end | |
32 | |
33 changed = 0; | |
34 if ~isempty(fid) | |
35 fprintf(fid, 'ITERATION %d\n', iter); | |
36 end | |
37 for n=hidden(:)' % this will not contain any disconnected nodes | |
38 old_bel{n} = bel{n}; | |
39 bel{n} = compute_bel(engine.msg_type, msg{n}.pi, msg{n}.lambda); | |
40 if ~isempty(fid) | |
41 fprintf(fid, 'node %d: %s\n', n, bel_to_str(bel{n}, engine.msg_type)); | |
42 end | |
43 if engine.storebel | |
44 engine.bel{n,iter} = bel{n}; | |
45 end | |
46 if (iter == 1) | ~approxeq_bel(bel{n}, old_bel{n}, engine.tol, engine.msg_type) | |
47 changed = 1; | |
48 end | |
49 end | |
50 %converged = ~changed; | |
51 converged = ~changed & (iter > 1); % Sonia Leach changed this | |
52 | |
53 if ~converged | |
54 % Everybody sends to all their neighbors in parallel | |
55 for n=nodes(:)' | |
56 % lambda msgs to parents | |
57 ps_msg = parents(engine.msg_dag, n); | |
58 ps_orig = parents(bnet.dag, n); | |
59 for p=ps_msg(:)' | |
60 j = engine.child_index{p}(n); % n is p's j'th child | |
61 old_msg = msg{p}.lambda_from_child{j}(:); | |
62 new_msg = CPD_to_lambda_msg(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, ... | |
63 msg, p, evidence); | |
64 lam_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type); | |
65 msg{p}.lambda_from_child{j} = lam_msg; | |
66 end | |
67 | |
68 % pi msgs to children | |
69 cs_msg = children(engine.msg_dag, n); | |
70 for c=cs_msg(:)' | |
71 j = engine.parent_index{c}(n); % n is c's j'th parent | |
72 old_msg = msg{c}.pi_from_parent{j}(:); | |
73 %new_msg = compute_pi_msg(n, cs, msg, c)); | |
74 new_msg = compute_bel(engine.msg_type, msg{n}.pi, prod_lambda_msgs(n, cs_msg, msg, engine.msg_type, c)); | |
75 pi_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type); | |
76 msg{c}.pi_from_parent{j} = pi_msg; | |
77 end | |
78 end | |
79 iter = iter + 1; | |
80 end | |
81 end | |
82 | |
83 if fid > 0, fclose(fid); end | |
84 %niter = iter - 1; | |
85 niter = iter; | |
86 | |
87 %%%%%%%%%% | |
88 | |
89 function str = bel_to_str(bel, type) | |
90 | |
91 switch type | |
92 case 'd', str = sprintf('%9.4f ', bel(:)'); | |
93 case 'g', str = sprintf('%9.4f ', bel.mu(:)'); | |
94 end | |
95 | |
96 | |
97 %%%%%%% | |
98 | |
99 function a = approxeq_bel(bel1, bel2, tol, type) | |
100 | |
101 switch type | |
102 case 'd', a = approxeq(bel1, bel2, tol); | |
103 case 'g', a = approxeq(bel1.mu, bel2.mu, tol) & approxeq(bel1.Sigma, bel2.Sigma, tol); | |
104 end | |
105 | |
106 | |
107 %%%%%%% | |
108 | |
109 function msg = convex_combination_msg(old_msg, new_msg, old_weight, type) | |
110 | |
111 switch type | |
112 case 'd', msg = old_weight * old_msg + (1-old_weight)*new_msg; | |
113 case 'g', msg = new_msg; | |
114 end |