diff toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@pearl_dbn_inf_engine/Old/correct_smooth.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@pearl_dbn_inf_engine/Old/correct_smooth.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,244 @@
+function [marginal, msg, loglik] = smooth_evidence(engine, evidence)
+% [marginal, msg, loglik] = smooth_evidence(engine, evidence) (pearl_dbn)
+
+disp('warning: broken');
+
+[ss T] = size(evidence);
+bnet = bnet_from_engine(engine);
+bnet2 = dbn_to_bnet(bnet, T);
+ns = bnet2.node_sizes;
+hnodes = mysetdiff(1:ss, engine.onodes);
+hnodes = hnodes(:)';
+
+onodes2 = unroll_set(engine.onodes(:), ss, T);
+onodes2 = onodes2(:)';
+
+hnodes2 = unroll_set(hnodes(:), ss, T);
+hnodes2 = hnodes2(:)';
+
+[engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet2);
+
+msg = init_msgs(bnet2.dag, ns, evidence, bnet2.equiv_class, bnet2.CPD);
+
+verbose = 0;
+
+niter = 1;
+for iter=1:niter
+  % FORWARD
+  for t=1:T
+    if verbose, fprintf('t=%d\n', t); end
+    % observed leaves send lambda to parents
+    for i=engine.onodes(:)'
+      n = i + (t-1)*ss;
+      ps = parents(bnet2.dag, n);
+      for p=ps(:)'
+	j = engine.child_index{p}(n); % n is p's j'th child
+	if t > 1
+	  e = bnet.equiv_class(i, 2);
+	else
+	  e = bnet.equiv_class(i, 1);
+	end
+	lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p));
+	msg{p}.lambda_from_child{j} = lam_msg;
+	if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end
+      end 
+    end
+
+    % update pi
+    for i=hnodes
+      n = i + (t-1)*ss;
+      ps = parents(bnet2.dag, n);
+      if t==1
+	e = bnet.equiv_class(i,1);
+      else
+	e = bnet.equiv_class(i,2);
+      end
+      msg{n}.pi = compute_pi(bnet.CPD{e}, n, ps, msg);
+      if verbose, fprintf('%d computes pi\n', n); disp(msg{n}.pi); end
+    end
+    
+    % send pi msg to children
+    for i=hnodes
+      n = i + (t-1)*ss;
+      %cs = myintersect(children(bnet2.dag, n), hnodes2);
+      cs = children(bnet2.dag, n);
+      for c=cs(:)'
+	j = engine.parent_index{c}(n); % n is c's j'th parent
+	pi_msg = normalise(compute_pi_msg(n, cs, msg, c, ns));
+	msg{c}.pi_from_parent{j} = pi_msg;
+	if verbose, fprintf('%d sends pi to %d\n', n, c); disp(pi_msg); end
+      end
+    end
+  end
+
+  % BACKWARD
+  for t=T:-1:1
+    if verbose, fprintf('t = %d\n', t); end
+    % update lambda
+    for i=hnodes
+      n = i + (t-1)*ss;
+      cs = children(bnet2.dag, n);
+      msg{n}.lambda = compute_lambda(n, cs, msg, ns);
+      if verbose, fprintf('%d computes lambda\n', n); disp(msg{n}.lambda); end
+    end
+    % send lambda msgs to parents
+    for i=hnodes
+      n = i + (t-1)*ss;
+      %ps = myintersect(parents(bnet2.dag, n), hnodes2);
+      ps = parents(bnet2.dag, n);
+      for p=ps(:)'
+	j = engine.child_index{p}(n); % n is p's j'th child
+	if t > 1
+	  e = bnet.equiv_class(i, 2);
+	else
+	  e = bnet.equiv_class(i, 1);
+	end
+	lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p));
+	msg{p}.lambda_from_child{j} = lam_msg;
+	if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end
+      end 
+    end
+  end
+  
+end
+
+
+marginal = cell(ss,T);
+lik = zeros(1,ss*T);
+for t=1:T
+  for i=1:ss
+    n = i + (t-1)*ss;
+    [bel, lik(n)] = normalise(msg{n}.pi .* msg{n}.lambda);     
+    marginal{i,t} = bel;
+  end
+end
+
+loglik = sum(log(lik));
+
+
+
+%%%%%%%
+
+function lambda = compute_lambda(n, cs, msg, ns)
+% Pearl p183 eq 4.50
+lambda = prod_lambda_msgs(n, cs, msg, ns);
+
+%%%%%%%
+
+function pi_msg = compute_pi_msg(n, cs, msg, c, ns)
+% Pearl p183 eq 4.53 and 4.51
+pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, ns, c);
+
+%%%%%%%%%
+
+function lam = prod_lambda_msgs(n, cs, msg, ns, except)
+
+if nargin < 5, except = -1; end
+
+lam = msg{n}.lambda_from_self(:);
+lam = ones(ns(n), 1);
+for i=1:length(cs)
+  c = cs(i);
+  if c ~= except
+    lam = lam .* msg{n}.lambda_from_child{i};
+  end
+end   
+
+
+%%%%%%%%%
+
+function msg = init_msgs(dag, ns, evidence, eclass, CPD)
+% INIT_MSGS Initialize the lambda/pi message and state vectors (pearl_dbn)
+% msg =  init_msgs(dag, ns, evidence)
+
+N = length(dag);
+msg = cell(1,N);
+observed = ~isemptycell(evidence(:));
+
+for n=1:N
+  ps = parents(dag, n);
+  msg{n}.pi_from_parent = cell(1, length(ps));
+  for i=1:length(ps)
+    p = ps(i);
+    msg{n}.pi_from_parent{i} = ones(ns(p), 1);
+  end
+  
+  cs = children(dag, n);
+  msg{n}.lambda_from_child = cell(1, length(cs));
+  for i=1:length(cs)
+    c = cs(i);
+    msg{n}.lambda_from_child{i} = ones(ns(n), 1);
+  end
+
+  msg{n}.lambda = ones(ns(n), 1);
+  msg{n}.lambda_from_self = ones(ns(n), 1);
+  msg{n}.pi = ones(ns(n), 1);
+
+ % Initialize the lambdas with any evidence
+  if observed(n)
+    v = evidence{n};
+    %msg{n}.lambda_from_self = zeros(ns(n), 1);
+    %msg{n}.lambda_from_self(v) = 1; % delta function
+    msg{n}.lambda = zeros(ns(n), 1);
+    msg{n}.lambda(v) = 1; % delta function
+  end      
+  
+end
+
+
+%%%%%%%%
+
+function msg = init_ev_msgs(engine, evidence, msg)
+
+[ss T] = size(evidence);
+bnet = bnet_from_engine(engine);
+pot_type = 'd';
+t = 1;
+hnodes = mysetdiff(1:ss, engine.onodes);
+for i=engine.onodes(:)'
+  fam = family(bnet.dag, i);
+  e = bnet.equiv_class(i, 1);
+  CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
+  temp = pot_to_marginal(CPDpot);
+  msg{i}.lambda_from_self = temp.T;
+end
+for t=2:T
+  for i=engine.onodes(:)'
+    fam = family(bnet.dag, i, 2); % extract from slice t
+    e = bnet.equiv_class(i, 2);
+    CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
+    temp = pot_to_marginal(CPDpot);
+    n = i + (t-1)*ss;
+    msg{n}.lambda_from_self = temp.T;
+  end
+end
+
+
+%%%%%%%%%%%
+
+function msg = init_ev_msgs2(engine, evidence, msg)
+
+[ss T] = size(evidence);
+bnet = bnet_from_engine(engine);
+pot_type = 'd';
+t = 1;
+hnodes = mysetdiff(1:ss, engine.onodes);
+for i=engine.onodes(:)'
+  fam = family(bnet.dag, i);
+  e = bnet.equiv_class(i, 1);
+  CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
+  temp = pot_to_marginal(CPDpot);
+  msg{i}.lambda_from_self = temp.T;
+end
+for t=2:T
+  for i=engine.onodes(:)'
+    fam = family(bnet.dag, i, 2); % extract from slice t
+    e = bnet.equiv_class(i, 2);
+    CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
+    temp = pot_to_marginal(CPDpot);
+    n = i + (t-1)*ss;
+    msg{n}.lambda_from_self = temp.T;
+  end
+end
+
+