Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@pearl_inf_engine/marginal_family.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function m = marginal_family(engine, n, add_ev) | |
2 % MARGINAL_FAMILY Compute the marginal on i's family (loopy) | |
3 % m = marginal_family(engine, n, add_ev) | |
4 | |
5 if nargin < 3, add_ev = 0; end | |
6 | |
7 bnet = bnet_from_engine(engine); | |
8 ns = bnet.node_sizes; | |
9 ps = parents(bnet.dag, n); | |
10 dom = [ps n]; | |
11 CPD = bnet.CPD{bnet.equiv_class(n)}; | |
12 | |
13 switch engine.msg_type | |
14 case 'd', | |
15 % The method is similar to the following HMM equation: | |
16 % xi(i,j,t) = normalise( alpha(i,t) * transmat(i,j) * obsmat(j,t+1) * beta(j,t+1) ) | |
17 % where xi(i,j,t) = Pr(Q(t)=i, Q(t+1)=j | y(1:T)) | |
18 % beta == lambda, alpha == pi, alpha from each parent = pi msg | |
19 % In general, if A,B are parents of C, | |
20 % P(A,B,C) = P(C|A,B) pi_msg(A->C) pi_msg(B->C) lambda(C) | |
21 % where lambda(C) = P(ev below and including C|C) = prod incoming lamba_msg(children->C) | |
22 % and pi_msg(X->C) = P(X|ev above) etc | |
23 | |
24 T = dpot(dom, ns(dom), CPD_to_CPT(CPD)); | |
25 for j=1:length(ps) | |
26 p = ps(j); | |
27 pi_msg = dpot(p, ns(p), engine.msg{n}.pi_from_parent{j}); | |
28 T = multiply_by_pot(T, pi_msg); | |
29 end | |
30 lambda = dpot(n, ns(n), engine.msg{n}.lambda); | |
31 T = multiply_by_pot(T, lambda); | |
32 T = normalize_pot(T); | |
33 m = pot_to_marginal(T); | |
34 if ~add_ev | |
35 m.T = shrink_obs_dims_in_table(m.T, dom, engine.evidence); | |
36 end | |
37 case 'g', | |
38 if engine.disconnected_nodes_bitv(n) | |
39 m.T = 1; | |
40 m.domain = dom; | |
41 if add_ev | |
42 m = add_ev_to_dmarginal(m, engine.evidence, ns) | |
43 end | |
44 return; | |
45 end | |
46 | |
47 [m, C, W] = gaussian_CPD_params_given_dps(CPD, dom, engine.evidence); | |
48 cdom = myintersect(dom, bnet.cnodes); | |
49 pot = linear_gaussian_to_cpot(m, C, W, dom, ns, cdom, engine.evidence); | |
50 % linear_gaussian_to_cpot will set the effective size of observed nodes to 0, | |
51 % so we need to do this explicitely for the messages, too, | |
52 % so they are all the same size. | |
53 obs_bitv = ~isemptycell(engine.evidence); | |
54 ps = parents(engine.msg_dag, n); | |
55 for j=1:length(ps) | |
56 p = ps(j); | |
57 msg = engine.msg{n}.pi_from_parent{j}; | |
58 if obs_bitv(p) | |
59 pi_msg = mpot(p, 0); | |
60 else | |
61 pi_msg = mpot(p, ns(p), 0, msg.mu, msg.Sigma); | |
62 end | |
63 pot = multiply_by_pot(pot, mpot_to_cpot(pi_msg)); | |
64 end | |
65 msg = engine.msg{n}.lambda; | |
66 if obs_bitv(n) | |
67 lambda = cpot(n, 0); | |
68 else | |
69 lambda = cpot(n, ns(n), 0, msg.info_state, msg.precision); | |
70 end | |
71 pot = multiply_by_pot(pot, lambda); | |
72 m = pot_to_marginal(pot); | |
73 if add_ev | |
74 m = add_evidence_to_gmarginal(m, engine.evidence, bnet.node_sizes, bnet.cnodes); | |
75 end | |
76 end | |
77 | |
78 | |
79 | |
80 |