wolffd@0: function [msg, niter] = parallel_protocol(engine, evidence, msg) wolffd@0: wolffd@0: bnet = bnet_from_engine(engine); wolffd@0: N = length(bnet.dag); wolffd@0: ns = bnet.node_sizes(:); wolffd@0: wolffd@0: if ~isempty(engine.filename) wolffd@0: fid = fopen(engine.filename, 'w'); wolffd@0: if fid == 0 wolffd@0: error(['could not open ' engine.filename ' for writing']) wolffd@0: end wolffd@0: else wolffd@0: fid = []; wolffd@0: end wolffd@0: wolffd@0: converged = 0; wolffd@0: iter = 1; wolffd@0: hidden = find(isemptycell(evidence)); wolffd@0: bel = cell(1,N); wolffd@0: old_bel = cell(1,N); wolffd@0: %nodes = mysetdiff(1:N, engine.disconnected_nodes); wolffd@0: nodes = find(~engine.disconnected_nodes_bitv); wolffd@0: while ~converged & (iter <= engine.max_iter) wolffd@0: % Everybody updates their state in parallel wolffd@0: for n=nodes(:)' wolffd@0: cs_msg = children(engine.msg_dag, n); wolffd@0: %msg{n}.lambda = compute_lambda(n, cs, msg); wolffd@0: msg{n}.lambda = prod_lambda_msgs(n, cs_msg, msg, engine.msg_type); wolffd@0: ps_orig = parents(bnet.dag, n); wolffd@0: msg{n}.pi = CPD_to_pi(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, msg, evidence); wolffd@0: end wolffd@0: wolffd@0: changed = 0; wolffd@0: if ~isempty(fid) wolffd@0: fprintf(fid, 'ITERATION %d\n', iter); wolffd@0: end wolffd@0: for n=hidden(:)' % this will not contain any disconnected nodes wolffd@0: old_bel{n} = bel{n}; wolffd@0: bel{n} = compute_bel(engine.msg_type, msg{n}.pi, msg{n}.lambda); wolffd@0: if ~isempty(fid) wolffd@0: fprintf(fid, 'node %d: %s\n', n, bel_to_str(bel{n}, engine.msg_type)); wolffd@0: end wolffd@0: if engine.storebel wolffd@0: engine.bel{n,iter} = bel{n}; wolffd@0: end wolffd@0: if (iter == 1) | ~approxeq_bel(bel{n}, old_bel{n}, engine.tol, engine.msg_type) wolffd@0: changed = 1; wolffd@0: end wolffd@0: end wolffd@0: %converged = ~changed; wolffd@0: converged = ~changed & (iter > 1); % Sonia Leach changed this wolffd@0: wolffd@0: if ~converged wolffd@0: % Everybody sends to all their neighbors in parallel wolffd@0: for n=nodes(:)' wolffd@0: % lambda msgs to parents wolffd@0: ps_msg = parents(engine.msg_dag, n); wolffd@0: ps_orig = parents(bnet.dag, n); wolffd@0: for p=ps_msg(:)' wolffd@0: j = engine.child_index{p}(n); % n is p's j'th child wolffd@0: old_msg = msg{p}.lambda_from_child{j}(:); wolffd@0: new_msg = CPD_to_lambda_msg(bnet.CPD{bnet.equiv_class(n)}, engine.msg_type, n, ps_orig, ... wolffd@0: msg, p, evidence); wolffd@0: lam_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type); wolffd@0: msg{p}.lambda_from_child{j} = lam_msg; wolffd@0: end wolffd@0: wolffd@0: % pi msgs to children wolffd@0: cs_msg = children(engine.msg_dag, n); wolffd@0: for c=cs_msg(:)' wolffd@0: j = engine.parent_index{c}(n); % n is c's j'th parent wolffd@0: old_msg = msg{c}.pi_from_parent{j}(:); wolffd@0: %new_msg = compute_pi_msg(n, cs, msg, c)); wolffd@0: new_msg = compute_bel(engine.msg_type, msg{n}.pi, prod_lambda_msgs(n, cs_msg, msg, engine.msg_type, c)); wolffd@0: pi_msg = convex_combination_msg(old_msg, new_msg, engine.momentum, engine.msg_type); wolffd@0: msg{c}.pi_from_parent{j} = pi_msg; wolffd@0: end wolffd@0: end wolffd@0: iter = iter + 1; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if fid > 0, fclose(fid); end wolffd@0: %niter = iter - 1; wolffd@0: niter = iter; wolffd@0: wolffd@0: %%%%%%%%%% wolffd@0: wolffd@0: function str = bel_to_str(bel, type) wolffd@0: wolffd@0: switch type wolffd@0: case 'd', str = sprintf('%9.4f ', bel(:)'); wolffd@0: case 'g', str = sprintf('%9.4f ', bel.mu(:)'); wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%%%%%% wolffd@0: wolffd@0: function a = approxeq_bel(bel1, bel2, tol, type) wolffd@0: wolffd@0: switch type wolffd@0: case 'd', a = approxeq(bel1, bel2, tol); wolffd@0: case 'g', a = approxeq(bel1.mu, bel2.mu, tol) & approxeq(bel1.Sigma, bel2.Sigma, tol); wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%%%%%% wolffd@0: wolffd@0: function msg = convex_combination_msg(old_msg, new_msg, old_weight, type) wolffd@0: wolffd@0: switch type wolffd@0: case 'd', msg = old_weight * old_msg + (1-old_weight)*new_msg; wolffd@0: case 'g', msg = new_msg; wolffd@0: end