To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / general / mk_dbn.m @ 8:b5b38998ef3b

History | View | Annotate | Download (4.14 KB)

1
function bnet = mk_dbn(intra, inter, node_sizes, varargin)
2
% MK_DBN Make a Dynamic Bayesian Network.
3
%
4
% BNET = MK_DBN(INTRA, INTER, NODE_SIZES, ...) makes a DBN with arcs
5
% from i in slice t to j in slice t iff intra(i,j) = 1, and 
6
% from i in slice t to j in slice t+1 iff inter(i,j) = 1,
7
% for i,j in {1, 2, ..., n}, where n = num. nodes per slice, and t >= 1.
8
% node_sizes(i) is the number of values node i can take on.
9
% The nodes are assumed to be in topological order. Use TOPOLOGICAL_SORT if necessary.
10
% See also mk_bnet.
11
%
12
% Optional arguments [default in brackets]
13
% 'discrete' - list of discrete nodes [1:n]
14
% 'observed' - the list of nodes which will definitely be observed in every slice of every case [ [] ]
15
% 'eclass1' - equiv class for slice 1 [1:n]
16
% 'eclass2' - equiv class for slice 2 [tie nodes with equivalent parents to slice 1]
17
%    equiv_class1(i) = j means node i in slice 1 gets its parameters from bnet.CPD{j},
18
%    i.e., nodes i and j have tied parameters.
19
% 'intra1' - topology of first slice, if different from others
20
% 'names' - a cell array of strings to be associated with nodes 1:n [{}]
21
%    This creates an associative array, so you write e.g.
22
%     'evidence(bnet.names{'bar'}) = 42' instead of  'evidence(2} = 42' 
23
%     assuming names = { 'foo', 'bar', ...}.
24
%    
25
% For backwards compatibility with BNT2, arguments can also be specified as follows
26
%   bnet = mk_dbn(intra, inter, node_sizes, dnodes, eclass1, eclass2, intra1)
27
%
28
% After calling this function, you must specify the parameters (conditional probability
29
% distributions) using bnet.CPD{i} = gaussian_CPD(...) or tabular_CPD(...) etc.
30

    
31

    
32
n = length(intra);
33
ss = n;
34
bnet.nnodes_per_slice = ss;
35
bnet.intra = intra;
36
bnet.inter = inter;
37
bnet.intra1 = intra;
38
dag = zeros(2*n);
39
dag(1:n,1:n) = bnet.intra1;
40
dag(1:n,(1:n)+n) = bnet.inter;
41
dag((1:n)+n,(1:n)+n) = bnet.intra;
42
bnet.dag = dag;
43
bnet.names = {};
44

    
45
directed = 1;
46
if ~acyclic(dag,directed)
47
  error('graph must be acyclic')
48
end
49

    
50

    
51
bnet.eclass1 = 1:n;
52
%bnet.eclass2 = (1:n)+n;
53
bnet.eclass2 = bnet.eclass1;
54
for i=1:ss
55
  if isequal(parents(dag, i+ss), parents(dag, i)+ss)
56
    %fprintf('%d has isomorphic parents, eclass %d\n', i, bnet.eclass2(i))
57
  else
58
    bnet.eclass2(i) = max(bnet.eclass2) + 1;
59
    %fprintf('%d has non isomorphic parents, eclass %d\n', i, bnet.eclass2(i))
60
  end
61
end
62

    
63
dnodes = 1:n;
64
bnet.observed = [];
65

    
66
if nargin >= 4
67
  args = varargin;
68
  nargs = length(args);
69
  if ~isstr(args{1})
70
    if nargs >= 1, dnodes = args{1}; end
71
    if nargs >= 2, bnet.eclass1 = args{2}; end
72
    if nargs >= 3, bnet.eclass2 = args{3}; end
73
    if nargs >= 4, bnet.intra1 = args{4}; end
74
  else
75
    for i=1:2:nargs
76
      switch args{i},
77
       case 'discrete', dnodes = args{i+1}; 
78
       case 'observed', bnet.observed = args{i+1}; 
79
       case 'eclass1',  bnet.eclass1 = args{i+1}; 
80
       case 'eclass2',  bnet.eclass2 = args{i+1}; 
81
       case 'intra1',  bnet.intra1 = args{i+1}; 
82
       %case 'ar_hmm',  bnet.ar_hmm = args{i+1};  % should check topology
83
       case 'names',  bnet.names = assocarray(args{i+1}, num2cell(1:n)); 
84
       otherwise,  
85
	error(['invalid argument name ' args{i}]);       
86
      end
87
    end
88
  end
89
end
90

    
91

    
92
bnet.observed = sort(bnet.observed); % for comparing sets
93
ns = node_sizes;
94
bnet.node_sizes_slice = ns(:)';
95
bnet.node_sizes = [ns(:) ns(:)];
96

    
97
cnodes = mysetdiff(1:n, dnodes);
98
bnet.dnodes_slice = dnodes;
99
bnet.cnodes_slice = cnodes;
100
bnet.dnodes = [dnodes dnodes+n];
101
bnet.cnodes = [cnodes cnodes+n];
102

    
103
bnet.equiv_class = [bnet.eclass1(:) bnet.eclass2(:)];
104
bnet.CPD = cell(1,max(bnet.equiv_class(:)));
105
eclass = bnet.equiv_class(:);
106
E = max(eclass);
107
bnet.rep_of_eclass = zeros(1,E);
108
for e=1:E
109
  mems = find(eclass==e);
110
  bnet.rep_of_eclass(e) = mems(1);
111
end
112

    
113
ss = n;
114
onodes = bnet.observed;
115
hnodes = mysetdiff(1:ss, onodes);
116
bnet.hidden_bitv = zeros(1,2*ss);
117
bnet.hidden_bitv(hnodes) = 1;
118
bnet.hidden_bitv(hnodes+ss) = 1;
119

    
120
bnet.parents = cell(1, 2*ss);
121
for i=1:ss
122
  bnet.parents{i} = parents(bnet.dag, i);
123
  bnet.parents{i+ss} = parents(bnet.dag, i+ss);
124
end
125

    
126
bnet.auto_regressive = zeros(1,ss);
127
% ar(i)=1 means (observed) node i depends on i in the  previous slice
128
for o=bnet.observed(:)'
129
  if any(bnet.parents{o+ss} <= ss)
130
    bnet.auto_regressive(o) = 1;
131
  end
132
end
133