wolffd@0: function [fwdback, loglik, fwd_frontier, back_frontier] = enter_soft_evidence(engine, CPD, onodes, pot_type, filter) wolffd@0: % ENTER_SOFT_EVIDENCE Add soft evidence to network (frontier) wolffd@0: % [fwdback, loglik] = enter_soft_evidence(engine, CPDpot, onodes, filter) wolffd@0: wolffd@0: if nargin < 3, filter = 0; end wolffd@0: wolffd@0: [ss T] = size(CPD); wolffd@0: bnet = bnet_from_engine(engine); wolffd@0: ns = repmat(bnet.node_sizes_slice(:), 1, T); wolffd@0: cnodes = unroll_set(bnet.cnodes(:), ss, T); wolffd@0: wolffd@0: % FORWARDS wolffd@0: fwd = cell(ss,T); wolffd@0: ll = zeros(1,T); wolffd@0: S = 2*ss; % num. intermediate frontiers to get from t to t+1 wolffd@0: frontier = cell(S,T); wolffd@0: wolffd@0: % Start with empty frontier, and add each node in slice 1 wolffd@0: init = mk_initial_pot(pot_type, [], ns, cnodes, onodes); wolffd@0: t = 1; wolffd@0: s = 1; wolffd@0: j = 1; wolffd@0: frontier{s,t} = update(init, j, 1, CPD{j}, engine.fdom1{s}, pot_type, ns, cnodes, onodes); wolffd@0: fwd{j} = frontier{s,t}; wolffd@0: for s=2:ss wolffd@0: j = s; % add node j at step s wolffd@0: frontier{s,t} = update(frontier{s-1,t}, j, 1, CPD{j}, engine.fdom1{s}, pot_type, ns, cnodes, onodes); wolffd@0: fwd{j} = frontier{s,t}; wolffd@0: end wolffd@0: frontier{S,t} = frontier{ss,t}; wolffd@0: [frontier{S,t}, ll(1)] = normalize_pot(frontier{S,t}); wolffd@0: wolffd@0: % Now move frontier from slice to slice wolffd@0: OPS = engine.ops; wolffd@0: add = OPS>0; wolffd@0: nodes = [zeros(S,1) unroll_set(abs(OPS(:)), ss, T-1)]; wolffd@0: for t=2:T wolffd@0: offset = (t-2)*ss; wolffd@0: for s=1:S wolffd@0: if s==1 wolffd@0: prev_ndx = (t-2)*S + S; % S,t-1 wolffd@0: else wolffd@0: prev_ndx = (t-1)*S + s-1; % s-1,t wolffd@0: end wolffd@0: j = nodes(s,t); wolffd@0: frontier{s,t} = update(frontier{prev_ndx}, j, add(s), CPD{j}, engine.fdom{s}+offset, pot_type, ns, cnodes, onodes); wolffd@0: if add(s) wolffd@0: fwd{j} = frontier{s,t}; wolffd@0: end wolffd@0: end wolffd@0: [frontier{S,t}, ll(t)] = normalize_pot(frontier{S,t}); wolffd@0: end wolffd@0: loglik = sum(ll); wolffd@0: wolffd@0: wolffd@0: fwd_frontier = frontier; wolffd@0: wolffd@0: if filter wolffd@0: fwdback = fwd; wolffd@0: return; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: % BACKWARDS wolffd@0: back = cell(ss,T); wolffd@0: add = ~add; % forwards add = backwards remove wolffd@0: frontier = cell(S,T+1); wolffd@0: t = T; wolffd@0: dom = (1:ss) + (t-1)*ss; wolffd@0: frontier{1,T+1} = mk_initial_pot(pot_type, dom, ns, cnodes, onodes); % all 1s for last slice wolffd@0: for t=T:-1:2 wolffd@0: offset = (t-2)*ss; wolffd@0: for s=S:-1:1 % reverse order wolffd@0: if s==S wolffd@0: prev_ndx = t*S + 1; % 1,t+1 wolffd@0: else wolffd@0: prev_ndx = (t-1)*S + (s+1); % s+1,t wolffd@0: end wolffd@0: j = nodes(s,t); wolffd@0: if ~add(s) wolffd@0: back{j} = frontier{prev_ndx}; % save frontier before removing wolffd@0: end wolffd@0: frontier{s,t} = rev_update(frontier{prev_ndx}, t, s, j, add(s), CPD{j}, engine.fdom{s}+offset, pot_type, ns, cnodes, onodes); wolffd@0: end wolffd@0: frontier{1,t} = normalize_pot(frontier{1,t}); wolffd@0: end wolffd@0: % Remove each node in first slice until left with empty set wolffd@0: t = 1; wolffd@0: frontier{ss+1,t} = frontier{1,2}; wolffd@0: add = 0; wolffd@0: for s=ss:-1:1 wolffd@0: j = s; % remove node j at step s wolffd@0: back{j} = frontier{s+1,t}; wolffd@0: frontier{s,t} = rev_update(frontier{s+1,t}, t, s, j, add, CPD{j}, 1:s, pot_type, ns, cnodes, onodes); wolffd@0: end wolffd@0: wolffd@0: % COMBINE wolffd@0: for t=1:T wolffd@0: for i=1:ss wolffd@0: %fwd{i,t} = multiply_by_pot(fwd{i,t}, back{i,t}); wolffd@0: %fwdback{i,t} = normalize_pot(fwd{i,t}); wolffd@0: fwdback{i,t} = normalize_pot(multiply_pots(fwd{i,t}, back{i,t})); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: back_frontier = frontier; wolffd@0: wolffd@0: %%%%%%%%%% wolffd@0: function new_frontier = update(old_frontier, j, add, CPD, newdom, pot_type, ns, cnodes, onodes) wolffd@0: wolffd@0: if add wolffd@0: new_frontier = mk_initial_pot(pot_type, newdom, ns, cnodes, onodes); wolffd@0: new_frontier = multiply_by_pot(new_frontier, old_frontier); wolffd@0: new_frontier = multiply_by_pot(new_frontier, CPD); wolffd@0: else wolffd@0: new_frontier = marginalize_pot(old_frontier, mysetdiff(domain_pot(old_frontier), j)); wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%%%%% wolffd@0: function new_frontier = rev_update(old_frontier, t, s, j, add, CPD, junk, pot_type, ns, cnodes, onodes) wolffd@0: wolffd@0: olddom = domain_pot(old_frontier); wolffd@0: assert(isequal(junk, olddom)); wolffd@0: wolffd@0: if add wolffd@0: % add: extend domain to include j by multiplying by 1 wolffd@0: newdom = myunion(olddom, j); wolffd@0: new_frontier = mk_initial_pot(pot_type, newdom, ns, cnodes, onodes); wolffd@0: new_frontier = multiply_by_pot(new_frontier, old_frontier); wolffd@0: %fprintf('t=%d, s=%d, add %d to %s to make %s\n', t, s, j, num2str(olddom), num2str(newdom)); wolffd@0: else wolffd@0: % remove: multiply in CPT and then marginalize out j wolffd@0: % parents of j are guaranteed to be in old_frontier, else couldn't have added j on fwds pass wolffd@0: old_frontier = multiply_by_pot(old_frontier, CPD); wolffd@0: newdom = mysetdiff(olddom, j); wolffd@0: new_frontier = marginalize_pot(old_frontier, newdom); wolffd@0: %newdom2 = domain_pot(new_frontier); wolffd@0: %fprintf('t=%d, s=%d, rem %d from %s to make %s\n', t, s, j, num2str(olddom), num2str(newdom2)); wolffd@0: end wolffd@0: wolffd@0: