diff toolboxes/FullBNT-1.0.7/bnt/inference/static/@pearl_inf_engine/private/parallel_protocol.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/bnt/inference/static/@pearl_inf_engine/private/parallel_protocol.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,114 @@
+function [msg, niter] = parallel_protocol(engine, evidence, msg)
+
+bnet = bnet_from_engine(engine);
+N = length(bnet.dag);
+ns = bnet.node_sizes(:);
+
+if ~isempty(engine.filename)
+  fid = fopen(engine.filename, 'w');
+  if fid == 0
+    error(['could not open ' engine.filename ' for writing'])
+  end
+else
+  fid = [];
+end
+
+converged = 0;
+iter = 1;
+hidden = find(isemptycell(evidence));
+bel = cell(1,N);
+old_bel = cell(1,N);
+%nodes = mysetdiff(1:N, engine.disconnected_nodes);
+nodes = find(~engine.disconnected_nodes_bitv);
+while ~converged & (iter <= engine.max_iter)
+  % Everybody updates their state in parallel
+  for n=nodes(:)'
+    cs_msg = children(engine.msg_dag, n);
+    %msg{n}.lambda = compute_lambda(n, cs, msg);
+    msg{n}.lambda = prod_lambda_msgs(n, cs_msg, msg, engine.msg_type);
+    ps_orig = parents(bnet.dag, n);
+    msg{n}.pi = CPD_to_pi(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, msg, evidence);
+  end
+  
+  changed = 0;
+  if ~isempty(fid)
+    fprintf(fid, 'ITERATION %d\n', iter);
+  end
+  for n=hidden(:)' % this will not contain any disconnected nodes
+    old_bel{n} = bel{n};
+    bel{n}  = compute_bel(engine.msg_type, msg{n}.pi, msg{n}.lambda);
+    if ~isempty(fid)
+      fprintf(fid, 'node %d: %s\n', n, bel_to_str(bel{n}, engine.msg_type));
+    end
+    if engine.storebel
+      engine.bel{n,iter} = bel{n};
+    end
+    if (iter == 1) | ~approxeq_bel(bel{n}, old_bel{n}, engine.tol, engine.msg_type)
+      changed = 1;
+    end
+  end
+  %converged = ~changed;
+  converged = ~changed & (iter > 1);  % Sonia Leach changed this
+
+  if ~converged
+    % Everybody sends to all their neighbors in parallel
+    for n=nodes(:)'
+      % lambda msgs to parents
+      ps_msg = parents(engine.msg_dag, n);
+      ps_orig = parents(bnet.dag, n);
+      for p=ps_msg(:)'
+	j = engine.child_index{p}(n); % n is p's j'th child
+	old_msg = msg{p}.lambda_from_child{j}(:);
+	new_msg = CPD_to_lambda_msg(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, ...
+				    msg, p, evidence);
+	lam_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type);
+	msg{p}.lambda_from_child{j} = lam_msg;
+      end 
+
+      % pi msgs to children
+      cs_msg = children(engine.msg_dag, n);
+      for c=cs_msg(:)'
+	j = engine.parent_index{c}(n); % n is c's j'th parent
+	old_msg = msg{c}.pi_from_parent{j}(:);
+	%new_msg = compute_pi_msg(n, cs, msg, c));
+	new_msg = compute_bel(engine.msg_type, msg{n}.pi, prod_lambda_msgs(n, cs_msg, msg, engine.msg_type, c));
+	pi_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type);
+	msg{c}.pi_from_parent{j} = pi_msg;
+      end
+    end
+    iter = iter + 1;
+  end
+end
+
+if fid > 0, fclose(fid); end
+%niter = iter - 1;
+niter = iter;
+
+%%%%%%%%%%
+
+function str = bel_to_str(bel, type)
+
+switch type
+ case 'd', str = sprintf('%9.4f ', bel(:)');
+ case 'g', str = sprintf('%9.4f ', bel.mu(:)');
+end
+
+
+%%%%%%%
+
+function a = approxeq_bel(bel1, bel2, tol, type)
+
+switch type
+ case 'd', a = approxeq(bel1, bel2, tol);
+ case 'g', a = approxeq(bel1.mu, bel2.mu, tol) & approxeq(bel1.Sigma, bel2.Sigma, tol);
+end
+
+
+%%%%%%%
+
+function msg = convex_combination_msg(old_msg, new_msg, old_weight, type)
+
+switch type
+ case 'd', msg = old_weight * old_msg + (1-old_weight)*new_msg;
+ case 'g', msg = new_msg;
+end