Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@pearl_dbn_inf_engine/Old/correct_smooth.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [marginal, msg, loglik] = smooth_evidence(engine, evidence) | |
2 % [marginal, msg, loglik] = smooth_evidence(engine, evidence) (pearl_dbn) | |
3 | |
4 disp('warning: broken'); | |
5 | |
6 [ss T] = size(evidence); | |
7 bnet = bnet_from_engine(engine); | |
8 bnet2 = dbn_to_bnet(bnet, T); | |
9 ns = bnet2.node_sizes; | |
10 hnodes = mysetdiff(1:ss, engine.onodes); | |
11 hnodes = hnodes(:)'; | |
12 | |
13 onodes2 = unroll_set(engine.onodes(:), ss, T); | |
14 onodes2 = onodes2(:)'; | |
15 | |
16 hnodes2 = unroll_set(hnodes(:), ss, T); | |
17 hnodes2 = hnodes2(:)'; | |
18 | |
19 [engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet2); | |
20 | |
21 msg = init_msgs(bnet2.dag, ns, evidence, bnet2.equiv_class, bnet2.CPD); | |
22 | |
23 verbose = 0; | |
24 | |
25 niter = 1; | |
26 for iter=1:niter | |
27 % FORWARD | |
28 for t=1:T | |
29 if verbose, fprintf('t=%d\n', t); end | |
30 % observed leaves send lambda to parents | |
31 for i=engine.onodes(:)' | |
32 n = i + (t-1)*ss; | |
33 ps = parents(bnet2.dag, n); | |
34 for p=ps(:)' | |
35 j = engine.child_index{p}(n); % n is p's j'th child | |
36 if t > 1 | |
37 e = bnet.equiv_class(i, 2); | |
38 else | |
39 e = bnet.equiv_class(i, 1); | |
40 end | |
41 lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p)); | |
42 msg{p}.lambda_from_child{j} = lam_msg; | |
43 if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end | |
44 end | |
45 end | |
46 | |
47 % update pi | |
48 for i=hnodes | |
49 n = i + (t-1)*ss; | |
50 ps = parents(bnet2.dag, n); | |
51 if t==1 | |
52 e = bnet.equiv_class(i,1); | |
53 else | |
54 e = bnet.equiv_class(i,2); | |
55 end | |
56 msg{n}.pi = compute_pi(bnet.CPD{e}, n, ps, msg); | |
57 if verbose, fprintf('%d computes pi\n', n); disp(msg{n}.pi); end | |
58 end | |
59 | |
60 % send pi msg to children | |
61 for i=hnodes | |
62 n = i + (t-1)*ss; | |
63 %cs = myintersect(children(bnet2.dag, n), hnodes2); | |
64 cs = children(bnet2.dag, n); | |
65 for c=cs(:)' | |
66 j = engine.parent_index{c}(n); % n is c's j'th parent | |
67 pi_msg = normalise(compute_pi_msg(n, cs, msg, c, ns)); | |
68 msg{c}.pi_from_parent{j} = pi_msg; | |
69 if verbose, fprintf('%d sends pi to %d\n', n, c); disp(pi_msg); end | |
70 end | |
71 end | |
72 end | |
73 | |
74 % BACKWARD | |
75 for t=T:-1:1 | |
76 if verbose, fprintf('t = %d\n', t); end | |
77 % update lambda | |
78 for i=hnodes | |
79 n = i + (t-1)*ss; | |
80 cs = children(bnet2.dag, n); | |
81 msg{n}.lambda = compute_lambda(n, cs, msg, ns); | |
82 if verbose, fprintf('%d computes lambda\n', n); disp(msg{n}.lambda); end | |
83 end | |
84 % send lambda msgs to parents | |
85 for i=hnodes | |
86 n = i + (t-1)*ss; | |
87 %ps = myintersect(parents(bnet2.dag, n), hnodes2); | |
88 ps = parents(bnet2.dag, n); | |
89 for p=ps(:)' | |
90 j = engine.child_index{p}(n); % n is p's j'th child | |
91 if t > 1 | |
92 e = bnet.equiv_class(i, 2); | |
93 else | |
94 e = bnet.equiv_class(i, 1); | |
95 end | |
96 lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p)); | |
97 msg{p}.lambda_from_child{j} = lam_msg; | |
98 if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end | |
99 end | |
100 end | |
101 end | |
102 | |
103 end | |
104 | |
105 | |
106 marginal = cell(ss,T); | |
107 lik = zeros(1,ss*T); | |
108 for t=1:T | |
109 for i=1:ss | |
110 n = i + (t-1)*ss; | |
111 [bel, lik(n)] = normalise(msg{n}.pi .* msg{n}.lambda); | |
112 marginal{i,t} = bel; | |
113 end | |
114 end | |
115 | |
116 loglik = sum(log(lik)); | |
117 | |
118 | |
119 | |
120 %%%%%%% | |
121 | |
122 function lambda = compute_lambda(n, cs, msg, ns) | |
123 % Pearl p183 eq 4.50 | |
124 lambda = prod_lambda_msgs(n, cs, msg, ns); | |
125 | |
126 %%%%%%% | |
127 | |
128 function pi_msg = compute_pi_msg(n, cs, msg, c, ns) | |
129 % Pearl p183 eq 4.53 and 4.51 | |
130 pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, ns, c); | |
131 | |
132 %%%%%%%%% | |
133 | |
134 function lam = prod_lambda_msgs(n, cs, msg, ns, except) | |
135 | |
136 if nargin < 5, except = -1; end | |
137 | |
138 lam = msg{n}.lambda_from_self(:); | |
139 lam = ones(ns(n), 1); | |
140 for i=1:length(cs) | |
141 c = cs(i); | |
142 if c ~= except | |
143 lam = lam .* msg{n}.lambda_from_child{i}; | |
144 end | |
145 end | |
146 | |
147 | |
148 %%%%%%%%% | |
149 | |
150 function msg = init_msgs(dag, ns, evidence, eclass, CPD) | |
151 % INIT_MSGS Initialize the lambda/pi message and state vectors (pearl_dbn) | |
152 % msg = init_msgs(dag, ns, evidence) | |
153 | |
154 N = length(dag); | |
155 msg = cell(1,N); | |
156 observed = ~isemptycell(evidence(:)); | |
157 | |
158 for n=1:N | |
159 ps = parents(dag, n); | |
160 msg{n}.pi_from_parent = cell(1, length(ps)); | |
161 for i=1:length(ps) | |
162 p = ps(i); | |
163 msg{n}.pi_from_parent{i} = ones(ns(p), 1); | |
164 end | |
165 | |
166 cs = children(dag, n); | |
167 msg{n}.lambda_from_child = cell(1, length(cs)); | |
168 for i=1:length(cs) | |
169 c = cs(i); | |
170 msg{n}.lambda_from_child{i} = ones(ns(n), 1); | |
171 end | |
172 | |
173 msg{n}.lambda = ones(ns(n), 1); | |
174 msg{n}.lambda_from_self = ones(ns(n), 1); | |
175 msg{n}.pi = ones(ns(n), 1); | |
176 | |
177 % Initialize the lambdas with any evidence | |
178 if observed(n) | |
179 v = evidence{n}; | |
180 %msg{n}.lambda_from_self = zeros(ns(n), 1); | |
181 %msg{n}.lambda_from_self(v) = 1; % delta function | |
182 msg{n}.lambda = zeros(ns(n), 1); | |
183 msg{n}.lambda(v) = 1; % delta function | |
184 end | |
185 | |
186 end | |
187 | |
188 | |
189 %%%%%%%% | |
190 | |
191 function msg = init_ev_msgs(engine, evidence, msg) | |
192 | |
193 [ss T] = size(evidence); | |
194 bnet = bnet_from_engine(engine); | |
195 pot_type = 'd'; | |
196 t = 1; | |
197 hnodes = mysetdiff(1:ss, engine.onodes); | |
198 for i=engine.onodes(:)' | |
199 fam = family(bnet.dag, i); | |
200 e = bnet.equiv_class(i, 1); | |
201 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1)); | |
202 temp = pot_to_marginal(CPDpot); | |
203 msg{i}.lambda_from_self = temp.T; | |
204 end | |
205 for t=2:T | |
206 for i=engine.onodes(:)' | |
207 fam = family(bnet.dag, i, 2); % extract from slice t | |
208 e = bnet.equiv_class(i, 2); | |
209 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t)); | |
210 temp = pot_to_marginal(CPDpot); | |
211 n = i + (t-1)*ss; | |
212 msg{n}.lambda_from_self = temp.T; | |
213 end | |
214 end | |
215 | |
216 | |
217 %%%%%%%%%%% | |
218 | |
219 function msg = init_ev_msgs2(engine, evidence, msg) | |
220 | |
221 [ss T] = size(evidence); | |
222 bnet = bnet_from_engine(engine); | |
223 pot_type = 'd'; | |
224 t = 1; | |
225 hnodes = mysetdiff(1:ss, engine.onodes); | |
226 for i=engine.onodes(:)' | |
227 fam = family(bnet.dag, i); | |
228 e = bnet.equiv_class(i, 1); | |
229 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1)); | |
230 temp = pot_to_marginal(CPDpot); | |
231 msg{i}.lambda_from_self = temp.T; | |
232 end | |
233 for t=2:T | |
234 for i=engine.onodes(:)' | |
235 fam = family(bnet.dag, i, 2); % extract from slice t | |
236 e = bnet.equiv_class(i, 2); | |
237 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t)); | |
238 temp = pot_to_marginal(CPDpot); | |
239 n = i + (t-1)*ss; | |
240 msg{n}.lambda_from_self = temp.T; | |
241 end | |
242 end | |
243 | |
244 |