Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/general/convert_dbn_CPDs_to_tables.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPDpot = convert_dbn_CPDs_to_tables(bnet, evidence) | |
2 % CONVERT_DBN_CPDS_TO_TABLES Convert CPDs of (possibly instantiated) DBN nodes to tables | |
3 % CPDpot = convert_dbn_CPDs_to_tables(bnet, evidence) | |
4 % | |
5 % CPDpot{n,t} is a table containing P(n,t|pa(n,t), ev) | |
6 % All hidden nodes are assumed to be discrete. | |
7 % We assume the observed nodes are the same in every slice. | |
8 % | |
9 % Evaluating the conditional likelihood of long evidence sequences can be very slow, | |
10 % so we take pains to vectorize where possible. | |
11 | |
12 [ss T] = size(evidence); | |
13 %obs_bitv = ~isemptycell(evidence(:)); | |
14 obs_bitv = zeros(1, 2*ss); | |
15 obs_bitv(bnet.observed) = 1; | |
16 obs_bitv(bnet.observed+ss) = 1; | |
17 | |
18 ns = bnet.node_sizes(:); | |
19 CPDpot = cell(ss,T); | |
20 | |
21 for n=1:ss | |
22 % slice 1 | |
23 t = 1; | |
24 ps = parents(bnet.dag, n); | |
25 e = bnet.equiv_class(n, 1); | |
26 if ~any(obs_bitv(ps)) | |
27 CPDpot{n,t} = convert_CPD_to_table_hidden_ps(bnet.CPD{e}, evidence{n,t}); | |
28 else | |
29 CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps n], evidence(:,1)); | |
30 end | |
31 | |
32 % special cases: c=child, p=parents, d=discrete, h=hidden, 1sl=1slice | |
33 % if c=h=1 then c=d=1, since hidden nodes must be discrete | |
34 % c=h c=d p=h p=d 1sl method | |
35 % --------------------------- | |
36 % 1 1 1 1 - replicate CPT | |
37 % - 1 - 1 - evaluate CPT on evidence * | |
38 % 0 1 1 1 1 dhmm | |
39 % 0 0 1 1 1 ghmm | |
40 % other loop | |
41 % | |
42 % * = any subset of the domain may be observed | |
43 | |
44 % Example where all of the special cases occur - a hierarchical HMM | |
45 % where the top layer (G) and leaves (Y) are observed and | |
46 % all nodes are discrete except Y. | |
47 % (O turns on if Y is an outlier) | |
48 | |
49 % G ---------> G | |
50 % | | | |
51 % v v | |
52 % S --------> S | |
53 % | | | |
54 % v v | |
55 % Y Y | |
56 % ^ ^ | |
57 % | | | |
58 % O O | |
59 | |
60 % Evaluating P(yt|St,Ot) is the ghmm case | |
61 % Evaluating P(St|S(t-1),gt) is the eval CPT case | |
62 % Evaluating P(gt|g(t-1) is the eval CPT case (hdom = []) | |
63 % Evaluating P(Ot) is the replicated CPT case | |
64 | |
65 % Cts parents (e.g., inputs) would require an additional special case for speed | |
66 | |
67 | |
68 % slices 2..T | |
69 [ss T] = size(evidence); | |
70 self = n+ss; | |
71 ps = parents(bnet.dag, self); | |
72 e = bnet.equiv_class(n, 2); | |
73 | |
74 if 1 | |
75 debug = 0; | |
76 hidden_child = ~obs_bitv(n); | |
77 discrete_child = myismember(n, bnet.dnodes); | |
78 hidden_ps = all(~obs_bitv(ps)); | |
79 discrete_ps = mysubset(ps, bnet.dnodes); | |
80 parents_in_same_slice = all(ps > ss); | |
81 | |
82 if hidden_child & discrete_child & hidden_ps & discrete_ps | |
83 CPDpot = helper_repl(bnet, evidence, n, CPDpot, obs_bitv, debug); | |
84 elseif discrete_child & discrete_ps | |
85 CPDpot = helper_eval(bnet, evidence, n, CPDpot, obs_bitv, debug); | |
86 elseif discrete_child & hidden_ps & discrete_ps & parents_in_same_slice | |
87 CPDpot = helper_dhmm(bnet, evidence, n, CPDpot, obs_bitv, debug); | |
88 elseif ~discrete_child & hidden_ps & discrete_ps & parents_in_same_slice | |
89 CPDpot = helper_ghmm(bnet, evidence, n, CPDpot, obs_bitv, debug); | |
90 else | |
91 if debug, fprintf('node %d, slow\n', n); end | |
92 for t=2:T | |
93 CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t)); | |
94 end | |
95 end | |
96 end | |
97 | |
98 if 0 | |
99 for t=2:T | |
100 CPDpot2{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t)); | |
101 if ~approxeq(CPDpot{n,t}, CPDpot2{n,t}) | |
102 fprintf('CPDpot n=%d, t=%d\n',n,t); | |
103 keyboard | |
104 end | |
105 end | |
106 end | |
107 | |
108 | |
109 end | |
110 | |
111 | |
112 | |
113 | |
114 %%%%%%% | |
115 function CPDpot = helper_repl(bnet, evidence, n, CPDpot, obs_bitv, debug) | |
116 | |
117 [ss T] = size(evidence); | |
118 if debug, fprintf('node %d, repl\n', n); end | |
119 e = bnet.equiv_class(n, 2); | |
120 CPT = convert_CPD_to_table_hidden_ps(bnet.CPD{e}, []); | |
121 CPDpot(n,2:T) = num2cell(repmat(CPT, [1 1 T-1]), [1 2]); | |
122 | |
123 | |
124 | |
125 %%%%%%% | |
126 function CPDpot = helper_eval(bnet, evidence, n, CPDpot, obs_bitv, debug) | |
127 | |
128 [ss T] = size(evidence); | |
129 self = n+ss; | |
130 ps = parents(bnet.dag, self); | |
131 e = bnet.equiv_class(n, 2); | |
132 ns = bnet.node_sizes(:); | |
133 % Example: given CPT(p1, p2, p3, p4, c), where p1,p3 are observed | |
134 % we create CPT([p2 p4 c], [p1 p3]). | |
135 % We then convert all observed p1,p3 into indices ndx | |
136 % and return CPT(:, ndx) | |
137 CPT = CPD_to_CPT(bnet.CPD{e}); | |
138 domain = [ps self]; | |
139 % if dom is [3 7 8] and 3,8 are observed, odom_rel = [1 3], hdom_rel = 2, | |
140 % odom = [3 8], hdom = 7 | |
141 odom_rel = find(obs_bitv(domain)); | |
142 hdom_rel = find(~obs_bitv(domain)); | |
143 odom = domain(odom_rel); | |
144 hdom = domain(hdom_rel); | |
145 if isempty(hdom) | |
146 CPT = CPT(:); | |
147 else | |
148 CPT = permute(CPT, [hdom_rel odom_rel]); | |
149 CPT = reshape(CPT, prod(ns(hdom)), prod(ns(odom))); | |
150 end | |
151 parents_in_same_slice = all(ps > ss); | |
152 if parents_in_same_slice | |
153 if debug, fprintf('node %d eval 1 slice\n', n); end | |
154 data = cell2num(evidence(odom-ss,2:T)); %data(i,t) = val of i'th obs parent at t+1 | |
155 else | |
156 if debug, fprintf('node %d eval 2 slice\n', n); end | |
157 % there's probably a way of vectorizing this... | |
158 data = zeros(length(odom), T-1); | |
159 for t=2:T | |
160 ev = evidence(:,t-1:t); | |
161 ev = ev(:); | |
162 ev2 = ev(odom); | |
163 data(:,t-1) = cat(1, ev2{:}); | |
164 %data(:,t-1) = cell2num(ev2); | |
165 end | |
166 end | |
167 ndx = subv2ind(ns(odom), data'); % ndx(t) encodes data(:,t) | |
168 if isempty(hdom) | |
169 CPDpot(n,2:T) = num2cell(CPT(ndx)); % a cell array of floats | |
170 else | |
171 CPDpot(n,2:T) = num2cell(CPT(:, ndx), 1); % a cell array of column vectors | |
172 end | |
173 | |
174 %%%%%%% | |
175 function CPDpot = helper_dhmm(bnet, evidence, n, CPDpot, obs_bitv, debug) | |
176 | |
177 if debug, fprintf('node %d, dhmm\n', n); end | |
178 [ss T] = size(evidence); | |
179 self = n+ss; | |
180 ps = parents(bnet.dag, self); | |
181 e = bnet.equiv_class(n, 2); | |
182 ns = bnet.node_sizes(:); | |
183 CPT = CPD_to_CPT(bnet.CPD{e}); | |
184 CPT = reshape(CPT, [prod(ns(ps)) ns(self)]); % what if no parents? | |
185 %obslik = mk_dhmm_obs_lik(cell2num(evidence(n,2:T)), CPT); | |
186 obslik = eval_pdf_cond_multinomial(cell2num(evidence(n,2:T)), CPT); | |
187 CPDpot(n,2:T) = num2cell(obslik, 1); | |
188 | |
189 | |
190 %%%%%%% | |
191 function CPDpot = helper_ghmm(bnet, evidence, n, CPDpot, obs_bitv, debug) | |
192 | |
193 if debug, fprintf('node %d, ghmm\n', n); end | |
194 [ss T] = size(evidence); | |
195 e = bnet.equiv_class(n, 2); | |
196 S = struct(bnet.CPD{e}); | |
197 ev2 = cell2num(evidence(n,2:T)); | |
198 %obslik = mk_ghmm_obs_lik(ev2, S.mean, S.cov); | |
199 obslik = eval_pdf_cond_gauss(ev2, S.mean, S.cov); | |
200 CPDpot(n,2:T) = num2cell(obslik, 1); | |
201 |