diff toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@frontier_inf_engine/frontier_inf_engine.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@frontier_inf_engine/frontier_inf_engine.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,121 @@
+function engine = frontier_inf_engine(bnet)
+% FRONTIER_INF_ENGINE Inference engine for DBNs which which uses the frontier algorithm.
+% engine = frontier_inf_engine(bnet)
+%
+% The frontier algorithm extends the forwards-backwards algorithm to DBNs in the obvious way,
+% maintaining a joint distribution (frontier) over all the nodes in a time slice.
+% When all the hidden nodes in the DBN are persistent (have children in the next time slice),
+% its theoretical running time is often similar to that of the junction tree algorithm,
+% although in practice, this algorithm seems to very slow (at least in matlab).
+% However, it is extremely simple to describe and implement.
+%
+% Suppose there are n binary nodes per slice, so the frontier takes O(2^n) space.
+% Each time step takes between O(n 2^{n+1}) and O(n 2^{2n}) operations, depending on the graph structure.
+% The lower bound is achieved by a set of n independent chains, as in a factorial HMM.
+% The upper bound is achieved by a set of n fully interconnected chains, as in an HMM.
+%
+% The factor of n arises because we need to multiply in each CPD from slice t+1.
+% The second factor depends on the size of the frontier to which we add the new node.
+% In an FHMM, once we have added X(i,t+1), we can marginalize out X(i,t) from the frontier, since
+% no other nodes depend on it; hence the frontier never contains more than n+1 nodes.
+% In a fully coupled HMM, we must leave X(i,t) in the frontier until all X(j,t+1) have been
+% added; hence the frontier will contain 2*n nodes at its peak.
+%
+% For details, see
+%   "The Factored Frontier Algorithm for Approximate Inference in DBNs",
+%   Kevin Murphy and Yair Weiss, UAI 01.
+
+ns = bnet.node_sizes_slice;
+onodes = bnet.observed;
+ns(onodes) = 1;
+ss = length(bnet.intra);
+
+[engine.ops, engine.fdom] = best_first_frontier_seq(ns, bnet.dag);
+engine.ops1 = 1:ss;
+
+engine.fwdback = [];
+engine.fwd_frontier = [];
+engine.back_frontier = [];
+
+engine.fdom1 = cell(1,ss);
+for s=1:ss
+  engine.fdom1{s} = 1:s;
+end
+
+engine = class(engine, 'frontier_inf_engine', inf_engine(bnet));
+
+
+%%%%%%%%%
+
+function [ops, frontier_set] = best_first_frontier_seq(ns, dag)
+% BEST_FIRST_FRONTIER_SEQ Do a greedy search for the sequence of additions/removals to the frontier.
+% [ops, frontier_set] = best_first_frontier_seq(ns, dag)
+%
+% We maintain 3 sets: the frontier (F), the right set (R), and the left set (L).
+% The invariant is that the nodes in R are d-separated from L given F.
+% We start with slice 1 in F and slice 2 in R.
+% The goal is to move slice 1 from F to L, and slice 2 from R to F, so as to minimize the size
+% of the frontier at each step, where the size(F) = product of the node-sizes of nodes in F.
+% A node may be removed (from F to L) if it has no children in R.
+% A node may be added (from R to F) if its parents are in F.
+%
+% ns(i) = num. discrete values node i can take on (i=1..ss, where ss = slice size)
+% dag is the (2*ss) x (2*ss) adjacency matrix for the 2-slice DBN.
+
+% Example:
+%
+% 4    9
+% ^    ^
+% |    |
+% 2 -> 7
+% ^    ^
+% |    |
+% 1 -> 6
+% |    |
+% v    v
+% 3 -> 8
+% |    |
+% v    V
+% 5    10
+%
+% ops = -4, -5, 6, -1, 7, -2, 8, -3, 9, 10
+
+ss = length(ns);
+ns = [ns(:)' ns(:)'];
+ops = zeros(1,ss);
+L = []; F = 1:ss; R = (1:ss)+ss;
+frontier_set = cell(1,2*ss);
+for s=1:2*ss
+  remcost = inf*ones(1,2*ss);
+  %disp(['L: ' num2str(L) ', F: ' num2str(F) ', R: ' num2str(R)]);
+  maybe_removable = myintersect(F, 1:ss);
+  for n=maybe_removable(:)'
+    cs = children(dag, n);
+    if isempty(myintersect(cs, R))
+      remcost(n) = prod(ns(mysetdiff(F, n)));
+    end
+  end
+  %remcost
+  if any(remcost < inf)
+    n = argmin(remcost);
+    ops(s) = -n;
+    L = myunion(L, n);
+    F = mysetdiff(F, n);
+  else
+    addcost = inf*ones(1,2*ss);
+    for n=R(:)'
+      ps = parents(dag, n);
+      if mysubset(ps, F)
+	addcost(n) = prod(ns(myunion(F, [ps n])));
+      end
+    end
+    %addcost
+    assert(any(addcost < inf));
+    n = argmin(addcost);
+    ops(s) = n;
+    R  = mysetdiff(R, n);
+    F = myunion(F, n);
+  end
+  %fprintf('op at step %d = %d\n\n', s, ops(s));
+  frontier_set{s} = F;
+end