To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / general / convert_dbn_CPDs_to_tables1.m @ 8:b5b38998ef3b

History | View | Annotate | Download (5.36 KB)

1
function CPDpot = convert_dbn_CPDs_to_tables1(bnet, evidence)
2
% CONVERT_DBN_CPDS_TO_TABLES Convert CPDs of (possibly instantiated) DBN nodes to tables
3
% CPDpot = convert_dbn_CPDs_to_tables(bnet, evidence)
4
%
5
% CPDpot{n,t} is a table containing P(n,t|pa(n,t), ev)
6
% All hidden nodes are assumed to be discrete
7
% We assume the observed nodes are the same in every slice
8
%
9
% Evaluating the conditional likelihood of the evidence can be very slow,
10
% so we take pains to vectorize where possible, i.e., we try to avoid
11
% calling convert_to_table
12

    
13
[ss T] = size(evidence);
14
%obs_bitv = ~isemptycell(evidence(:));
15
obs_bitv = zeros(1, 2*ss);
16
obs_bitv(bnet.observed) = 1;
17
obs_bitv(bnet.observed+ss) = 1;
18

    
19
ns = bnet.node_sizes(:);
20
CPDpot = cell(ss,T); 
21

    
22
for n=1:ss
23
  % slice 1
24
  t = 1;
25
  ps = parents(bnet.dag, n);
26
  e = bnet.equiv_class(n, 1);
27
  if ~any(obs_bitv(ps))
28
    CPDpot{n,t} = convert_CPD_to_table_hidden_ps(bnet.CPD{e}, evidence{n,t});
29
  else
30
    CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps n], evidence(:,1));
31
  end
32
  
33
  % slices 2..T
34
  debug = 1;
35
  if ~obs_bitv(n)
36
    CPDpot = helper_hidden_child(bnet, evidence, n, CPDpot, obs_bitv, debug);
37
  else
38
    CPDpot = helper_obs_child(bnet, evidence, n, CPDpot, obs_bitv, debug);
39
  end
40
end
41

    
42
if 0
43
CPDpot2 = convert_dbn_CPDs_to_tables_slow(bnet, evidence);
44
for t=1:T
45
  for n=1:ss
46
    if ~approxeq(CPDpot{n,t}, CPDpot2{n,t})
47
      fprintf('CPDpot n=%d, t=%d\n',n,t);
48
      keyboard
49
    end
50
  end
51
end
52
end
53

    
54

    
55
% special cases: c=child, p=parents, d=discrete, h=hidden, 1=1slice
56
% if c=h=1 then c=d=1, since hidden nodes must be discrete
57
% c=h c=d p=h p=d p=1 method
58
% ---------------------------
59
% 1   1   1   1   -   replicate CPT
60
% 0   1   1   1   1   dhmm
61
% 0   0   1   1   1   ghmm
62
% -   1   -   1   -   evaluate CPT on evidence
63
% other               loop
64

    
65
%%%%%%%
66
function CPDpot = helper_hidden_child(bnet, evidence, n, CPDpot, obs_bitv, debug)
67

    
68
[ss T] = size(evidence);
69
self = n+ss;
70
ps = parents(bnet.dag, self);
71
e = bnet.equiv_class(n, 2);
72
ns = bnet.node_sizes(:);
73
if ~any(obs_bitv(ps)) % all parents are hidden (hence discrete)
74
  if debug, fprintf('node %d is hidden, all ps are hidden\n', n); end
75
  if myismember(n, bnet.dnodes) 
76
    %CPT = CPD_to_CPT(bnet.CPD{e});
77
    %CPT = reshape(CPT, [prod(ns(ps)) ns(self)]);
78
    CPT = convert_CPD_to_table_hidden_ps(bnet.CPD{e}, []);
79
    CPDpot(n,2:T) = num2cell(repmat(CPT, [1 1 T-1]), [1 2]);
80
  else
81
    error(['hidden cts node disallowed'])
82
  end
83
else % some parents are observed - slow
84
  if mysubset(ps, bnet.dnodes) % all parents are discrete
