Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@frontier_inf_engine/enter_soft_evidence.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [fwdback, loglik, fwd_frontier, back_frontier] = enter_soft_evidence(engine, CPD, onodes, pot_type, filter) | |
2 % ENTER_SOFT_EVIDENCE Add soft evidence to network (frontier) | |
3 % [fwdback, loglik] = enter_soft_evidence(engine, CPDpot, onodes, filter) | |
4 | |
5 if nargin < 3, filter = 0; end | |
6 | |
7 [ss T] = size(CPD); | |
8 bnet = bnet_from_engine(engine); | |
9 ns = repmat(bnet.node_sizes_slice(:), 1, T); | |
10 cnodes = unroll_set(bnet.cnodes(:), ss, T); | |
11 | |
12 % FORWARDS | |
13 fwd = cell(ss,T); | |
14 ll = zeros(1,T); | |
15 S = 2*ss; % num. intermediate frontiers to get from t to t+1 | |
16 frontier = cell(S,T); | |
17 | |
18 % Start with empty frontier, and add each node in slice 1 | |
19 init = mk_initial_pot(pot_type, [], ns, cnodes, onodes); | |
20 t = 1; | |
21 s = 1; | |
22 j = 1; | |
23 frontier{s,t} = update(init, j, 1, CPD{j}, engine.fdom1{s}, pot_type, ns, cnodes, onodes); | |
24 fwd{j} = frontier{s,t}; | |
25 for s=2:ss | |
26 j = s; % add node j at step s | |
27 frontier{s,t} = update(frontier{s-1,t}, j, 1, CPD{j}, engine.fdom1{s}, pot_type, ns, cnodes, onodes); | |
28 fwd{j} = frontier{s,t}; | |
29 end | |
30 frontier{S,t} = frontier{ss,t}; | |
31 [frontier{S,t}, ll(1)] = normalize_pot(frontier{S,t}); | |
32 | |
33 % Now move frontier from slice to slice | |
34 OPS = engine.ops; | |
35 add = OPS>0; | |
36 nodes = [zeros(S,1) unroll_set(abs(OPS(:)), ss, T-1)]; | |
37 for t=2:T | |
38 offset = (t-2)*ss; | |
39 for s=1:S | |
40 if s==1 | |
41 prev_ndx = (t-2)*S + S; % S,t-1 | |
42 else | |
43 prev_ndx = (t-1)*S + s-1; % s-1,t | |
44 end | |
45 j = nodes(s,t); | |
46 frontier{s,t} = update(frontier{prev_ndx}, j, add(s), CPD{j}, engine.fdom{s}+offset, pot_type, ns, cnodes, onodes); | |
47 if add(s) | |
48 fwd{j} = frontier{s,t}; | |
49 end | |
50 end | |
51 [frontier{S,t}, ll(t)] = normalize_pot(frontier{S,t}); | |
52 end | |
53 loglik = sum(ll); | |
54 | |
55 | |
56 fwd_frontier = frontier; | |
57 | |
58 if filter | |
59 fwdback = fwd; | |
60 return; | |
61 end | |
62 | |
63 | |
64 % BACKWARDS | |
65 back = cell(ss,T); | |
66 add = ~add; % forwards add = backwards remove | |
67 frontier = cell(S,T+1); | |
68 t = T; | |
69 dom = (1:ss) + (t-1)*ss; | |
70 frontier{1,T+1} = mk_initial_pot(pot_type, dom, ns, cnodes, onodes); % all 1s for last slice | |
71 for t=T:-1:2 | |
72 offset = (t-2)*ss; | |
73 for s=S:-1:1 % reverse order | |
74 if s==S | |
75 prev_ndx = t*S + 1; % 1,t+1 | |
76 else | |
77 prev_ndx = (t-1)*S + (s+1); % s+1,t | |
78 end | |
79 j = nodes(s,t); | |
80 if ~add(s) | |
81 back{j} = frontier{prev_ndx}; % save frontier before removing | |
82 end | |
83 frontier{s,t} = rev_update(frontier{prev_ndx}, t, s, j, add(s), CPD{j}, engine.fdom{s}+offset, pot_type, ns, cnodes, onodes); | |
84 end | |
85 frontier{1,t} = normalize_pot(frontier{1,t}); | |
86 end | |
87 % Remove each node in first slice until left with empty set | |
88 t = 1; | |
89 frontier{ss+1,t} = frontier{1,2}; | |
90 add = 0; | |
91 for s=ss:-1:1 | |
92 j = s; % remove node j at step s | |
93 back{j} = frontier{s+1,t}; | |
94 frontier{s,t} = rev_update(frontier{s+1,t}, t, s, j, add, CPD{j}, 1:s, pot_type, ns, cnodes, onodes); | |
95 end | |
96 | |
97 % COMBINE | |
98 for t=1:T | |
99 for i=1:ss | |
100 %fwd{i,t} = multiply_by_pot(fwd{i,t}, back{i,t}); | |
101 %fwdback{i,t} = normalize_pot(fwd{i,t}); | |
102 fwdback{i,t} = normalize_pot(multiply_pots(fwd{i,t}, back{i,t})); | |
103 end | |
104 end | |
105 | |
106 back_frontier = frontier; | |
107 | |
108 %%%%%%%%%% | |
109 function new_frontier = update(old_frontier, j, add, CPD, newdom, pot_type, ns, cnodes, onodes) | |
110 | |
111 if add | |
112 new_frontier = mk_initial_pot(pot_type, newdom, ns, cnodes, onodes); | |
113 new_frontier = multiply_by_pot(new_frontier, old_frontier); | |
114 new_frontier = multiply_by_pot(new_frontier, CPD); | |
115 else | |
116 new_frontier = marginalize_pot(old_frontier, mysetdiff(domain_pot(old_frontier), j)); | |
117 end | |
118 | |
119 | |
120 %%%%%% | |
121 function new_frontier = rev_update(old_frontier, t, s, j, add, CPD, junk, pot_type, ns, cnodes, onodes) | |
122 | |
123 olddom = domain_pot(old_frontier); | |
124 assert(isequal(junk, olddom)); | |
125 | |
126 if add | |
127 % add: extend domain to include j by multiplying by 1 | |
128 newdom = myunion(olddom, j); | |
129 new_frontier = mk_initial_pot(pot_type, newdom, ns, cnodes, onodes); | |
130 new_frontier = multiply_by_pot(new_frontier, old_frontier); | |
131 %fprintf('t=%d, s=%d, add %d to %s to make %s\n', t, s, j, num2str(olddom), num2str(newdom)); | |
132 else | |
133 % remove: multiply in CPT and then marginalize out j | |
134 % parents of j are guaranteed to be in old_frontier, else couldn't have added j on fwds pass | |
135 old_frontier = multiply_by_pot(old_frontier, CPD); | |
136 newdom = mysetdiff(olddom, j); | |
137 new_frontier = marginalize_pot(old_frontier, newdom); | |
138 %newdom2 = domain_pot(new_frontier); | |
139 %fprintf('t=%d, s=%d, rem %d from %s to make %s\n', t, s, j, num2str(olddom), num2str(newdom2)); | |
140 end | |
141 | |
142 |