Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@frontier_inf_engine/frontier_inf_engine.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function engine = frontier_inf_engine(bnet) | |
2 % FRONTIER_INF_ENGINE Inference engine for DBNs which which uses the frontier algorithm. | |
3 % engine = frontier_inf_engine(bnet) | |
4 % | |
5 % The frontier algorithm extends the forwards-backwards algorithm to DBNs in the obvious way, | |
6 % maintaining a joint distribution (frontier) over all the nodes in a time slice. | |
7 % When all the hidden nodes in the DBN are persistent (have children in the next time slice), | |
8 % its theoretical running time is often similar to that of the junction tree algorithm, | |
9 % although in practice, this algorithm seems to very slow (at least in matlab). | |
10 % However, it is extremely simple to describe and implement. | |
11 % | |
12 % Suppose there are n binary nodes per slice, so the frontier takes O(2^n) space. | |
13 % Each time step takes between O(n 2^{n+1}) and O(n 2^{2n}) operations, depending on the graph structure. | |
14 % The lower bound is achieved by a set of n independent chains, as in a factorial HMM. | |
15 % The upper bound is achieved by a set of n fully interconnected chains, as in an HMM. | |
16 % | |
17 % The factor of n arises because we need to multiply in each CPD from slice t+1. | |
18 % The second factor depends on the size of the frontier to which we add the new node. | |
19 % In an FHMM, once we have added X(i,t+1), we can marginalize out X(i,t) from the frontier, since | |
20 % no other nodes depend on it; hence the frontier never contains more than n+1 nodes. | |
21 % In a fully coupled HMM, we must leave X(i,t) in the frontier until all X(j,t+1) have been | |
22 % added; hence the frontier will contain 2*n nodes at its peak. | |
23 % | |
24 % For details, see | |
25 % "The Factored Frontier Algorithm for Approximate Inference in DBNs", | |
26 % Kevin Murphy and Yair Weiss, UAI 01. | |
27 | |
28 ns = bnet.node_sizes_slice; | |
29 onodes = bnet.observed; | |
30 ns(onodes) = 1; | |
31 ss = length(bnet.intra); | |
32 | |
33 [engine.ops, engine.fdom] = best_first_frontier_seq(ns, bnet.dag); | |
34 engine.ops1 = 1:ss; | |
35 | |
36 engine.fwdback = []; | |
37 engine.fwd_frontier = []; | |
38 engine.back_frontier = []; | |
39 | |
40 engine.fdom1 = cell(1,ss); | |
41 for s=1:ss | |
42 engine.fdom1{s} = 1:s; | |
43 end | |
44 | |
45 engine = class(engine, 'frontier_inf_engine', inf_engine(bnet)); | |
46 | |
47 | |
48 %%%%%%%%% | |
49 | |
50 function [ops, frontier_set] = best_first_frontier_seq(ns, dag) | |
51 % BEST_FIRST_FRONTIER_SEQ Do a greedy search for the sequence of additions/removals to the frontier. | |
52 % [ops, frontier_set] = best_first_frontier_seq(ns, dag) | |
53 % | |
54 % We maintain 3 sets: the frontier (F), the right set (R), and the left set (L). | |
55 % The invariant is that the nodes in R are d-separated from L given F. | |
56 % We start with slice 1 in F and slice 2 in R. | |
57 % The goal is to move slice 1 from F to L, and slice 2 from R to F, so as to minimize the size | |
58 % of the frontier at each step, where the size(F) = product of the node-sizes of nodes in F. | |
59 % A node may be removed (from F to L) if it has no children in R. | |
60 % A node may be added (from R to F) if its parents are in F. | |
61 % | |
62 % ns(i) = num. discrete values node i can take on (i=1..ss, where ss = slice size) | |
63 % dag is the (2*ss) x (2*ss) adjacency matrix for the 2-slice DBN. | |
64 | |
65 % Example: | |
66 % | |
67 % 4 9 | |
68 % ^ ^ | |
69 % | | | |
70 % 2 -> 7 | |
71 % ^ ^ | |
72 % | | | |
73 % 1 -> 6 | |
74 % | | | |
75 % v v | |
76 % 3 -> 8 | |
77 % | | | |
78 % v V | |
79 % 5 10 | |
80 % | |
81 % ops = -4, -5, 6, -1, 7, -2, 8, -3, 9, 10 | |
82 | |
83 ss = length(ns); | |
84 ns = [ns(:)' ns(:)']; | |
85 ops = zeros(1,ss); | |
86 L = []; F = 1:ss; R = (1:ss)+ss; | |
87 frontier_set = cell(1,2*ss); | |
88 for s=1:2*ss | |
89 remcost = inf*ones(1,2*ss); | |
90 %disp(['L: ' num2str(L) ', F: ' num2str(F) ', R: ' num2str(R)]); | |
91 maybe_removable = myintersect(F, 1:ss); | |
92 for n=maybe_removable(:)' | |
93 cs = children(dag, n); | |
94 if isempty(myintersect(cs, R)) | |
95 remcost(n) = prod(ns(mysetdiff(F, n))); | |
96 end | |
97 end | |
98 %remcost | |
99 if any(remcost < inf) | |
100 n = argmin(remcost); | |
101 ops(s) = -n; | |
102 L = myunion(L, n); | |
103 F = mysetdiff(F, n); | |
104 else | |
105 addcost = inf*ones(1,2*ss); | |
106 for n=R(:)' | |
107 ps = parents(dag, n); | |
108 if mysubset(ps, F) | |
109 addcost(n) = prod(ns(myunion(F, [ps n]))); | |
110 end | |
111 end | |
112 %addcost | |
113 assert(any(addcost < inf)); | |
114 n = argmin(addcost); | |
115 ops(s) = n; | |
116 R = mysetdiff(R, n); | |
117 F = myunion(F, n); | |
118 end | |
119 %fprintf('op at step %d = %d\n\n', s, ops(s)); | |
120 frontier_set{s} = F; | |
121 end |