wolffd@0: function [engine, loglik, niter] = enter_evidence(engine, evidence, varargin) wolffd@0: % ENTER_EVIDENCE Add the specified evidence to the network (pearl) wolffd@0: % [engine, loglik, num_iter] = enter_evidence(engine, evidence, ...) wolffd@0: % evidence{i} = [] if if X(i) is hidden, and otherwise contains its observed value (scalar or column vector) wolffd@0: % wolffd@0: % The following optional arguments can be specified in the form of name/value pa irs: wolffd@0: % [default value in brackets] wolffd@0: % wolffd@0: % maximize - if 1, does max-product instead of sum-product [0] wolffd@0: % 'filename' - msgs will be printed to this file, so you can assess convergence while it runs [engine.filename] wolffd@0: % wolffd@0: % e.g., engine = enter_evidence(engine, ev, 'maximize', 1) wolffd@0: % wolffd@0: % For discrete nodes, loglik is the negative Bethe free energy evaluated at the final beliefs. wolffd@0: % For Gaussian nodes, loglik is currently always 0. wolffd@0: % wolffd@0: % 'num_iter' returns the number of iterations used. wolffd@0: wolffd@0: maximize = 0; wolffd@0: filename = engine.filename; wolffd@0: wolffd@0: % parse optional params wolffd@0: args = varargin; wolffd@0: nargs = length(args); wolffd@0: if nargs > 0 wolffd@0: for i=1:2:nargs wolffd@0: switch args{i}, wolffd@0: case 'maximize', maximize = args{i+1}; wolffd@0: case 'filename', filename = args{i+1}; wolffd@0: otherwise, wolffd@0: error(['invalid argument name ' args{i}]); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: if maximize wolffd@0: error('can''t handle max-prop yet') wolffd@0: end wolffd@0: wolffd@0: engine.maximize = maximize; wolffd@0: engine.filename = filename; wolffd@0: engine.bel = []; % reset if necessary wolffd@0: wolffd@0: bnet = bnet_from_engine(engine); wolffd@0: N = length(bnet.dag); wolffd@0: ns = bnet.node_sizes(:); wolffd@0: wolffd@0: observed_bitv = ~isemptycell(evidence); wolffd@0: disconnected = find(engine.disconnected_nodes_bitv); wolffd@0: if ~all(observed_bitv(disconnected)) wolffd@0: error(['The following discrete nodes must be observed: ' num2str(disconnected)]) wolffd@0: end wolffd@0: msg = init_pearl_msgs(engine.msg_type, engine.msg_dag, ns, evidence); wolffd@0: wolffd@0: niter = 1; wolffd@0: switch engine.protocol wolffd@0: case 'parallel', [msg, niter] = parallel_protocol(engine, evidence, msg); wolffd@0: case 'tree', msg = tree_protocol(engine, evidence, msg); wolffd@0: otherwise, wolffd@0: error(['unrecognized protocol ' engine.protocol]) wolffd@0: end wolffd@0: engine.niter = niter; wolffd@0: wolffd@0: engine.marginal = cell(1,N); wolffd@0: nodes = find(~engine.disconnected_nodes_bitv); wolffd@0: for n=nodes(:)' wolffd@0: engine.marginal{n} = compute_bel(engine.msg_type, msg{n}.pi, msg{n}.lambda); wolffd@0: end wolffd@0: wolffd@0: engine.evidence = evidence; % needed by marginal_nodes and marginal_family wolffd@0: engine.msg = msg; % needed by marginal_family wolffd@0: wolffd@0: if (nargout >= 2) wolffd@0: if (engine.msg_type == 'd') wolffd@0: loglik = bethe_free_energy(engine, evidence); wolffd@0: else wolffd@0: loglik = 0; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: wolffd@0: %%%%%%%%%%% wolffd@0: wolffd@0: function msg = init_pearl_msgs(msg_type, dag, ns, evidence) wolffd@0: % INIT_MSGS Initialize the lambda/pi message and state vectors wolffd@0: % msg = init_msgs(dag, ns, evidence) wolffd@0: % wolffd@0: wolffd@0: N = length(dag); wolffd@0: msg = cell(1,N); wolffd@0: observed = ~isemptycell(evidence); wolffd@0: lam_msg = 1; wolffd@0: wolffd@0: for n=1:N wolffd@0: ps = parents(dag, n); wolffd@0: msg{n}.pi_from_parent = cell(1, length(ps)); wolffd@0: for i=1:length(ps) wolffd@0: p = ps(i); wolffd@0: msg{n}.pi_from_parent{i} = mk_msg(msg_type, ns(p)); wolffd@0: end wolffd@0: wolffd@0: cs = children(dag, n); wolffd@0: msg{n}.lambda_from_child = cell(1, length(cs)); wolffd@0: for i=1:length(cs) wolffd@0: c = cs(i); wolffd@0: msg{n}.lambda_from_child{i} = mk_msg(msg_type, ns(n), lam_msg); wolffd@0: end wolffd@0: wolffd@0: msg{n}.lambda = mk_msg(msg_type, ns(n), lam_msg); wolffd@0: msg{n}.pi = mk_msg(msg_type, ns(n)); wolffd@0: wolffd@0: if observed(n) wolffd@0: msg{n}.lambda_from_self = mk_msg_with_evidence(msg_type, ns(n), evidence{n}); wolffd@0: else wolffd@0: msg{n}.lambda_from_self = mk_msg(msg_type, ns(n), lam_msg); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: wolffd@0: %%%%%%%%% wolffd@0: wolffd@0: function msg = mk_msg(msg_type, sz, is_lambda_msg) wolffd@0: wolffd@0: if nargin < 3, is_lambda_msg = 0; end wolffd@0: wolffd@0: switch msg_type wolffd@0: case 'd', msg = ones(sz, 1); wolffd@0: case 'g', wolffd@0: if is_lambda_msg wolffd@0: msg.precision = zeros(sz, sz); wolffd@0: msg.info_state = zeros(sz, 1); wolffd@0: else wolffd@0: msg.Sigma = zeros(sz, sz); wolffd@0: msg.mu = zeros(sz,1); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: %%%%%%%%%%%% wolffd@0: wolffd@0: function msg = mk_msg_with_evidence(msg_type, sz, val) wolffd@0: wolffd@0: switch msg_type wolffd@0: case 'd', wolffd@0: msg = zeros(sz, 1); wolffd@0: msg(val) = 1; wolffd@0: case 'g', wolffd@0: %msg.observed_val = val(:); wolffd@0: msg.precision = inf; wolffd@0: msg.mu = val(:); wolffd@0: end