comparison toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@pearl_dbn_inf_engine/Old/smooth_evidence.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [marginal, msg, loglik] = smooth_evidence(engine, evidence)
2
3 [ss T] = size(evidence);
4 bnet = bnet_from_engine(engine);
5 onodes = engine.onodes;
6 hnodes = mysetdiff(1:ss, onodes);
7 hnodes = hnodes(:)';
8
9 ns = bnet.node_sizes(:);
10 onodes2 = [onodes(:); onodes(:)+ss];
11 ns(onodes2) = 1;
12
13 verbose = 0;
14 pot_type = 'd';
15 niter = engine.max_iter;
16
17 if verbose, fprintf('new smooth\n'); end
18
19 % msg(i1,t1,i2,j2) (i1,t1) -> (i2,t2)
20 %lambda_msg = cell(ss,T,ss,T);
21 %pi_msg = cell(ss,T,ss,T);
22
23 % intra_lambda_msg(i,j,t) (i,t) -> (j,t), i is child
24 % inter_lambda_msg(i,j,t) (i,t+1) -> (j,t), i is child
25 % inter_pi_msg(i,j,t) (i,t-1) -> (j,t), i is parent
26 intra_lambda_msg = cell(ss,ss,T);
27 inter_lambda_msg = cell(ss,ss,T);
28 inter_pi_msg = cell(ss,ss,T);
29
30 lambda = cell(ss,T);
31 pi = cell(ss,T);
32
33 for t=1:T
34 for i=1:ss
35 lambda{i,t} = ones(ns(i), 1);
36 pi{i,t} = ones(ns(i), 1);
37
38 cs = children(bnet.intra, i);
39 for c=cs(:)'
40 intra_lambda_msg{c,i,t} = ones(ns(i),1);
41 end
42
43 cs = children(bnet.inter, i);
44 for c=cs(:)'
45 inter_lambda_msg{c,i,t} = ones(ns(i),1);
46 end
47
48 ps = parents(bnet.inter, i);
49 for p=ps(:)'
50 inter_pi_msg{p,i,t} = ones(ns(i), 1); % not used for t==1
51 end
52 end
53 end
54
55
56 % each hidden node absorbs lambda from its observed child (if any)
57 for t=1:T
58 for i=hnodes
59 c = engine.obschild(i);
60 if c > 0
61 if t==1
62 fam = family(bnet.dag, c);
63 e = bnet.equiv_class(c, 1);
64 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
65 else
66 fam = family(bnet.dag, c, 2); % within 2 slice network
67 e = bnet.equiv_class(c, 2);
68 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
69 end
70 temp = pot_to_marginal(CPDpot);
71 lam_msg = normalise(temp.T);
72 intra_lambda_msg{c,i,t} = lam_msg;
73 end
74 end
75 end
76
77 for iter=1:engine.max_iter
78 % FORWARD
79 for t=1:T
80 % update pi
81 for i=hnodes
82 if t==1
83 e = bnet.equiv_class(i,1);
84 CPD = struct(bnet.CPD{e});
85 pi{i,t} = CPD.CPT;
86 else
87 e = bnet.equiv_class(i,2);
88 CPD = struct(bnet.CPD{e});
89 ps = parents(bnet.inter, i);
90 dom = [ps i+ss];
91 pot = dpot(dom, ns(dom), CPD.CPT);
92 for p=ps(:)'
93 temp = dpot(p, ns(p), inter_pi_msg{p,i,t});
94 pot = multiply_by_pot(pot, temp);
95 end
96 pot = marginalize_pot(pot, i+ss);
97 temp = pot_to_marginal(pot);
98 pi{i,t} = temp.T;
99 end
100 if verbose, fprintf('%d updates pi\n', i+(t-1)*ss); disp(pi{i,t}); end
101 end
102
103 % send pi msg to children
104 for i=hnodes
105 cs = children(bnet.inter, i);
106 for c=cs(:)'
107 pot = pi{i,t};
108 for k=cs(:)'
109 if k ~= c
110 pot = pot .* inter_lambda_msg{k,i,t};
111 end
112 end
113 cs2 = children(bnet.intra, i);
114 for k=cs2(:)'
115 pot = pot .* intra_lambda_msg{k,i,t};
116 end
117 inter_pi_msg{i,c,t+1} = normalise(pot);
118 if verbose, fprintf('%d sends pi to %d\n', i+(t-1)*ss, c+t*ss); disp(inter_pi_msg{i,c,t+1}); end
119 end
120 end
121 end
122
123 if verbose, fprintf('backwards\n'); end
124 % BACKWARD
125 for t=T:-1:1
126 % update lambda
127 for i=hnodes
128 pot = ones(ns(i), 1);
129 cs = children(bnet.inter, i);
130 for c=cs(:)'
131 pot = pot .* inter_lambda_msg{c,i,t};
132 end
133 cs = children(bnet.intra, i);
134 for c=cs(:)'
135 pot = pot .* intra_lambda_msg{c,i,t};
136 end
137 lambda{i,t} = normalise(pot);
138 if verbose, fprintf('%d computes lambda\n', i+(t-1)*ss); disp(lambda{i,t}); end
139 end
140
141 % send lambda msgs to hidden parents in prev slcie
142 for i=hnodes
143 ps = parents(bnet.inter, i);
144 if t > 1
145 e = bnet.equiv_class(i, 2);
146 CPD = struct(bnet.CPD{e});
147 fam = [ps i+ss];
148 for p=ps(:)'
149 pot = dpot(fam, ns(fam), CPD.CPT);
150 temp = dpot(i+ss, ns(i), lambda{i,t});
151 pot = multiply_by_pot(pot, temp);
152 for k=ps(:)'
153 if k ~= p
154 temp = dpot(k, ns(k), inter_pi_msg{k,i,t});
155 pot = multiply_by_pot(pot, temp);
156 end
157 end
158 pot = marginalize_pot(pot, p);
159 temp = pot_to_marginal(pot);
160 inter_lambda_msg{i,p,t-1} = normalise(temp.T);
161 if verbose, fprintf('%d sends lambda to %d\n', i+(t-1)*ss, p+(t-2)*ss); disp(inter_lambda_msg{i,p,t-1}); end
162 end
163 end
164 end
165 end
166 end
167
168
169
170 marginal = cell(ss,T);
171 for t=1:T
172 for i=hnodes
173 marginal{i,t} = normalise(pi{i,t} .* lambda{i,t});
174 end
175 end
176
177 loglik = 0;
178
179 msg.inter_pi_msg = inter_pi_msg;
180 msg.inter_lambda_msg = inter_lambda_msg;
181 msg.intra_lambda_msg = intra_lambda_msg;