comparison toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@pearl_dbn_inf_engine/Old/correct_smooth.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [marginal, msg, loglik] = smooth_evidence(engine, evidence)
2 % [marginal, msg, loglik] = smooth_evidence(engine, evidence) (pearl_dbn)
3
4 disp('warning: broken');
5
6 [ss T] = size(evidence);
7 bnet = bnet_from_engine(engine);
8 bnet2 = dbn_to_bnet(bnet, T);
9 ns = bnet2.node_sizes;
10 hnodes = mysetdiff(1:ss, engine.onodes);
11 hnodes = hnodes(:)';
12
13 onodes2 = unroll_set(engine.onodes(:), ss, T);
14 onodes2 = onodes2(:)';
15
16 hnodes2 = unroll_set(hnodes(:), ss, T);
17 hnodes2 = hnodes2(:)';
18
19 [engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet2);
20
21 msg = init_msgs(bnet2.dag, ns, evidence, bnet2.equiv_class, bnet2.CPD);
22
23 verbose = 0;
24
25 niter = 1;
26 for iter=1:niter
27 % FORWARD
28 for t=1:T
29 if verbose, fprintf('t=%d\n', t); end
30 % observed leaves send lambda to parents
31 for i=engine.onodes(:)'
32 n = i + (t-1)*ss;
33 ps = parents(bnet2.dag, n);
34 for p=ps(:)'
35 j = engine.child_index{p}(n); % n is p's j'th child
36 if t > 1
37 e = bnet.equiv_class(i, 2);
38 else
39 e = bnet.equiv_class(i, 1);
40 end
41 lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p));
42 msg{p}.lambda_from_child{j} = lam_msg;
43 if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end
44 end
45 end
46
47 % update pi
48 for i=hnodes
49 n = i + (t-1)*ss;
50 ps = parents(bnet2.dag, n);
51 if t==1
52 e = bnet.equiv_class(i,1);
53 else
54 e = bnet.equiv_class(i,2);
55 end
56 msg{n}.pi = compute_pi(bnet.CPD{e}, n, ps, msg);
57 if verbose, fprintf('%d computes pi\n', n); disp(msg{n}.pi); end
58 end
59
60 % send pi msg to children
61 for i=hnodes
62 n = i + (t-1)*ss;
63 %cs = myintersect(children(bnet2.dag, n), hnodes2);
64 cs = children(bnet2.dag, n);
65 for c=cs(:)'
66 j = engine.parent_index{c}(n); % n is c's j'th parent
67 pi_msg = normalise(compute_pi_msg(n, cs, msg, c, ns));
68 msg{c}.pi_from_parent{j} = pi_msg;
69 if verbose, fprintf('%d sends pi to %d\n', n, c); disp(pi_msg); end
70 end
71 end
72 end
73
74 % BACKWARD
75 for t=T:-1:1
76 if verbose, fprintf('t = %d\n', t); end
77 % update lambda
78 for i=hnodes
79 n = i + (t-1)*ss;
80 cs = children(bnet2.dag, n);
81 msg{n}.lambda = compute_lambda(n, cs, msg, ns);
82 if verbose, fprintf('%d computes lambda\n', n); disp(msg{n}.lambda); end
83 end
84 % send lambda msgs to parents
85 for i=hnodes
86 n = i + (t-1)*ss;
87 %ps = myintersect(parents(bnet2.dag, n), hnodes2);
88 ps = parents(bnet2.dag, n);
89 for p=ps(:)'
90 j = engine.child_index{p}(n); % n is p's j'th child
91 if t > 1
92 e = bnet.equiv_class(i, 2);
93 else
94 e = bnet.equiv_class(i, 1);
95 end
96 lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p));
97 msg{p}.lambda_from_child{j} = lam_msg;
98 if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end
99 end
100 end
101 end
102
103 end
104
105
106 marginal = cell(ss,T);
107 lik = zeros(1,ss*T);
108 for t=1:T
109 for i=1:ss
110 n = i + (t-1)*ss;
111 [bel, lik(n)] = normalise(msg{n}.pi .* msg{n}.lambda);
112 marginal{i,t} = bel;
113 end
114 end
115
116 loglik = sum(log(lik));
117
118
119
120 %%%%%%%
121
122 function lambda = compute_lambda(n, cs, msg, ns)
123 % Pearl p183 eq 4.50
124 lambda = prod_lambda_msgs(n, cs, msg, ns);
125
126 %%%%%%%
127
128 function pi_msg = compute_pi_msg(n, cs, msg, c, ns)
129 % Pearl p183 eq 4.53 and 4.51
130 pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, ns, c);
131
132 %%%%%%%%%
133
134 function lam = prod_lambda_msgs(n, cs, msg, ns, except)
135
136 if nargin < 5, except = -1; end
137
138 lam = msg{n}.lambda_from_self(:);
139 lam = ones(ns(n), 1);
140 for i=1:length(cs)
141 c = cs(i);
142 if c ~= except
143 lam = lam .* msg{n}.lambda_from_child{i};
144 end
145 end
146
147
148 %%%%%%%%%
149
150 function msg = init_msgs(dag, ns, evidence, eclass, CPD)
151 % INIT_MSGS Initialize the lambda/pi message and state vectors (pearl_dbn)
152 % msg = init_msgs(dag, ns, evidence)
153
154 N = length(dag);
155 msg = cell(1,N);
156 observed = ~isemptycell(evidence(:));
157
158 for n=1:N
159 ps = parents(dag, n);
160 msg{n}.pi_from_parent = cell(1, length(ps));
161 for i=1:length(ps)
162 p = ps(i);
163 msg{n}.pi_from_parent{i} = ones(ns(p), 1);
164 end
165
166 cs = children(dag, n);
167 msg{n}.lambda_from_child = cell(1, length(cs));
168 for i=1:length(cs)
169 c = cs(i);
170 msg{n}.lambda_from_child{i} = ones(ns(n), 1);
171 end
172
173 msg{n}.lambda = ones(ns(n), 1);
174 msg{n}.lambda_from_self = ones(ns(n), 1);
175 msg{n}.pi = ones(ns(n), 1);
176
177 % Initialize the lambdas with any evidence
178 if observed(n)
179 v = evidence{n};
180 %msg{n}.lambda_from_self = zeros(ns(n), 1);
181 %msg{n}.lambda_from_self(v) = 1; % delta function
182 msg{n}.lambda = zeros(ns(n), 1);
183 msg{n}.lambda(v) = 1; % delta function
184 end
185
186 end
187
188
189 %%%%%%%%
190
191 function msg = init_ev_msgs(engine, evidence, msg)
192
193 [ss T] = size(evidence);
194 bnet = bnet_from_engine(engine);
195 pot_type = 'd';
196 t = 1;
197 hnodes = mysetdiff(1:ss, engine.onodes);
198 for i=engine.onodes(:)'
199 fam = family(bnet.dag, i);
200 e = bnet.equiv_class(i, 1);
201 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
202 temp = pot_to_marginal(CPDpot);
203 msg{i}.lambda_from_self = temp.T;
204 end
205 for t=2:T
206 for i=engine.onodes(:)'
207 fam = family(bnet.dag, i, 2); % extract from slice t
208 e = bnet.equiv_class(i, 2);
209 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
210 temp = pot_to_marginal(CPDpot);
211 n = i + (t-1)*ss;
212 msg{n}.lambda_from_self = temp.T;
213 end
214 end
215
216
217 %%%%%%%%%%%
218
219 function msg = init_ev_msgs2(engine, evidence, msg)
220
221 [ss T] = size(evidence);
222 bnet = bnet_from_engine(engine);
223 pot_type = 'd';
224 t = 1;
225 hnodes = mysetdiff(1:ss, engine.onodes);
226 for i=engine.onodes(:)'
227 fam = family(bnet.dag, i);
228 e = bnet.equiv_class(i, 1);
229 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
230 temp = pot_to_marginal(CPDpot);
231 msg{i}.lambda_from_self = temp.T;
232 end
233 for t=2:T
234 for i=engine.onodes(:)'
235 fam = family(bnet.dag, i, 2); % extract from slice t
236 e = bnet.equiv_class(i, 2);
237 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
238 temp = pot_to_marginal(CPDpot);
239 n = i + (t-1)*ss;
240 msg{n}.lambda_from_self = temp.T;
241 end
242 end
243
244