annotate toolboxes/FullBNT-1.0.7/bnt/inference/static/@pearl_inf_engine/private/parallel_protocol.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [msg, niter] = parallel_protocol(engine, evidence, msg)
wolffd@0 2
wolffd@0 3 bnet = bnet_from_engine(engine);
wolffd@0 4 N = length(bnet.dag);
wolffd@0 5 ns = bnet.node_sizes(:);
wolffd@0 6
wolffd@0 7 if ~isempty(engine.filename)
wolffd@0 8 fid = fopen(engine.filename, 'w');
wolffd@0 9 if fid == 0
wolffd@0 10 error(['could not open ' engine.filename ' for writing'])
wolffd@0 11 end
wolffd@0 12 else
wolffd@0 13 fid = [];
wolffd@0 14 end
wolffd@0 15
wolffd@0 16 converged = 0;
wolffd@0 17 iter = 1;
wolffd@0 18 hidden = find(isemptycell(evidence));
wolffd@0 19 bel = cell(1,N);
wolffd@0 20 old_bel = cell(1,N);
wolffd@0 21 %nodes = mysetdiff(1:N, engine.disconnected_nodes);
wolffd@0 22 nodes = find(~engine.disconnected_nodes_bitv);
wolffd@0 23 while ~converged & (iter <= engine.max_iter)
wolffd@0 24 % Everybody updates their state in parallel
wolffd@0 25 for n=nodes(:)'
wolffd@0 26 cs_msg = children(engine.msg_dag, n);
wolffd@0 27 %msg{n}.lambda = compute_lambda(n, cs, msg);
wolffd@0 28 msg{n}.lambda = prod_lambda_msgs(n, cs_msg, msg, engine.msg_type);
wolffd@0 29 ps_orig = parents(bnet.dag, n);
wolffd@0 30 msg{n}.pi = CPD_to_pi(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, msg, evidence);
wolffd@0 31 end
wolffd@0 32
wolffd@0 33 changed = 0;
wolffd@0 34 if ~isempty(fid)
wolffd@0 35 fprintf(fid, 'ITERATION %d\n', iter);
wolffd@0 36 end
wolffd@0 37 for n=hidden(:)' % this will not contain any disconnected nodes
wolffd@0 38 old_bel{n} = bel{n};
wolffd@0 39 bel{n} = compute_bel(engine.msg_type, msg{n}.pi, msg{n}.lambda);
wolffd@0 40 if ~isempty(fid)
wolffd@0 41 fprintf(fid, 'node %d: %s\n', n, bel_to_str(bel{n}, engine.msg_type));
wolffd@0 42 end
wolffd@0 43 if engine.storebel
wolffd@0 44 engine.bel{n,iter} = bel{n};
wolffd@0 45 end
wolffd@0 46 if (iter == 1) | ~approxeq_bel(bel{n}, old_bel{n}, engine.tol, engine.msg_type)
wolffd@0 47 changed = 1;
wolffd@0 48 end
wolffd@0 49 end
wolffd@0 50 %converged = ~changed;
wolffd@0 51 converged = ~changed & (iter > 1); % Sonia Leach changed this
wolffd@0 52
wolffd@0 53 if ~converged
wolffd@0 54 % Everybody sends to all their neighbors in parallel
wolffd@0 55 for n=nodes(:)'
wolffd@0 56 % lambda msgs to parents
wolffd@0 57 ps_msg = parents(engine.msg_dag, n);
wolffd@0 58 ps_orig = parents(bnet.dag, n);
wolffd@0 59 for p=ps_msg(:)'
wolffd@0 60 j = engine.child_index{p}(n); % n is p's j'th child
wolffd@0 61 old_msg = msg{p}.lambda_from_child{j}(:);
wolffd@0 62 new_msg = CPD_to_lambda_msg(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, ...
wolffd@0 63 msg, p, evidence);
wolffd@0 64 lam_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type);
wolffd@0 65 msg{p}.lambda_from_child{j} = lam_msg;
wolffd@0 66 end
wolffd@0 67
wolffd@0 68 % pi msgs to children
wolffd@0 69 cs_msg = children(engine.msg_dag, n);
wolffd@0 70 for c=cs_msg(:)'
wolffd@0 71 j = engine.parent_index{c}(n); % n is c's j'th parent
wolffd@0 72 old_msg = msg{c}.pi_from_parent{j}(:);
wolffd@0 73 %new_msg = compute_pi_msg(n, cs, msg, c));
wolffd@0 74 new_msg = compute_bel(engine.msg_type, msg{n}.pi, prod_lambda_msgs(n, cs_msg, msg, engine.msg_type, c));
wolffd@0 75 pi_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type);
wolffd@0 76 msg{c}.pi_from_parent{j} = pi_msg;
wolffd@0 77 end
wolffd@0 78 end
wolffd@0 79 iter = iter + 1;
wolffd@0 80 end
wolffd@0 81 end
wolffd@0 82
wolffd@0 83 if fid > 0, fclose(fid); end
wolffd@0 84 %niter = iter - 1;
wolffd@0 85 niter = iter;
wolffd@0 86
wolffd@0 87 %%%%%%%%%%
wolffd@0 88
wolffd@0 89 function str = bel_to_str(bel, type)
wolffd@0 90
wolffd@0 91 switch type
wolffd@0 92 case 'd', str = sprintf('%9.4f ', bel(:)');
wolffd@0 93 case 'g', str = sprintf('%9.4f ', bel.mu(:)');
wolffd@0 94 end
wolffd@0 95
wolffd@0 96
wolffd@0 97 %%%%%%%
wolffd@0 98
wolffd@0 99 function a = approxeq_bel(bel1, bel2, tol, type)
wolffd@0 100
wolffd@0 101 switch type
wolffd@0 102 case 'd', a = approxeq(bel1, bel2, tol);
wolffd@0 103 case 'g', a = approxeq(bel1.mu, bel2.mu, tol) & approxeq(bel1.Sigma, bel2.Sigma, tol);
wolffd@0 104 end
wolffd@0 105
wolffd@0 106
wolffd@0 107 %%%%%%%
wolffd@0 108
wolffd@0 109 function msg = convex_combination_msg(old_msg, new_msg, old_weight, type)
wolffd@0 110
wolffd@0 111 switch type
wolffd@0 112 case 'd', msg = old_weight * old_msg + (1-old_weight)*new_msg;
wolffd@0 113 case 'g', msg = new_msg;
wolffd@0 114 end