wolffd@0
|
1 function [fwdback, loglik, fwd_frontier, back_frontier] = enter_soft_evidence(engine, CPD, onodes, pot_type, filter)
|
wolffd@0
|
2 % ENTER_SOFT_EVIDENCE Add soft evidence to network (frontier)
|
wolffd@0
|
3 % [fwdback, loglik] = enter_soft_evidence(engine, CPDpot, onodes, filter)
|
wolffd@0
|
4
|
wolffd@0
|
5 if nargin < 3, filter = 0; end
|
wolffd@0
|
6
|
wolffd@0
|
7 [ss T] = size(CPD);
|
wolffd@0
|
8 bnet = bnet_from_engine(engine);
|
wolffd@0
|
9 ns = repmat(bnet.node_sizes_slice(:), 1, T);
|
wolffd@0
|
10 cnodes = unroll_set(bnet.cnodes(:), ss, T);
|
wolffd@0
|
11
|
wolffd@0
|
12 % FORWARDS
|
wolffd@0
|
13 fwd = cell(ss,T);
|
wolffd@0
|
14 ll = zeros(1,T);
|
wolffd@0
|
15 S = 2*ss; % num. intermediate frontiers to get from t to t+1
|
wolffd@0
|
16 frontier = cell(S,T);
|
wolffd@0
|
17
|
wolffd@0
|
18 % Start with empty frontier, and add each node in slice 1
|
wolffd@0
|
19 init = mk_initial_pot(pot_type, [], ns, cnodes, onodes);
|
wolffd@0
|
20 t = 1;
|
wolffd@0
|
21 s = 1;
|
wolffd@0
|
22 j = 1;
|
wolffd@0
|
23 frontier{s,t} = update(init, j, 1, CPD{j}, engine.fdom1{s}, pot_type, ns, cnodes, onodes);
|
wolffd@0
|
24 fwd{j} = frontier{s,t};
|
wolffd@0
|
25 for s=2:ss
|
wolffd@0
|
26 j = s; % add node j at step s
|
wolffd@0
|
27 frontier{s,t} = update(frontier{s-1,t}, j, 1, CPD{j}, engine.fdom1{s}, pot_type, ns, cnodes, onodes);
|
wolffd@0
|
28 fwd{j} = frontier{s,t};
|
wolffd@0
|
29 end
|
wolffd@0
|
30 frontier{S,t} = frontier{ss,t};
|
wolffd@0
|
31 [frontier{S,t}, ll(1)] = normalize_pot(frontier{S,t});
|
wolffd@0
|
32
|
wolffd@0
|
33 % Now move frontier from slice to slice
|
wolffd@0
|
34 OPS = engine.ops;
|
wolffd@0
|
35 add = OPS>0;
|
wolffd@0
|
36 nodes = [zeros(S,1) unroll_set(abs(OPS(:)), ss, T-1)];
|
wolffd@0
|
37 for t=2:T
|
wolffd@0
|
38 offset = (t-2)*ss;
|
wolffd@0
|
39 for s=1:S
|
wolffd@0
|
40 if s==1
|
wolffd@0
|
41 prev_ndx = (t-2)*S + S; % S,t-1
|
wolffd@0
|
42 else
|
wolffd@0
|
43 prev_ndx = (t-1)*S + s-1; % s-1,t
|
wolffd@0
|
44 end
|
wolffd@0
|
45 j = nodes(s,t);
|
wolffd@0
|
46 frontier{s,t} = update(frontier{prev_ndx}, j, add(s), CPD{j}, engine.fdom{s}+offset, pot_type, ns, cnodes, onodes);
|
wolffd@0
|
47 if add(s)
|
wolffd@0
|
48 fwd{j} = frontier{s,t};
|
wolffd@0
|
49 end
|
wolffd@0
|
50 end
|
wolffd@0
|
51 [frontier{S,t}, ll(t)] = normalize_pot(frontier{S,t});
|
wolffd@0
|
52 end
|
wolffd@0
|
53 loglik = sum(ll);
|
wolffd@0
|
54
|
wolffd@0
|
55
|
wolffd@0
|
56 fwd_frontier = frontier;
|
wolffd@0
|
57
|
wolffd@0
|
58 if filter
|
wolffd@0
|
59 fwdback = fwd;
|
wolffd@0
|
60 return;
|
wolffd@0
|
61 end
|
wolffd@0
|
62
|
wolffd@0
|
63
|
wolffd@0
|
64 % BACKWARDS
|
wolffd@0
|
65 back = cell(ss,T);
|
wolffd@0
|
66 add = ~add; % forwards add = backwards remove
|
wolffd@0
|
67 frontier = cell(S,T+1);
|
wolffd@0
|
68 t = T;
|
wolffd@0
|
69 dom = (1:ss) + (t-1)*ss;
|
wolffd@0
|
70 frontier{1,T+1} = mk_initial_pot(pot_type, dom, ns, cnodes, onodes); % all 1s for last slice
|
wolffd@0
|
71 for t=T:-1:2
|
wolffd@0
|
72 offset = (t-2)*ss;
|
wolffd@0
|
73 for s=S:-1:1 % reverse order
|
wolffd@0
|
74 if s==S
|
wolffd@0
|
75 prev_ndx = t*S + 1; % 1,t+1
|
wolffd@0
|
76 else
|
wolffd@0
|
77 prev_ndx = (t-1)*S + (s+1); % s+1,t
|
wolffd@0
|
78 end
|
wolffd@0
|
79 j = nodes(s,t);
|
wolffd@0
|
80 if ~add(s)
|
wolffd@0
|
81 back{j} = frontier{prev_ndx}; % save frontier before removing
|
wolffd@0
|
82 end
|
wolffd@0
|
83 frontier{s,t} = rev_update(frontier{prev_ndx}, t, s, j, add(s), CPD{j}, engine.fdom{s}+offset, pot_type, ns, cnodes, onodes);
|
wolffd@0
|
84 end
|
wolffd@0
|
85 frontier{1,t} = normalize_pot(frontier{1,t});
|
wolffd@0
|
86 end
|
wolffd@0
|
87 % Remove each node in first slice until left with empty set
|
wolffd@0
|
88 t = 1;
|
wolffd@0
|
89 frontier{ss+1,t} = frontier{1,2};
|
wolffd@0
|
90 add = 0;
|
wolffd@0
|
91 for s=ss:-1:1
|
wolffd@0
|
92 j = s; % remove node j at step s
|
wolffd@0
|
93 back{j} = frontier{s+1,t};
|
wolffd@0
|
94 frontier{s,t} = rev_update(frontier{s+1,t}, t, s, j, add, CPD{j}, 1:s, pot_type, ns, cnodes, onodes);
|
wolffd@0
|
95 end
|
wolffd@0
|
96
|
wolffd@0
|
97 % COMBINE
|
wolffd@0
|
98 for t=1:T
|
wolffd@0
|
99 for i=1:ss
|
wolffd@0
|
100 %fwd{i,t} = multiply_by_pot(fwd{i,t}, back{i,t});
|
wolffd@0
|
101 %fwdback{i,t} = normalize_pot(fwd{i,t});
|
wolffd@0
|
102 fwdback{i,t} = normalize_pot(multiply_pots(fwd{i,t}, back{i,t}));
|
wolffd@0
|
103 end
|
wolffd@0
|
104 end
|
wolffd@0
|
105
|
wolffd@0
|
106 back_frontier = frontier;
|
wolffd@0
|
107
|
wolffd@0
|
108 %%%%%%%%%%
|
wolffd@0
|
109 function new_frontier = update(old_frontier, j, add, CPD, newdom, pot_type, ns, cnodes, onodes)
|
wolffd@0
|
110
|
wolffd@0
|
111 if add
|
wolffd@0
|
112 new_frontier = mk_initial_pot(pot_type, newdom, ns, cnodes, onodes);
|
wolffd@0
|
113 new_frontier = multiply_by_pot(new_frontier, old_frontier);
|
wolffd@0
|
114 new_frontier = multiply_by_pot(new_frontier, CPD);
|
wolffd@0
|
115 else
|
wolffd@0
|
116 new_frontier = marginalize_pot(old_frontier, mysetdiff(domain_pot(old_frontier), j));
|
wolffd@0
|
117 end
|
wolffd@0
|
118
|
wolffd@0
|
119
|
wolffd@0
|
120 %%%%%%
|
wolffd@0
|
121 function new_frontier = rev_update(old_frontier, t, s, j, add, CPD, junk, pot_type, ns, cnodes, onodes)
|
wolffd@0
|
122
|
wolffd@0
|
123 olddom = domain_pot(old_frontier);
|
wolffd@0
|
124 assert(isequal(junk, olddom));
|
wolffd@0
|
125
|
wolffd@0
|
126 if add
|
wolffd@0
|
127 % add: extend domain to include j by multiplying by 1
|
wolffd@0
|
128 newdom = myunion(olddom, j);
|
wolffd@0
|
129 new_frontier = mk_initial_pot(pot_type, newdom, ns, cnodes, onodes);
|
wolffd@0
|
130 new_frontier = multiply_by_pot(new_frontier, old_frontier);
|
wolffd@0
|
131 %fprintf('t=%d, s=%d, add %d to %s to make %s\n', t, s, j, num2str(olddom), num2str(newdom));
|
wolffd@0
|
132 else
|
wolffd@0
|
133 % remove: multiply in CPT and then marginalize out j
|
wolffd@0
|
134 % parents of j are guaranteed to be in old_frontier, else couldn't have added j on fwds pass
|
wolffd@0
|
135 old_frontier = multiply_by_pot(old_frontier, CPD);
|
wolffd@0
|
136 newdom = mysetdiff(olddom, j);
|
wolffd@0
|
137 new_frontier = marginalize_pot(old_frontier, newdom);
|
wolffd@0
|
138 %newdom2 = domain_pot(new_frontier);
|
wolffd@0
|
139 %fprintf('t=%d, s=%d, rem %d from %s to make %s\n', t, s, j, num2str(olddom), num2str(newdom2));
|
wolffd@0
|
140 end
|
wolffd@0
|
141
|
wolffd@0
|
142
|