Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/general/convert_dbn_CPDs_to_tables1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPDpot = convert_dbn_CPDs_to_tables1(bnet, evidence) | |
2 % CONVERT_DBN_CPDS_TO_TABLES Convert CPDs of (possibly instantiated) DBN nodes to tables | |
3 % CPDpot = convert_dbn_CPDs_to_tables(bnet, evidence) | |
4 % | |
5 % CPDpot{n,t} is a table containing P(n,t|pa(n,t), ev) | |
6 % All hidden nodes are assumed to be discrete | |
7 % We assume the observed nodes are the same in every slice | |
8 % | |
9 % Evaluating the conditional likelihood of the evidence can be very slow, | |
10 % so we take pains to vectorize where possible, i.e., we try to avoid | |
11 % calling convert_to_table | |
12 | |
13 [ss T] = size(evidence); | |
14 %obs_bitv = ~isemptycell(evidence(:)); | |
15 obs_bitv = zeros(1, 2*ss); | |
16 obs_bitv(bnet.observed) = 1; | |
17 obs_bitv(bnet.observed+ss) = 1; | |
18 | |
19 ns = bnet.node_sizes(:); | |
20 CPDpot = cell(ss,T); | |
21 | |
22 for n=1:ss | |
23 % slice 1 | |
24 t = 1; | |
25 ps = parents(bnet.dag, n); | |
26 e = bnet.equiv_class(n, 1); | |
27 if ~any(obs_bitv(ps)) | |
28 CPDpot{n,t} = convert_CPD_to_table_hidden_ps(bnet.CPD{e}, evidence{n,t}); | |
29 else | |
30 CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps n], evidence(:,1)); | |
31 end | |
32 | |
33 % slices 2..T | |
34 debug = 1; | |
35 if ~obs_bitv(n) | |
36 CPDpot = helper_hidden_child(bnet, evidence, n, CPDpot, obs_bitv, debug); | |
37 else | |
38 CPDpot = helper_obs_child(bnet, evidence, n, CPDpot, obs_bitv, debug); | |
39 end | |
40 end | |
41 | |
42 if 0 | |
43 CPDpot2 = convert_dbn_CPDs_to_tables_slow(bnet, evidence); | |
44 for t=1:T | |
45 for n=1:ss | |
46 if ~approxeq(CPDpot{n,t}, CPDpot2{n,t}) | |
47 fprintf('CPDpot n=%d, t=%d\n',n,t); | |
48 keyboard | |
49 end | |
50 end | |
51 end | |
52 end | |
53 | |
54 | |
55 % special cases: c=child, p=parents, d=discrete, h=hidden, 1=1slice | |
56 % if c=h=1 then c=d=1, since hidden nodes must be discrete | |
57 % c=h c=d p=h p=d p=1 method | |
58 % --------------------------- | |
59 % 1 1 1 1 - replicate CPT | |
60 % 0 1 1 1 1 dhmm | |
61 % 0 0 1 1 1 ghmm | |
62 % - 1 - 1 - evaluate CPT on evidence | |
63 % other loop | |
64 | |
65 %%%%%%% | |
66 function CPDpot = helper_hidden_child(bnet, evidence, n, CPDpot, obs_bitv, debug) | |
67 | |
68 [ss T] = size(evidence); | |
69 self = n+ss; | |
70 ps = parents(bnet.dag, self); | |
71 e = bnet.equiv_class(n, 2); | |
72 ns = bnet.node_sizes(:); | |
73 if ~any(obs_bitv(ps)) % all parents are hidden (hence discrete) | |
74 if debug, fprintf('node %d is hidden, all ps are hidden\n', n); end | |
75 if myismember(n, bnet.dnodes) | |
76 %CPT = CPD_to_CPT(bnet.CPD{e}); | |
77 %CPT = reshape(CPT, [prod(ns(ps)) ns(self)]); | |
78 CPT = convert_CPD_to_table_hidden_ps(bnet.CPD{e}, []); | |
79 CPDpot(n,2:T) = num2cell(repmat(CPT, [1 1 T-1]), [1 2]); | |
80 else | |
81 error(['hidden cts node disallowed']) | |
82 end | |
83 else % some parents are observed - slow | |
84 if mysubset(ps, bnet.dnodes) % all parents are discrete | |
85 % given CPT(p1, p2, p3, p4, c), where p1,p3 are observed | |
86 % we create CPT([p2 p4 c], [p1 p3]). | |
87 % We then convert all observed p1,p3 into indices ndx | |
88 % and return CPT(:, ndx) | |
89 CPT = CPD_to_CPT(bnet.CPD{e}); | |
90 domain = [ps self]; | |
91 % if dom is [3 7 8] and 3,8 are observed, odom_rel = [1 3], hdom_rel = 2, | |
92 % odom = [3 8], hdom = 7 | |
93 odom_rel = find(obs_bitv(domain)); | |
94 hdom_rel = find(~obs_bitv(domain)); | |
95 odom = domain(odom_rel); | |
96 hdom = domain(hdom_rel); | |
97 CPT = permute(CPT, [hdom_rel odom_rel]); | |
98 CPT = reshape(CPT, prod(ns(hdom)), prod(ns(odom))); | |
99 parents_in_same_slice = all(ps > ss); | |
100 if parents_in_same_slice | |
101 if debug, fprintf('node %d is hidden, some ps are obs, all ps discrete, 1 slice\n', n); end | |
102 data = cell2num(evidence(odom-ss,2:T)); %data(i,t) = val of i'th obs parent at t+1 | |
103 else | |
104 if debug, fprintf('node %d is hidden, some ps are obs, all ps discrete, 2 slice\n', n); end | |
105 data = zeros(length(odom), T-1); | |
106 for t=2:T | |
107 ev = evidence(:,t-1:t); | |
108 data(:,t-1) = cell2num(ev(odom)); | |
109 end | |
110 end | |
111 ndx = subv2ind(ns(odom), data'); % ndx(t) encodes data(:,t) | |
112 CPDpot(n,2:T) = num2cell(CPT(:, ndx), [1 2]); | |
113 else % some parents are cts - v slow | |
114 if debug, fprintf('node %d is hidden, some ps are obs, some ps cts\n', n); end | |
115 for t=2:T | |
116 CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t)); | |
117 end | |
118 end | |
119 end | |
120 | |
121 %%%%%%% | |
122 function CPDpot = helper_obs_child(bnet, evidence, n, CPDpot, obs_bitv, debug) | |
123 | |
124 [ss T] = size(evidence); | |
125 self = n+ss; | |
126 ps = parents(bnet.dag, self); | |
127 e = bnet.equiv_class(n, 2); | |
128 ns = bnet.node_sizes(:); | |
129 if ~any(obs_bitv(ps)) % all parents are hidden | |
130 parents_in_same_slice = all(ps > ss); | |
131 if parents_in_same_slice | |
132 if debug, fprintf('node %d is obs, all ps are hidden, 1 slice\n', n); end | |
133 ps1 = ps - ss; | |
134 if myismember(n, bnet.dnodes) | |
135 CPT = CPD_to_CPT(bnet.CPD{e}); | |
136 CPT = reshape(CPT, [prod(ns(ps)) ns(self)]); % what if no parents? | |
137 obslik = eval_pdf_cond_multinomial(cell2num(evidence(n,2:T)), CPT); | |
138 CPDpot(n,2:T) = num2cell(obslik, 1); | |
139 else | |
140 S = struct(bnet.CPD{e}); | |
141 obslik = eval_pdf_cond_gauss(cell2num(evidence(n,2:T)), S.mean, S.cov); | |
142 CPDpot(n,2:T) = num2cell(obslik, 1); | |
143 end | |
144 else % parents span 2 slices - slow | |
145 if debug, fprintf('node %d is obs, all ps are hidden , 2 slice\n', n); end | |
146 for t=2:T | |
147 CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t)); | |
148 end | |
149 end | |
150 else | |
151 if isempty(ps) % observed root | |
152 if debug, fprintf('node %d is obs, no ps\n', n); end | |
153 CPT = CPD_to_CPT(bnet.CPD{e}); | |
154 data = cell2num(evidence(n,2:T)); | |
155 CPDpot(n,2:T) = CPT(data); | |
156 else % some parents are observed - slow | |
157 if debug, fprintf('node %d is obs, some ps are obs\n', n); end | |
158 for t=2:T | |
159 CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t)); | |
160 end | |
161 end | |
162 end |