annotate toolboxes/FullBNT-1.0.7/bnt/inference/dynamic/@pearl_dbn_inf_engine/Old/correct_smooth.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [marginal, msg, loglik] = smooth_evidence(engine, evidence)
wolffd@0 2 % [marginal, msg, loglik] = smooth_evidence(engine, evidence) (pearl_dbn)
wolffd@0 3
wolffd@0 4 disp('warning: broken');
wolffd@0 5
wolffd@0 6 [ss T] = size(evidence);
wolffd@0 7 bnet = bnet_from_engine(engine);
wolffd@0 8 bnet2 = dbn_to_bnet(bnet, T);
wolffd@0 9 ns = bnet2.node_sizes;
wolffd@0 10 hnodes = mysetdiff(1:ss, engine.onodes);
wolffd@0 11 hnodes = hnodes(:)';
wolffd@0 12
wolffd@0 13 onodes2 = unroll_set(engine.onodes(:), ss, T);
wolffd@0 14 onodes2 = onodes2(:)';
wolffd@0 15
wolffd@0 16 hnodes2 = unroll_set(hnodes(:), ss, T);
wolffd@0 17 hnodes2 = hnodes2(:)';
wolffd@0 18
wolffd@0 19 [engine.parent_index, engine.child_index] = mk_pearl_msg_indices(bnet2);
wolffd@0 20
wolffd@0 21 msg = init_msgs(bnet2.dag, ns, evidence, bnet2.equiv_class, bnet2.CPD);
wolffd@0 22
wolffd@0 23 verbose = 0;
wolffd@0 24
wolffd@0 25 niter = 1;
wolffd@0 26 for iter=1:niter
wolffd@0 27 % FORWARD
wolffd@0 28 for t=1:T
wolffd@0 29 if verbose, fprintf('t=%d\n', t); end
wolffd@0 30 % observed leaves send lambda to parents
wolffd@0 31 for i=engine.onodes(:)'
wolffd@0 32 n = i + (t-1)*ss;
wolffd@0 33 ps = parents(bnet2.dag, n);
wolffd@0 34 for p=ps(:)'
wolffd@0 35 j = engine.child_index{p}(n); % n is p's j'th child
wolffd@0 36 if t > 1
wolffd@0 37 e = bnet.equiv_class(i, 2);
wolffd@0 38 else
wolffd@0 39 e = bnet.equiv_class(i, 1);
wolffd@0 40 end
wolffd@0 41 lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p));
wolffd@0 42 msg{p}.lambda_from_child{j} = lam_msg;
wolffd@0 43 if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end
wolffd@0 44 end
wolffd@0 45 end
wolffd@0 46
wolffd@0 47 % update pi
wolffd@0 48 for i=hnodes
wolffd@0 49 n = i + (t-1)*ss;
wolffd@0 50 ps = parents(bnet2.dag, n);
wolffd@0 51 if t==1
wolffd@0 52 e = bnet.equiv_class(i,1);
wolffd@0 53 else
wolffd@0 54 e = bnet.equiv_class(i,2);
wolffd@0 55 end
wolffd@0 56 msg{n}.pi = compute_pi(bnet.CPD{e}, n, ps, msg);
wolffd@0 57 if verbose, fprintf('%d computes pi\n', n); disp(msg{n}.pi); end
wolffd@0 58 end
wolffd@0 59
wolffd@0 60 % send pi msg to children
wolffd@0 61 for i=hnodes
wolffd@0 62 n = i + (t-1)*ss;
wolffd@0 63 %cs = myintersect(children(bnet2.dag, n), hnodes2);
wolffd@0 64 cs = children(bnet2.dag, n);
wolffd@0 65 for c=cs(:)'
wolffd@0 66 j = engine.parent_index{c}(n); % n is c's j'th parent
wolffd@0 67 pi_msg = normalise(compute_pi_msg(n, cs, msg, c, ns));
wolffd@0 68 msg{c}.pi_from_parent{j} = pi_msg;
wolffd@0 69 if verbose, fprintf('%d sends pi to %d\n', n, c); disp(pi_msg); end
wolffd@0 70 end
wolffd@0 71 end
wolffd@0 72 end
wolffd@0 73
wolffd@0 74 % BACKWARD
wolffd@0 75 for t=T:-1:1
wolffd@0 76 if verbose, fprintf('t = %d\n', t); end
wolffd@0 77 % update lambda
wolffd@0 78 for i=hnodes
wolffd@0 79 n = i + (t-1)*ss;
wolffd@0 80 cs = children(bnet2.dag, n);
wolffd@0 81 msg{n}.lambda = compute_lambda(n, cs, msg, ns);
wolffd@0 82 if verbose, fprintf('%d computes lambda\n', n); disp(msg{n}.lambda); end
wolffd@0 83 end
wolffd@0 84 % send lambda msgs to parents
wolffd@0 85 for i=hnodes
wolffd@0 86 n = i + (t-1)*ss;
wolffd@0 87 %ps = myintersect(parents(bnet2.dag, n), hnodes2);
wolffd@0 88 ps = parents(bnet2.dag, n);
wolffd@0 89 for p=ps(:)'
wolffd@0 90 j = engine.child_index{p}(n); % n is p's j'th child
wolffd@0 91 if t > 1
wolffd@0 92 e = bnet.equiv_class(i, 2);
wolffd@0 93 else
wolffd@0 94 e = bnet.equiv_class(i, 1);
wolffd@0 95 end
wolffd@0 96 lam_msg = normalise(compute_lambda_msg(bnet.CPD{e}, n, ps, msg, p));
wolffd@0 97 msg{p}.lambda_from_child{j} = lam_msg;
wolffd@0 98 if verbose, fprintf('%d sends lambda to %d\n', n, p); disp(lam_msg); end
wolffd@0 99 end
wolffd@0 100 end
wolffd@0 101 end
wolffd@0 102
wolffd@0 103 end
wolffd@0 104
wolffd@0 105
wolffd@0 106 marginal = cell(ss,T);
wolffd@0 107 lik = zeros(1,ss*T);
wolffd@0 108 for t=1:T
wolffd@0 109 for i=1:ss
wolffd@0 110 n = i + (t-1)*ss;
wolffd@0 111 [bel, lik(n)] = normalise(msg{n}.pi .* msg{n}.lambda);
wolffd@0 112 marginal{i,t} = bel;
wolffd@0 113 end
wolffd@0 114 end
wolffd@0 115
wolffd@0 116 loglik = sum(log(lik));
wolffd@0 117
wolffd@0 118
wolffd@0 119
wolffd@0 120 %%%%%%%
wolffd@0 121
wolffd@0 122 function lambda = compute_lambda(n, cs, msg, ns)
wolffd@0 123 % Pearl p183 eq 4.50
wolffd@0 124 lambda = prod_lambda_msgs(n, cs, msg, ns);
wolffd@0 125
wolffd@0 126 %%%%%%%
wolffd@0 127
wolffd@0 128 function pi_msg = compute_pi_msg(n, cs, msg, c, ns)
wolffd@0 129 % Pearl p183 eq 4.53 and 4.51
wolffd@0 130 pi_msg = msg{n}.pi .* prod_lambda_msgs(n, cs, msg, ns, c);
wolffd@0 131
wolffd@0 132 %%%%%%%%%
wolffd@0 133
wolffd@0 134 function lam = prod_lambda_msgs(n, cs, msg, ns, except)
wolffd@0 135
wolffd@0 136 if nargin < 5, except = -1; end
wolffd@0 137
wolffd@0 138 lam = msg{n}.lambda_from_self(:);
wolffd@0 139 lam = ones(ns(n), 1);
wolffd@0 140 for i=1:length(cs)
wolffd@0 141 c = cs(i);
wolffd@0 142 if c ~= except
wolffd@0 143 lam = lam .* msg{n}.lambda_from_child{i};
wolffd@0 144 end
wolffd@0 145 end
wolffd@0 146
wolffd@0 147
wolffd@0 148 %%%%%%%%%
wolffd@0 149
wolffd@0 150 function msg = init_msgs(dag, ns, evidence, eclass, CPD)
wolffd@0 151 % INIT_MSGS Initialize the lambda/pi message and state vectors (pearl_dbn)
wolffd@0 152 % msg = init_msgs(dag, ns, evidence)
wolffd@0 153
wolffd@0 154 N = length(dag);
wolffd@0 155 msg = cell(1,N);
wolffd@0 156 observed = ~isemptycell(evidence(:));
wolffd@0 157
wolffd@0 158 for n=1:N
wolffd@0 159 ps = parents(dag, n);
wolffd@0 160 msg{n}.pi_from_parent = cell(1, length(ps));
wolffd@0 161 for i=1:length(ps)
wolffd@0 162 p = ps(i);
wolffd@0 163 msg{n}.pi_from_parent{i} = ones(ns(p), 1);
wolffd@0 164 end
wolffd@0 165
wolffd@0 166 cs = children(dag, n);
wolffd@0 167 msg{n}.lambda_from_child = cell(1, length(cs));
wolffd@0 168 for i=1:length(cs)
wolffd@0 169 c = cs(i);
wolffd@0 170 msg{n}.lambda_from_child{i} = ones(ns(n), 1);
wolffd@0 171 end
wolffd@0 172
wolffd@0 173 msg{n}.lambda = ones(ns(n), 1);
wolffd@0 174 msg{n}.lambda_from_self = ones(ns(n), 1);
wolffd@0 175 msg{n}.pi = ones(ns(n), 1);
wolffd@0 176
wolffd@0 177 % Initialize the lambdas with any evidence
wolffd@0 178 if observed(n)
wolffd@0 179 v = evidence{n};
wolffd@0 180 %msg{n}.lambda_from_self = zeros(ns(n), 1);
wolffd@0 181 %msg{n}.lambda_from_self(v) = 1; % delta function
wolffd@0 182 msg{n}.lambda = zeros(ns(n), 1);
wolffd@0 183 msg{n}.lambda(v) = 1; % delta function
wolffd@0 184 end
wolffd@0 185
wolffd@0 186 end
wolffd@0 187
wolffd@0 188
wolffd@0 189 %%%%%%%%
wolffd@0 190
wolffd@0 191 function msg = init_ev_msgs(engine, evidence, msg)
wolffd@0 192
wolffd@0 193 [ss T] = size(evidence);
wolffd@0 194 bnet = bnet_from_engine(engine);
wolffd@0 195 pot_type = 'd';
wolffd@0 196 t = 1;
wolffd@0 197 hnodes = mysetdiff(1:ss, engine.onodes);
wolffd@0 198 for i=engine.onodes(:)'
wolffd@0 199 fam = family(bnet.dag, i);
wolffd@0 200 e = bnet.equiv_class(i, 1);
wolffd@0 201 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
wolffd@0 202 temp = pot_to_marginal(CPDpot);
wolffd@0 203 msg{i}.lambda_from_self = temp.T;
wolffd@0 204 end
wolffd@0 205 for t=2:T
wolffd@0 206 for i=engine.onodes(:)'
wolffd@0 207 fam = family(bnet.dag, i, 2); % extract from slice t
wolffd@0 208 e = bnet.equiv_class(i, 2);
wolffd@0 209 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
wolffd@0 210 temp = pot_to_marginal(CPDpot);
wolffd@0 211 n = i + (t-1)*ss;
wolffd@0 212 msg{n}.lambda_from_self = temp.T;
wolffd@0 213 end
wolffd@0 214 end
wolffd@0 215
wolffd@0 216
wolffd@0 217 %%%%%%%%%%%
wolffd@0 218
wolffd@0 219 function msg = init_ev_msgs2(engine, evidence, msg)
wolffd@0 220
wolffd@0 221 [ss T] = size(evidence);
wolffd@0 222 bnet = bnet_from_engine(engine);
wolffd@0 223 pot_type = 'd';
wolffd@0 224 t = 1;
wolffd@0 225 hnodes = mysetdiff(1:ss, engine.onodes);
wolffd@0 226 for i=engine.onodes(:)'
wolffd@0 227 fam = family(bnet.dag, i);
wolffd@0 228 e = bnet.equiv_class(i, 1);
wolffd@0 229 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,1));
wolffd@0 230 temp = pot_to_marginal(CPDpot);
wolffd@0 231 msg{i}.lambda_from_self = temp.T;
wolffd@0 232 end
wolffd@0 233 for t=2:T
wolffd@0 234 for i=engine.onodes(:)'
wolffd@0 235 fam = family(bnet.dag, i, 2); % extract from slice t
wolffd@0 236 e = bnet.equiv_class(i, 2);
wolffd@0 237 CPDpot = CPD_to_pot(pot_type, bnet.CPD{e}, fam, bnet.node_sizes(:), bnet.cnodes(:), evidence(:,t-1:t));
wolffd@0 238 temp = pot_to_marginal(CPDpot);
wolffd@0 239 n = i + (t-1)*ss;
wolffd@0 240 msg{n}.lambda_from_self = temp.T;
wolffd@0 241 end
wolffd@0 242 end
wolffd@0 243
wolffd@0 244