wolffd@0
|
1 function [marginal, msg, loglik] = smooth_evidence(engine, evidence)
|
wolffd@0
|
2 % [marginal, msg, loglik] = smooth_evidence(engine, evidence) (pearl_dbn)
|
wolffd@0
|
3
|
wolffd@0
|
4 disp('warning: broken');
|
wolffd@0
|
5
|
wolffd@0
|
6 [ss T] = size(evidence);
|
wolffd@0
|
7 bnet = bnet_from_engine(engine);
|
wolffd@0
|
8 bnet2 = dbn_to_bnet(bnet, T);
|
wolffd@0
|
9 ns = bnet2.node_sizes;
|
wolffd@0
|
10 hnodes = mysetdiff(1:ss, engine.onodes);
|
wolffd@0
|
11 hnodes = hnodes(:)';
|
wolffd@0
|
12
|
wolffd@0
|
13 onodes2 = unroll_set(engine.onodes(:), ss, T);
|
wolffd@0
|
14 onodes2 = onodes2(:)';
|
wolffd@0
|
15
|
wolffd@0
|
16 hnodes2 = unroll_set(hnodes(:), ss, T);
|
wolffd@0
|
17 hnodes2 = hnodes2(:)';
|
wolffd@0
|
18
|
wolffd@0
|
19 [engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet2);
|
wolffd@0
|
20
|
wolffd@0
|
21 msg = init_msgs(bnet2.dag, ns, evidence, bnet2.equiv_class, bnet2.CPD);
|
wolffd@0
|
22
|
wolffd@0
|
23 verbose = 0;
|
wolffd@0
|
24
|
wolffd@0
|
25 niter = 1;
|
wolffd@0
|
26 for iter=1:niter
|
wolffd@0
|
27 % FORWARD
|
wolffd@0
|
28 for t=1:T
|
wolffd@0
|
29 if verbose, fprintf('t=%d\n', t); end
|
wolffd@0
|
30 % observed leaves send lambda to parents
|
wolffd@0
|
31 for i=engine.onodes(:)'
|
wolffd@0
|
32 n = i + (t-1)*ss;
|
wolffd@0
|
33 ps = parents(bnet2.dag, n);
|
wolffd@0
|
34 for p=ps(:)'
|
wolffd@0
|
35 j = engine.child_index{p}(n); % n is p's j'th child
|
wolffd@0
|
36 if t > 1
|
wolffd@0
|
37 e = bnet.equiv_class(i, 2);
|
wolffd@0
|
38 else
|
wolffd@0
|
39 e = bnet.equiv_class(i, 1);
|
wolffd@0
|
40 end
|
wolffd@0
|
41 lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p));
|
wolffd@0
|
42 msg{p}.lambda_from_child{j} = lam_msg;
|
wolffd@0
|
43 if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end
|
wolffd@0
|
44 end
|
wolffd@0
|
45 end
|
wolffd@0
|
46
|
wolffd@0
|
47 % update pi
|
wolffd@0
|
48 for i=hnodes
|
wolffd@0
|
49 n = i + (t-1)*ss;
|
wolffd@0
|
50 ps = parents(bnet2.dag, n);
|
wolffd@0
|
51 if t==1
|
wolffd@0
|
52 e = bnet.equiv_class(i,1);
|
wolffd@0
|
53 else
|
wolffd@0
|
54 e = bnet.equiv_class(i,2);
|
wolffd@0
|
55 end
|
wolffd@0
|
56 msg{n}.pi = compute_pi(bnet.CPD{e}, n, ps, msg);
|
wolffd@0
|
57 if verbose, fprintf('%d computes pi\n', n); disp(msg{n}.pi); end
|
wolffd@0
|
58 end
|
wolffd@0
|
59
|
wolffd@0
|
60 % send pi msg to children
|
wolffd@0
|
61 for i=hnodes
|
wolffd@0
|
62 n = i + (t-1)*ss;
|
wolffd@0
|
63 %cs = myintersect(children(bnet2.dag, n), hnodes2);
|
wolffd@0
|
64 cs = children(bnet2.dag, n);
|
wolffd@0
|
65 for c=cs(:)'
|
wolffd@0
|
66 j = engine.parent_index{c}(n); % n is c's j'th parent
|
wolffd@0
|
67 pi_msg = normalise(compute_pi_msg(n, cs, msg, c, ns));
|
wolffd@0
|
68 msg{c}.pi_from_parent{j} = pi_msg;
|
wolffd@0
|
69 if verbose, fprintf('%d sends pi to %d\n', n, c); disp(pi_msg); end
|
wolffd@0
|
70 end
|
wolffd@0
|
71 end
|
wolffd@0
|
72 end
|
wolffd@0
|
73
|
wolffd@0
|
74 % BACKWARD
|
wolffd@0
|
75 for t=T:-1:1
|
wolffd@0
|
76 if verbose, fprintf('t = %d\n', t); end
|
wolffd@0
|
77 % update lambda
|
wolffd@0
|
78 for i=hnodes
|
wolffd@0
|
79 n = i + (t-1)*ss;
|
wolffd@0
|
80 cs = children(bnet2.dag, n);
|
wolffd@0
|
81 msg{n}.lambda = compute_lambda(n, cs, msg, ns);
|
wolffd@0
|
82 if verbose, fprintf('%d computes lambda\n', n); disp(msg{n}.lambda); end
|
wolffd@0
|
83 end
|
wolffd@0
|
84 % send lambda msgs to parents
|
wolffd@0
|
85 for i=hnodes
|
wolffd@0
|
86 n = i + (t-1)*ss;
|
wolffd@0
|
87 %ps = myintersect(parents(bnet2.dag, n), hnodes2);
|
wolffd@0
|
88 ps = parents(bnet2.dag, n);
|
wolffd@0
|
89 for p=ps(:)'
|
wolffd@0
|
90 j = engine.child_index{p}(n); % n is p's j'th child
|
wolffd@0
|
91 if t > 1
|
wolffd@0
|
92 e = bnet.equiv_class(i, 2);
|
wolffd@0
|
93 else
|
wolffd@0
|
94 e = bnet.equiv_class(i, 1);
|
wolffd@0
|
95 end
|
wolffd@0
|
96 lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p));
|
wolffd@0
|
97 msg{p}.lambda_from_child{j} = lam_msg;
|
wolffd@0
|
98 if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end
|
wolffd@0
|
99 end
|
wolffd@0
|
100 end
|
wolffd@0
|
101 end
|
wolffd@0
|
102
|
wolffd@0
|
103 end
|
wolffd@0
|
104
|
wolffd@0
|
105
|
wolffd@0
|
106 marginal = cell(ss,T);
|
wolffd@0
|
107 lik = zeros(1,ss*T);
|
wolffd@0
|
108 for t=1:T
|
wolffd@0
|
109 for i=1:ss
|
wolffd@0
|
110 n = i + (t-1)*ss;
|
wolffd@0
|
111 [bel, lik(n)] = normalise(msg{n}.pi .* msg{n}.lambda);
|
wolffd@0
|
112 marginal{i,t} = bel;
|
wolffd@0
|
113 end
|
wolffd@0
|
114 end
|
wolffd@0
|
115
|
wolffd@0
|
116 loglik = sum(log(lik));
|
wolffd@0
|
117
|
wolffd@0
|
118
|
wolffd@0
|
119
|
wolffd@0
|
120 %%%%%%%
|
wolffd@0
|
121
|
wolffd@0
|
122 function lambda = compute_lambda(n, cs, msg, ns)
|
wolffd@0
|
123 % Pearl p183 eq 4.50
|
wolffd@0
|
124 lambda = prod_lambda_msgs(n, cs, msg, ns);
|
wolffd@0
|
125
|
wolffd@0
|
126 %%%%%%%
|
wolffd@0
|
127
|
wolffd@0
|
128 function pi_msg = compute_pi_msg(n, cs, msg, c, ns)
|
wolffd@0
|
129 % Pearl p183 eq 4.53 and 4.51
|
wolffd@0
|
130 pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, ns, c);
|
wolffd@0
|
131
|
wolffd@0
|
132 %%%%%%%%%
|
wolffd@0
|
133
|
wolffd@0
|
134 function lam = prod_lambda_msgs(n, cs, msg, ns, except)
|
wolffd@0
|
135
|
wolffd@0
|
136 if nargin < 5, except = -1; end
|
wolffd@0
|
137
|
wolffd@0
|
138 lam = msg{n}.lambda_from_self(:);
|
wolffd@0
|
139 lam = ones(ns(n), 1);
|
wolffd@0
|
140 for i=1:length(cs)
|
wolffd@0
|
141 c = cs(i);
|
wolffd@0
|
142 if c ~= except
|
wolffd@0
|
143 lam = lam .* msg{n}.lambda_from_child{i};
|
wolffd@0
|
144 end
|
wolffd@0
|
145 end
|
wolffd@0
|
146
|
wolffd@0
|
147
|
wolffd@0
|
148 %%%%%%%%%
|
wolffd@0
|
149
|
wolffd@0
|
150 function msg = init_msgs(dag, ns, evidence, eclass, CPD)
|
wolffd@0
|
151 % INIT_MSGS Initialize the lambda/pi message and state vectors (pearl_dbn)
|
wolffd@0
|
152 % msg = init_msgs(dag, ns, evidence)
|
wolffd@0
|
153
|
wolffd@0
|
154 N = length(dag);
|
wolffd@0
|
155 msg = cell(1,N);
|
wolffd@0
|
156 observed = ~isemptycell(evidence(:));
|
wolffd@0
|
157
|
wolffd@0
|
158 for n=1:N
|
wolffd@0
|
159 ps = parents(dag, n);
|
wolffd@0
|
160 msg{n}.pi_from_parent = cell(1, length(ps));
|
wolffd@0
|
161 for i=1:length(ps)
|
wolffd@0
|
162 p = ps(i);
|
wolffd@0
|
163 msg{n}.pi_from_parent{i} = ones(ns(p), 1);
|
wolffd@0
|
164 end
|
wolffd@0
|
165
|
wolffd@0
|
166 cs = children(dag, n);
|
wolffd@0
|
167 msg{n}.lambda_from_child = cell(1, length(cs));
|
wolffd@0
|
168 for i=1:length(cs)
|
wolffd@0
|
169 c = cs(i);
|
wolffd@0
|
170 msg{n}.lambda_from_child{i} = ones(ns(n), 1);
|
wolffd@0
|
171 end
|
wolffd@0
|
172
|
wolffd@0
|
173 msg{n}.lambda = ones(ns(n), 1);
|
wolffd@0
|
174 msg{n}.lambda_from_self = ones(ns(n), 1);
|
wolffd@0
|
175 msg{n}.pi = ones(ns(n), 1);
|
wolffd@0
|
176
|
wolffd@0
|
177 % Initialize the lambdas with any evidence
|
wolffd@0
|
178 if observed(n)
|
wolffd@0
|
179 v = evidence{n};
|
wolffd@0
|
180 %msg{n}.lambda_from_self = zeros(ns(n), 1);
|
wolffd@0
|
181 %msg{n}.lambda_from_self(v) = 1; % delta function
|
wolffd@0
|
182 msg{n}.lambda = zeros(ns(n), 1);
|
wolffd@0
|
183 msg{n}.lambda(v) = 1; % delta function
|
wolffd@0
|
184 end
|
wolffd@0
|
185
|
wolffd@0
|
186 end
|
wolffd@0
|
187
|
wolffd@0
|
188
|
wolffd@0
|
189 %%%%%%%%
|
wolffd@0
|
190
|
wolffd@0
|
191 function msg = init_ev_msgs(engine, evidence, msg)
|
wolffd@0
|
192
|
wolffd@0
|
193 [ss T] = size(evidence);
|
wolffd@0
|
194 bnet = bnet_from_engine(engine);
|
wolffd@0
|
195 pot_type = 'd';
|
wolffd@0
|
196 t = 1;
|
wolffd@0
|
197 hnodes = mysetdiff(1:ss, engine.onodes);
|
wolffd@0
|
198 for i=engine.onodes(:)'
|
wolffd@0
|
199 fam = family(bnet.dag, i);
|
wolffd@0
|
200 e = bnet.equiv_class(i, 1);
|
wolffd@0
|
201 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
|
wolffd@0
|
202 temp = pot_to_marginal(CPDpot);
|
wolffd@0
|
203 msg{i}.lambda_from_self = temp.T;
|
wolffd@0
|
204 end
|
wolffd@0
|
205 for t=2:T
|
wolffd@0
|
206 for i=engine.onodes(:)'
|
wolffd@0
|
207 fam = family(bnet.dag, i, 2); % extract from slice t
|
wolffd@0
|
208 e = bnet.equiv_class(i, 2);
|
wolffd@0
|
209 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
|
wolffd@0
|
210 temp = pot_to_marginal(CPDpot);
|
wolffd@0
|
211 n = i + (t-1)*ss;
|
wolffd@0
|
212 msg{n}.lambda_from_self = temp.T;
|
wolffd@0
|
213 end
|
wolffd@0
|
214 end
|
wolffd@0
|
215
|
wolffd@0
|
216
|
wolffd@0
|
217 %%%%%%%%%%%
|
wolffd@0
|
218
|
wolffd@0
|
219 function msg = init_ev_msgs2(engine, evidence, msg)
|
wolffd@0
|
220
|
wolffd@0
|
221 [ss T] = size(evidence);
|
wolffd@0
|
222 bnet = bnet_from_engine(engine);
|
wolffd@0
|
223 pot_type = 'd';
|
wolffd@0
|
224 t = 1;
|
wolffd@0
|
225 hnodes = mysetdiff(1:ss, engine.onodes);
|
wolffd@0
|
226 for i=engine.onodes(:)'
|
wolffd@0
|
227 fam = family(bnet.dag, i);
|
wolffd@0
|
228 e = bnet.equiv_class(i, 1);
|
wolffd@0
|
229 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
|
wolffd@0
|
230 temp = pot_to_marginal(CPDpot);
|
wolffd@0
|
231 msg{i}.lambda_from_self = temp.T;
|
wolffd@0
|
232 end
|
wolffd@0
|
233 for t=2:T
|
wolffd@0
|
234 for i=engine.onodes(:)'
|
wolffd@0
|
235 fam = family(bnet.dag, i, 2); % extract from slice t
|
wolffd@0
|
236 e = bnet.equiv_class(i, 2);
|
wolffd@0
|
237 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
|
wolffd@0
|
238 temp = pot_to_marginal(CPDpot);
|
wolffd@0
|
239 n = i + (t-1)*ss;
|
wolffd@0
|
240 msg{n}.lambda_from_self = temp.T;
|
wolffd@0
|
241 end
|
wolffd@0
|
242 end
|
wolffd@0
|
243
|
wolffd@0
|
244
|