85
    % given CPT(p1, p2, p3, p4, c), where p1,p3 are observed
86
    % we create CPT([p2 p4 c], [p1 p3]).
87
    % We then convert all observed p1,p3 into indices ndx
88
    % and return CPT(:, ndx)
89
    CPT = CPD_to_CPT(bnet.CPD{e});
90
    domain = [ps self];
91
    % if dom is [3 7 8] and 3,8 are observed, odom_rel = [1 3], hdom_rel = 2,
92
    % odom = [3 8], hdom = 7
93
    odom_rel = find(obs_bitv(domain));
94
    hdom_rel = find(~obs_bitv(domain));
95
    odom = domain(odom_rel);
96
    hdom = domain(hdom_rel);
97
    CPT = permute(CPT, [hdom_rel odom_rel]);
98
    CPT = reshape(CPT, prod(ns(hdom)), prod(ns(odom)));
99
    parents_in_same_slice = all(ps > ss);
100
    if parents_in_same_slice
101
      if debug, fprintf('node %d is hidden, some ps are obs, all ps discrete, 1 slice\n', n); end
102
      data = cell2num(evidence(odom-ss,2:T)); %data(i,t) = val of i'th obs parent at t+1
103
    else
104
      if debug, fprintf('node %d is hidden, some ps are obs, all ps discrete, 2 slice\n', n); end
105
      data = zeros(length(odom), T-1);
106
      for t=2:T
107
	ev = evidence(:,t-1:t);
108
	data(:,t-1) = cell2num(ev(odom));
109
      end
110
    end
111
    ndx = subv2ind(ns(odom), data'); % ndx(t) encodes data(:,t)
112
    CPDpot(n,2:T) = num2cell(CPT(:, ndx), [1 2]);
113
  else % some parents are cts - v slow
114
    if debug, fprintf('node %d is hidden, some ps are obs, some ps cts\n', n); end
115
    for t=2:T
116
      CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t));
117
    end
118
  end
119
end
120
  
121
%%%%%%%
122
function CPDpot = helper_obs_child(bnet, evidence, n, CPDpot, obs_bitv, debug)
123

    
124
[ss T] = size(evidence);
125
self = n+ss;
126
ps = parents(bnet.dag, self);
127
e = bnet.equiv_class(n, 2);
128
ns = bnet.node_sizes(:);
129
if ~any(obs_bitv(ps)) % all parents are hidden
130
  parents_in_same_slice = all(ps > ss);
131
  if parents_in_same_slice
132
    if debug, fprintf('node %d is obs, all ps are hidden, 1 slice\n', n); end
133
    ps1 = ps - ss;
134
    if myismember(n, bnet.dnodes) 
135
      CPT = CPD_to_CPT(bnet.CPD{e});
136
      CPT = reshape(CPT, [prod(ns(ps)) ns(self)]); % what if no parents?
137
      obslik = eval_pdf_cond_multinomial(cell2num(evidence(n,2:T)), CPT);
138
      CPDpot(n,2:T) = num2cell(obslik, 1);
139
    else
140
      S = struct(bnet.CPD{e}); 
141
      obslik = eval_pdf_cond_gauss(cell2num(evidence(n,2:T)), S.mean, S.cov);
142
      CPDpot(n,2:T) = num2cell(obslik, 1);
143
    end
144
  else % parents span 2 slices - slow
145
    if debug, fprintf('node %d is obs, all ps are hidden , 2 slice\n', n); end
146
    for t=2:T
147
      CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t));
148
    end
149
  end
150
else 
151
  if isempty(ps) % observed root
152
    if debug, fprintf('node %d is obs, no ps\n', n); end
153
    CPT = CPD_to_CPT(bnet.CPD{e});
154
    data = cell2num(evidence(n,2:T));
155
    CPDpot(n,2:T) = CPT(data);
156
  else  % some parents are observed  - slow
157
    if debug, fprintf('node %d is obs, some ps are obs\n', n); end
158
    for t=2:T
159
      CPDpot{n,t} = convert_to_table(bnet.CPD{e}, [ps self], evidence(:,t-1:t));
160
    end
161
  end
162
end