Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@pearl_dbn_inf_engine/Old/correct_smooth.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@pearl_dbn_inf_engine/Old/correct_smooth.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,244 @@ +function [marginal, msg, loglik] = smooth_evidence(engine, evidence) +% [marginal, msg, loglik] = smooth_evidence(engine, evidence) (pearl_dbn) + +disp('warning: broken'); + +[ss T] = size(evidence); +bnet = bnet_from_engine(engine); +bnet2 = dbn_to_bnet(bnet, T); +ns = bnet2.node_sizes; +hnodes = mysetdiff(1:ss, engine.onodes); +hnodes = hnodes(:)'; + +onodes2 = unroll_set(engine.onodes(:), ss, T); +onodes2 = onodes2(:)'; + +hnodes2 = unroll_set(hnodes(:), ss, T); +hnodes2 = hnodes2(:)'; + +[engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet2); + +msg = init_msgs(bnet2.dag, ns, evidence, bnet2.equiv_class, bnet2.CPD); + +verbose = 0; + +niter = 1; +for iter=1:niter + % FORWARD + for t=1:T + if verbose, fprintf('t=%d\n', t); end + % observed leaves send lambda to parents + for i=engine.onodes(:)' + n = i + (t-1)*ss; + ps = parents(bnet2.dag, n); + for p=ps(:)' + j = engine.child_index{p}(n); % n is p's j'th child + if t > 1 + e = bnet.equiv_class(i, 2); + else + e = bnet.equiv_class(i, 1); + end + lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p)); + msg{p}.lambda_from_child{j} = lam_msg; + if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end + end + end + + % update pi + for i=hnodes + n = i + (t-1)*ss; + ps = parents(bnet2.dag, n); + if t==1 + e = bnet.equiv_class(i,1); + else + e = bnet.equiv_class(i,2); + end + msg{n}.pi = compute_pi(bnet.CPD{e}, n, ps, msg); + if verbose, fprintf('%d computes pi\n', n); disp(msg{n}.pi); end + end + + % send pi msg to children + for i=hnodes + n = i + (t-1)*ss; + %cs = myintersect(children(bnet2.dag, n), hnodes2); + cs = children(bnet2.dag, n); + for c=cs(:)' + j = engine.parent_index{c}(n); % n is c's j'th parent + pi_msg = normalise(compute_pi_msg(n, cs, msg, c, ns)); + msg{c}.pi_from_parent{j} = pi_msg; + if verbose, fprintf('%d sends pi to %d\n', n, c); disp(pi_msg); end + end + end + end + + % BACKWARD + for t=T:-1:1 + if verbose, fprintf('t = %d\n', t); end + % update lambda + for i=hnodes + n = i + (t-1)*ss; + cs = children(bnet2.dag, n); + msg{n}.lambda = compute_lambda(n, cs, msg, ns); + if verbose, fprintf('%d computes lambda\n', n); disp(msg{n}.lambda); end + end + % send lambda msgs to parents + for i=hnodes + n = i + (t-1)*ss; + %ps = myintersect(parents(bnet2.dag, n), hnodes2); + ps = parents(bnet2.dag, n); + for p=ps(:)' + j = engine.child_index{p}(n); % n is p's j'th child + if t > 1 + e = bnet.equiv_class(i, 2); + else + e = bnet.equiv_class(i, 1); + end + lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p)); + msg{p}.lambda_from_child{j} = lam_msg; + if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end + end + end + end + +end + + +marginal = cell(ss,T); +lik = zeros(1,ss*T); +for t=1:T + for i=1:ss + n = i + (t-1)*ss; + [bel, lik(n)] = normalise(msg{n}.pi .* msg{n}.lambda); + marginal{i,t} = bel; + end +end + +loglik = sum(log(lik)); + + + +%%%%%%% + +function lambda = compute_lambda(n, cs, msg, ns) +% Pearl p183 eq 4.50 +lambda = prod_lambda_msgs(n, cs, msg, ns); + +%%%%%%% + +function pi_msg = compute_pi_msg(n, cs, msg, c, ns) +% Pearl p183 eq 4.53 and 4.51 +pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, ns, c); + +%%%%%%%%% + +function lam = prod_lambda_msgs(n, cs, msg, ns, except) + +if nargin < 5, except = -1; end + +lam = msg{n}.lambda_from_self(:); +lam = ones(ns(n), 1); +for i=1:length(cs) + c = cs(i); + if c ~= except + lam = lam .* msg{n}.lambda_from_child{i}; + end +end + + +%%%%%%%%% + +function msg = init_msgs(dag, ns, evidence, eclass, CPD) +% INIT_MSGS Initialize the lambda/pi message and state vectors (pearl_dbn) +% msg = init_msgs(dag, ns, evidence) + +N = length(dag); +msg = cell(1,N); +observed = ~isemptycell(evidence(:)); + +for n=1:N + ps = parents(dag, n); + msg{n}.pi_from_parent = cell(1, length(ps)); + for i=1:length(ps) + p = ps(i); + msg{n}.pi_from_parent{i} = ones(ns(p), 1); + end + + cs = children(dag, n); + msg{n}.lambda_from_child = cell(1, length(cs)); + for i=1:length(cs) + c = cs(i); + msg{n}.lambda_from_child{i} = ones(ns(n), 1); + end + + msg{n}.lambda = ones(ns(n), 1); + msg{n}.lambda_from_self = ones(ns(n), 1); + msg{n}.pi = ones(ns(n), 1); + + % Initialize the lambdas with any evidence + if observed(n) + v = evidence{n}; + %msg{n}.lambda_from_self = zeros(ns(n), 1); + %msg{n}.lambda_from_self(v) = 1; % delta function + msg{n}.lambda = zeros(ns(n), 1); + msg{n}.lambda(v) = 1; % delta function + end + +end + + +%%%%%%%% + +function msg = init_ev_msgs(engine, evidence, msg) + +[ss T] = size(evidence); +bnet = bnet_from_engine(engine); +pot_type = 'd'; +t = 1; +hnodes = mysetdiff(1:ss, engine.onodes); +for i=engine.onodes(:)' + fam = family(bnet.dag, i); + e = bnet.equiv_class(i, 1); + CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1)); + temp = pot_to_marginal(CPDpot); + msg{i}.lambda_from_self = temp.T; +end +for t=2:T + for i=engine.onodes(:)' + fam = family(bnet.dag, i, 2); % extract from slice t + e = bnet.equiv_class(i, 2); + CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t)); + temp = pot_to_marginal(CPDpot); + n = i + (t-1)*ss; + msg{n}.lambda_from_self = temp.T; + end +end + + +%%%%%%%%%%% + +function msg = init_ev_msgs2(engine, evidence, msg) + +[ss T] = size(evidence); +bnet = bnet_from_engine(engine); +pot_type = 'd'; +t = 1; +hnodes = mysetdiff(1:ss, engine.onodes); +for i=engine.onodes(:)' + fam = family(bnet.dag, i); + e = bnet.equiv_class(i, 1); + CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1)); + temp = pot_to_marginal(CPDpot); + msg{i}.lambda_from_self = temp.T; +end +for t=2:T + for i=engine.onodes(:)' + fam = family(bnet.dag, i, 2); % extract from slice t + e = bnet.equiv_class(i, 2); + CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t)); + temp = pot_to_marginal(CPDpot); + n = i + (t-1)*ss; + msg{n}.lambda_from_self = temp.T; + end +end + +