To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.
root / _FullBNT / BNT / general / convert_dbn_CPDs_to_tables1.m @ 8:b5b38998ef3b
History | View | Annotate | Download (5.36 KB)
| 1 |
function CPDpot = convert_dbn_CPDs_to_tables1(bnet, evidence) |
|---|---|
| 2 |
% CONVERT_DBN_CPDS_TO_TABLES Convert CPDs of (possibly instantiated) DBN nodes to tables |
| 3 |
% CPDpot = convert_dbn_CPDs_to_tables(bnet, evidence) |
| 4 |
% |
| 5 |
% CPDpot{n,t} is a table containing P(n,t|pa(n,t), ev)
|
| 6 |
% All hidden nodes are assumed to be discrete |
| 7 |
% We assume the observed nodes are the same in every slice |
| 8 |
% |
| 9 |
% Evaluating the conditional likelihood of the evidence can be very slow, |
| 10 |
% so we take pains to vectorize where possible, i.e., we try to avoid |
| 11 |
% calling convert_to_table |
| 12 |
|
| 13 |
[ss T] = size(evidence); |
| 14 |
%obs_bitv = ~isemptycell(evidence(:)); |
| 15 |
obs_bitv = zeros(1, 2*ss); |
| 16 |
obs_bitv(bnet.observed) = 1; |
| 17 |
obs_bitv(bnet.observed+ss) = 1; |
| 18 |
|
| 19 |
ns = bnet.node_sizes(:); |
| 20 |
CPDpot = cell(ss,T); |
| 21 |
|
| 22 |
for n=1:ss |
| 23 |
% slice 1 |
| 24 |
t = 1; |
| 25 |
ps = parents(bnet.dag, n); |
| 26 |
e = bnet.equiv_class(n, 1); |
| 27 |
if ~any(obs_bitv(ps)) |
| 28 |
CPDpot{n,t} = convert_CPD_to_table_hidden_ps(bnet.CPD{e}, evidence{n,t});
|
| 29 |
else |
| 30 |
CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps n], evidence(:,1));
|
| 31 |
end |
| 32 |
|
| 33 |
% slices 2..T |
| 34 |
debug = 1; |
| 35 |
if ~obs_bitv(n) |
| 36 |
CPDpot = helper_hidden_child(bnet, evidence, n, CPDpot, obs_bitv, debug); |
| 37 |
else |
| 38 |
CPDpot = helper_obs_child(bnet, evidence, n, CPDpot, obs_bitv, debug); |
| 39 |
end |
| 40 |
end |
| 41 |
|
| 42 |
if 0 |
| 43 |
CPDpot2 = convert_dbn_CPDs_to_tables_slow(bnet, evidence); |
| 44 |
for t=1:T |
| 45 |
for n=1:ss |
| 46 |
if ~approxeq(CPDpot{n,t}, CPDpot2{n,t})
|
| 47 |
fprintf('CPDpot n=%d, t=%d\n',n,t);
|
| 48 |
keyboard |
| 49 |
end |
| 50 |
end |
| 51 |
end |
| 52 |
end |
| 53 |
|
| 54 |
|
| 55 |
% special cases: c=child, p=parents, d=discrete, h=hidden, 1=1slice |
| 56 |
% if c=h=1 then c=d=1, since hidden nodes must be discrete |
| 57 |
% c=h c=d p=h p=d p=1 method |
| 58 |
% --------------------------- |
| 59 |
% 1 1 1 1 - replicate CPT |
| 60 |
% 0 1 1 1 1 dhmm |
| 61 |
% 0 0 1 1 1 ghmm |
| 62 |
% - 1 - 1 - evaluate CPT on evidence |
| 63 |
% other loop |
| 64 |
|
| 65 |
%%%%%%% |
| 66 |
function CPDpot = helper_hidden_child(bnet, evidence, n, CPDpot, obs_bitv, debug) |
| 67 |
|
| 68 |
[ss T] = size(evidence); |
| 69 |
self = n+ss; |
| 70 |
ps = parents(bnet.dag, self); |
| 71 |
e = bnet.equiv_class(n, 2); |
| 72 |
ns = bnet.node_sizes(:); |
| 73 |
if ~any(obs_bitv(ps)) % all parents are hidden (hence discrete) |
| 74 |
if debug, fprintf('node %d is hidden, all ps are hidden\n', n); end
|
| 75 |
if myismember(n, bnet.dnodes) |
| 76 |
%CPT = CPD_to_CPT(bnet.CPD{e});
|
| 77 |
%CPT = reshape(CPT, [prod(ns(ps)) ns(self)]); |
| 78 |
CPT = convert_CPD_to_table_hidden_ps(bnet.CPD{e}, []);
|
| 79 |
CPDpot(n,2:T) = num2cell(repmat(CPT, [1 1 T-1]), [1 2]); |
| 80 |
else |
| 81 |
error(['hidden cts node disallowed']) |
| 82 |
end |
| 83 |
else % some parents are observed - slow |
| 84 |
if mysubset(ps, bnet.dnodes) % all parents are discrete |
| 85 |
% given CPT(p1, p2, p3, p4, c), where p1,p3 are observed |
| 86 |
% we create CPT([p2 p4 c], [p1 p3]). |
| 87 |
% We then convert all observed p1,p3 into indices ndx |
| 88 |
% and return CPT(:, ndx) |
| 89 |
CPT = CPD_to_CPT(bnet.CPD{e});
|
| 90 |
domain = [ps self]; |
| 91 |
% if dom is [3 7 8] and 3,8 are observed, odom_rel = [1 3], hdom_rel = 2, |
| 92 |
% odom = [3 8], hdom = 7 |
| 93 |
odom_rel = find(obs_bitv(domain)); |
| 94 |
hdom_rel = find(~obs_bitv(domain)); |
| 95 |
odom = domain(odom_rel); |
| 96 |
hdom = domain(hdom_rel); |
| 97 |
CPT = permute(CPT, [hdom_rel odom_rel]); |
| 98 |
CPT = reshape(CPT, prod(ns(hdom)), prod(ns(odom))); |
| 99 |
parents_in_same_slice = all(ps > ss); |
| 100 |
if parents_in_same_slice |
| 101 |
if debug, fprintf('node %d is hidden, some ps are obs, all ps discrete, 1 slice\n', n); end
|
| 102 |
data = cell2num(evidence(odom-ss,2:T)); %data(i,t) = val of i'th obs parent at t+1 |
| 103 |
else |
| 104 |
if debug, fprintf('node %d is hidden, some ps are obs, all ps discrete, 2 slice\n', n); end
|
| 105 |
data = zeros(length(odom), T-1); |
| 106 |
for t=2:T |
| 107 |
ev = evidence(:,t-1:t); |
| 108 |
data(:,t-1) = cell2num(ev(odom)); |
| 109 |
end |
| 110 |
end |
| 111 |
ndx = subv2ind(ns(odom), data'); % ndx(t) encodes data(:,t) |
| 112 |
CPDpot(n,2:T) = num2cell(CPT(:, ndx), [1 2]); |
| 113 |
else % some parents are cts - v slow |
| 114 |
if debug, fprintf('node %d is hidden, some ps are obs, some ps cts\n', n); end
|
| 115 |
for t=2:T |
| 116 |
CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t));
|
| 117 |
end |
| 118 |
end |
| 119 |
end |
| 120 |
|
| 121 |
%%%%%%% |
| 122 |
function CPDpot = helper_obs_child(bnet, evidence, n, CPDpot, obs_bitv, debug) |
| 123 |
|
| 124 |
[ss T] = size(evidence); |
| 125 |
self = n+ss; |
| 126 |
ps = parents(bnet.dag, self); |
| 127 |
e = bnet.equiv_class(n, 2); |
| 128 |
ns = bnet.node_sizes(:); |
| 129 |
if ~any(obs_bitv(ps)) % all parents are hidden |
| 130 |
parents_in_same_slice = all(ps > ss); |
| 131 |
if parents_in_same_slice |
| 132 |
if debug, fprintf('node %d is obs, all ps are hidden, 1 slice\n', n); end
|
| 133 |
ps1 = ps - ss; |
| 134 |
if myismember(n, bnet.dnodes) |
| 135 |
CPT = CPD_to_CPT(bnet.CPD{e});
|
| 136 |
CPT = reshape(CPT, [prod(ns(ps)) ns(self)]); % what if no parents? |
| 137 |
obslik = eval_pdf_cond_multinomial(cell2num(evidence(n,2:T)), CPT); |
| 138 |
CPDpot(n,2:T) = num2cell(obslik, 1); |
| 139 |
else |
| 140 |
S = struct(bnet.CPD{e});
|
| 141 |
obslik = eval_pdf_cond_gauss(cell2num(evidence(n,2:T)), S.mean, S.cov); |
| 142 |
CPDpot(n,2:T) = num2cell(obslik, 1); |
| 143 |
end |
| 144 |
else % parents span 2 slices - slow |
| 145 |
if debug, fprintf('node %d is obs, all ps are hidden , 2 slice\n', n); end
|
| 146 |
for t=2:T |
| 147 |
CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t));
|
| 148 |
end |
| 149 |
end |
| 150 |
else |
| 151 |
if isempty(ps) % observed root |
| 152 |
if debug, fprintf('node %d is obs, no ps\n', n); end
|
| 153 |
CPT = CPD_to_CPT(bnet.CPD{e});
|
| 154 |
data = cell2num(evidence(n,2:T)); |
| 155 |
CPDpot(n,2:T) = CPT(data); |
| 156 |
else % some parents are observed - slow |
| 157 |
if debug, fprintf('node %d is obs, some ps are obs\n', n); end
|
| 158 |
for t=2:T |
| 159 |
CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t));
|
| 160 |
end |
| 161 |
end |
| 162 |
end |