wolffd@0: function [marginal, msg, loglik] = smooth_evidence(engine, evidence) wolffd@0: % [marginal, msg, loglik] = smooth_evidence(engine, evidence) (pearl_dbn) wolffd@0: wolffd@0: disp('warning: broken'); wolffd@0: wolffd@0: [ss T] = size(evidence); wolffd@0: bnet = bnet_from_engine(engine); wolffd@0: bnet2 = dbn_to_bnet(bnet, T); wolffd@0: ns = bnet2.node_sizes; wolffd@0: hnodes = mysetdiff(1:ss, engine.onodes); wolffd@0: hnodes = hnodes(:)'; wolffd@0: wolffd@0: onodes2 = unroll_set(engine.onodes(:), ss, T); wolffd@0: onodes2 = onodes2(:)'; wolffd@0: wolffd@0: hnodes2 = unroll_set(hnodes(:), ss, T); wolffd@0: hnodes2 = hnodes2(:)'; wolffd@0: wolffd@0: [engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet2); wolffd@0: wolffd@0: msg = init_msgs(bnet2.dag, ns, evidence, bnet2.equiv_class, bnet2.CPD); wolffd@0: wolffd@0: verbose = 0; wolffd@0: wolffd@0: niter = 1; wolffd@0: for iter=1:niter wolffd@0: % FORWARD wolffd@0: for t=1:T wolffd@0: if verbose, fprintf('t=%d\n', t); end wolffd@0: % observed leaves send lambda to parents wolffd@0: for i=engine.onodes(:)' wolffd@0: n = i + (t-1)*ss; wolffd@0: ps = parents(bnet2.dag, n); wolffd@0: for p=ps(:)' wolffd@0: j = engine.child_index{p}(n); % n is p's j'th child wolffd@0: if t > 1 wolffd@0: e = bnet.equiv_class(i, 2); wolffd@0: else wolffd@0: e = bnet.equiv_class(i, 1); wolffd@0: end wolffd@0: lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p)); wolffd@0: msg{p}.lambda_from_child{j} = lam_msg; wolffd@0: if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % update pi wolffd@0: for i=hnodes wolffd@0: n = i + (t-1)*ss; wolffd@0: ps = parents(bnet2.dag, n); wolffd@0: if t==1 wolffd@0: e = bnet.equiv_class(i,1); wolffd@0: else wolffd@0: e = bnet.equiv_class(i,2); wolffd@0: end wolffd@0: msg{n}.pi = compute_pi(bnet.CPD{e}, n, ps, msg); wolffd@0: if verbose, fprintf('%d computes pi\n', n); disp(msg{n}.pi); end wolffd@0: end wolffd@0: wolffd@0: % send pi msg to children wolffd@0: for i=hnodes wolffd@0: n = i + (t-1)*ss; wolffd@0: %cs = myintersect(children(bnet2.dag, n), hnodes2); wolffd@0: cs = children(bnet2.dag, n); wolffd@0: for c=cs(:)' wolffd@0: j = engine.parent_index{c}(n); % n is c's j'th parent wolffd@0: pi_msg = normalise(compute_pi_msg(n, cs, msg, c, ns)); wolffd@0: msg{c}.pi_from_parent{j} = pi_msg; wolffd@0: if verbose, fprintf('%d sends pi to %d\n', n, c); disp(pi_msg); end wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % BACKWARD wolffd@0: for t=T:-1:1 wolffd@0: if verbose, fprintf('t = %d\n', t); end wolffd@0: % update lambda wolffd@0: for i=hnodes wolffd@0: n = i + (t-1)*ss; wolffd@0: cs = children(bnet2.dag, n); wolffd@0: msg{n}.lambda = compute_lambda(n, cs, msg, ns); wolffd@0: if verbose, fprintf('%d computes lambda\n', n); disp(msg{n}.lambda); end wolffd@0: end wolffd@0: % send lambda msgs to parents wolffd@0: for i=hnodes wolffd@0: n = i + (t-1)*ss; wolffd@0: %ps = myintersect(parents(bnet2.dag, n), hnodes2); wolffd@0: ps = parents(bnet2.dag, n); wolffd@0: for p=ps(:)' wolffd@0: j = engine.child_index{p}(n); % n is p's j'th child wolffd@0: if t > 1 wolffd@0: e = bnet.equiv_class(i, 2); wolffd@0: else wolffd@0: e = bnet.equiv_class(i, 1); wolffd@0: end wolffd@0: lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p)); wolffd@0: msg{p}.lambda_from_child{j} = lam_msg; wolffd@0: if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: end wolffd@0: wolffd@0: wolffd@0: marginal = cell(ss,T); wolffd@0: lik = zeros(1,ss*T); wolffd@0: for t=1:T wolffd@0: for i=1:ss wolffd@0: n = i + (t-1)*ss; wolffd@0: [bel, lik(n)] = normalise(msg{n}.pi .* msg{n}.lambda); wolffd@0: marginal{i,t} = bel; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: loglik = sum(log(lik)); wolffd@0: wolffd@0: wolffd@0: wolffd@0: %%%%%%% wolffd@0: wolffd@0: function lambda = compute_lambda(n, cs, msg, ns) wolffd@0: % Pearl p183 eq 4.50 wolffd@0: lambda = prod_lambda_msgs(n, cs, msg, ns); wolffd@0: wolffd@0: %%%%%%% wolffd@0: wolffd@0: function pi_msg = compute_pi_msg(n, cs, msg, c, ns) wolffd@0: % Pearl p183 eq 4.53 and 4.51 wolffd@0: pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, ns, c); wolffd@0: wolffd@0: %%%%%%%%% wolffd@0: wolffd@0: function lam = prod_lambda_msgs(n, cs, msg, ns, except) wolffd@0: wolffd@0: if nargin < 5, except = -1; end wolffd@0: wolffd@0: lam = msg{n}.lambda_from_self(:); wolffd@0: lam = ones(ns(n), 1); wolffd@0: for i=1:length(cs) wolffd@0: c = cs(i); wolffd@0: if c ~= except wolffd@0: lam = lam .* msg{n}.lambda_from_child{i}; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%%%%%%%% wolffd@0: wolffd@0: function msg = init_msgs(dag, ns, evidence, eclass, CPD) wolffd@0: % INIT_MSGS Initialize the lambda/pi message and state vectors (pearl_dbn) wolffd@0: % msg = init_msgs(dag, ns, evidence) wolffd@0: wolffd@0: N = length(dag); wolffd@0: msg = cell(1,N); wolffd@0: observed = ~isemptycell(evidence(:)); wolffd@0: wolffd@0: for n=1:N wolffd@0: ps = parents(dag, n); wolffd@0: msg{n}.pi_from_parent = cell(1, length(ps)); wolffd@0: for i=1:length(ps) wolffd@0: p = ps(i); wolffd@0: msg{n}.pi_from_parent{i} = ones(ns(p), 1); wolffd@0: end wolffd@0: wolffd@0: cs = children(dag, n); wolffd@0: msg{n}.lambda_from_child = cell(1, length(cs)); wolffd@0: for i=1:length(cs) wolffd@0: c = cs(i); wolffd@0: msg{n}.lambda_from_child{i} = ones(ns(n), 1); wolffd@0: end wolffd@0: wolffd@0: msg{n}.lambda = ones(ns(n), 1); wolffd@0: msg{n}.lambda_from_self = ones(ns(n), 1); wolffd@0: msg{n}.pi = ones(ns(n), 1); wolffd@0: wolffd@0: % Initialize the lambdas with any evidence wolffd@0: if observed(n) wolffd@0: v = evidence{n}; wolffd@0: %msg{n}.lambda_from_self = zeros(ns(n), 1); wolffd@0: %msg{n}.lambda_from_self(v) = 1; % delta function wolffd@0: msg{n}.lambda = zeros(ns(n), 1); wolffd@0: msg{n}.lambda(v) = 1; % delta function wolffd@0: end wolffd@0: wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%%%%%%% wolffd@0: wolffd@0: function msg = init_ev_msgs(engine, evidence, msg) wolffd@0: wolffd@0: [ss T] = size(evidence); wolffd@0: bnet = bnet_from_engine(engine); wolffd@0: pot_type = 'd'; wolffd@0: t = 1; wolffd@0: hnodes = mysetdiff(1:ss, engine.onodes); wolffd@0: for i=engine.onodes(:)' wolffd@0: fam = family(bnet.dag, i); wolffd@0: e = bnet.equiv_class(i, 1); wolffd@0: CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1)); wolffd@0: temp = pot_to_marginal(CPDpot); wolffd@0: msg{i}.lambda_from_self = temp.T; wolffd@0: end wolffd@0: for t=2:T wolffd@0: for i=engine.onodes(:)' wolffd@0: fam = family(bnet.dag, i, 2); % extract from slice t wolffd@0: e = bnet.equiv_class(i, 2); wolffd@0: CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t)); wolffd@0: temp = pot_to_marginal(CPDpot); wolffd@0: n = i + (t-1)*ss; wolffd@0: msg{n}.lambda_from_self = temp.T; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%%%%%%%%%% wolffd@0: wolffd@0: function msg = init_ev_msgs2(engine, evidence, msg) wolffd@0: wolffd@0: [ss T] = size(evidence); wolffd@0: bnet = bnet_from_engine(engine); wolffd@0: pot_type = 'd'; wolffd@0: t = 1; wolffd@0: hnodes = mysetdiff(1:ss, engine.onodes); wolffd@0: for i=engine.onodes(:)' wolffd@0: fam = family(bnet.dag, i); wolffd@0: e = bnet.equiv_class(i, 1); wolffd@0: CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1)); wolffd@0: temp = pot_to_marginal(CPDpot); wolffd@0: msg{i}.lambda_from_self = temp.T; wolffd@0: end wolffd@0: for t=2:T wolffd@0: for i=engine.onodes(:)' wolffd@0: fam = family(bnet.dag, i, 2); % extract from slice t wolffd@0: e = bnet.equiv_class(i, 2); wolffd@0: CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t)); wolffd@0: temp = pot_to_marginal(CPDpot); wolffd@0: n = i + (t-1)*ss; wolffd@0: msg{n}.lambda_from_self = temp.T; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: