comparison toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@pearl_dbn_inf_engine/Old/filter_evidence_obj_oriented.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [marginal, msg, loglik] = filter_evidence_old(engine, evidence)
2 % [marginal, msg, loglik] = filter_evidence(engine, evidence) (pearl_dbn)
3
4 [ss T] = size(evidence);
5 bnet = bnet_from_engine(engine);
6 bnet2 = dbn_to_bnet(bnet, T);
7 ns = bnet2.node_sizes;
8 hnodes = mysetdiff(1:ss, engine.onodes);
9 hnodes = hnodes(:)';
10
11 [engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet2);
12
13 msg = init_msgs(bnet2.dag, ns, evidence);
14 msg = init_ev_msgs(engine, evidence, msg);
15
16 verbose = 1;
17 if verbose, fprintf('\nold filtering\n'); end
18
19 for t=1:T
20 % update pi
21 for i=hnodes
22 n = i + (t-1)*ss;
23 ps = parents(bnet2.dag, n);
24 if t==1
25 e = bnet.equiv_class(i,1);
26 else
27 e = bnet.equiv_class(i,2);
28 end
29 msg{n}.pi = compute_pi(bnet.CPD{e}, n, ps, msg);
30 %if verbose, fprintf('%d computes pi\n', n); disp(msg{n}.pi); end
31 msg{n}.pi = normalise(msg{n}.pi(:) .* msg{n}.lambda_from_self(:));
32 if verbose, fprintf('%d recomputes pi\n', n); disp(msg{n}.pi); end
33 end
34 % send pi msg to children
35 for i=hnodes
36 n = i + (t-1)*ss;
37 cs = children(bnet2.dag, n);
38 for c=cs(:)'
39 j = engine.parent_index{c}(n); % n is c's j'th parent
40 pi_msg = normalise(compute_pi_msg(n, cs, msg, c, ns));
41 msg{c}.pi_from_parent{j} = pi_msg;
42 if verbose, fprintf('%d sends pi to %d\n', n,c); disp(pi_msg); end
43 end
44 end
45 end
46
47
48 marginal = cell(ss,T);
49 lik = zeros(1,ss*T);
50 for t=1:T
51 for i=1:ss
52 n = i + (t-1)*ss;
53 %[bel, lik(n)] = normalise(msg{n}.pi .* msg{n}.lambda);
54 [bel, lik(n)] = normalise(msg{n}.pi);
55 marginal{i,t} = bel;
56 end
57 end
58
59 loglik = sum(log(lik));
60
61
62
63 %%%%%%%
64
65 function lambda = compute_lambda(n, cs, msg, ns)
66 % Pearl p183 eq 4.50
67 lambda = prod_lambda_msgs(n, cs, msg, ns);
68
69 %%%%%%%
70
71 function pi_msg = compute_pi_msg(n, cs, msg, c, ns)
72 % Pearl p183 eq 4.53 and 4.51
73 pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, ns, c);
74
75 %%%%%%%%%
76
77 function lam = prod_lambda_msgs(n, cs, msg, ns, except)
78
79 if nargin < 5, except = -1; end
80
81 %lam = msg{n}.lambda_from_self(:);
82 lam = ones(ns(n), 1);
83 for i=1:length(cs)
84 c = cs(i);
85 if c ~= except
86 lam = lam .* msg{n}.lambda_from_child{i};
87 end
88 end
89
90
91 %%%%%%%%%%%
92
93 function msg = init_msgs(dag, ns, evidence)
94 % INIT_MSGS Initialize the lambda/pi message and state vectors (pearl_dbn)
95 % msg = init_msgs(dag, ns, evidence)
96 %
97 % We assume all the hidden nodes are discrete.
98
99 N = length(dag);
100 msg = cell(1,N);
101 observed = ~isemptycell(evidence(:));
102
103 for n=1:N
104 ps = parents(dag, n);
105 msg{n}.pi_from_parent = cell(1, length(ps));
106 for i=1:length(ps)
107 p = ps(i);
108 msg{n}.pi_from_parent{i} = ones(ns(p), 1);
109 end
110
111 cs = children(dag, n);
112 msg{n}.lambda_from_child = cell(1, length(cs));
113 for i=1:length(cs)
114 c = cs(i);
115 msg{n}.lambda_from_child{i} = ones(ns(n), 1);
116 end
117
118 msg{n}.lambda = ones(ns(n), 1);
119 msg{n}.pi = ones(ns(n), 1);
120
121 msg{n}.lambda_from_self = ones(ns(n), 1);
122 end
123
124
125 %%%%%%%%%
126
127 function msg = init_ev_msgs(engine, evidence, msg)
128 % Initialize the lambdas with any evidence
129
130 [ss T] = size(evidence);
131 bnet = bnet_from_engine(engine);
132 pot_type = 'd';
133 t = 1;
134 hnodes = mysetdiff(1:ss, engine.onodes);
135 for i=hnodes(:)'
136 c = engine.obschild(i);
137 if c > 0
138 fam = family(bnet.dag, c);
139 e = bnet.equiv_class(c, 1);
140 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
141 temp = pot_to_marginal(CPDpot);
142 n = i;
143 msg{n}.lambda_from_self = temp.T;
144 end
145 end
146 for t=2:T
147 for i=hnodes(:)'
148 c = engine.obschild(i);
149 if c > 0
150 fam = family(bnet.dag, c, 2);
151 e = bnet.equiv_class(c, 2);
152 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
153 temp = pot_to_marginal(CPDpot);
154 n = i + (t-1)*ss;
155 msg{n}.lambda_from_self = temp.T;
156 end
157 end
158 end