To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / general / mk_higher_order_dbn.m @ 8:b5b38998ef3b

History | View | Annotate | Download (5.29 KB)

1
function bnet = mk_higher_order_dbn(intra, inter, node_sizes, varargin)
2
% MK_DBN Make a Dynamic Bayesian Network.
3
%
4
% BNET = MK_DBN(INTRA, INTER, NODE_SIZES, ...) makes a DBN with arcs
5
% from i in slice t to j in slice t iff intra(i,j) = 1, and 
6
% from i in slice t to j in slice t+1 iff inter(i,j) = 1,
7
% for i,j in {1, 2, ..., n}, where n = num. nodes per slice, and t >= 1.
8
% node_sizes(i) is the number of values node i can take on.
9
% The nodes are assumed to be in topological order. Use TOPOLOGICAL_SORT if necessary.
10
% See also mk_bnet.
11
%
12
% Optional arguments [default in brackets]
13
% 'discrete' - list of discrete nodes [1:n]
14
% 'observed' - the list of nodes which will definitely be observed in every slice of every case [ [] ]
15
% 'eclass1' - equiv class for slice 1 [1:n]
16
% 'eclass2' - equiv class for slice 2 [tie nodes with equivalent parents to slice 1]
17
%    equiv_class1(i) = j means node i in slice 1 gets its parameters from bnet.CPD{j},
18
%    i.e., nodes i and j have tied parameters.
19
% 'intra1' - topology of first slice, if different from others
20
% 'names' - a cell array of strings to be associated with nodes 1:n [{}]
21
%    This creates an associative array, so you write e.g.
22
%     'evidence(bnet.names{'bar'}) = 42' instead of  'evidence(2} = 42' 
23
%     assuming names = { 'foo', 'bar', ...}.
24
%    
25
% For backwards compatibility with BNT2, arguments can also be specified as follows
26
%   bnet = mk_dbn(intra, inter, node_sizes, dnodes, eclass1, eclass2, intra1)
27
%
28
% After calling this function, you must specify the parameters (conditional probability
29
% distributions) using bnet.CPD{i} = gaussian_CPD(...) or tabular_CPD(...) etc.
30

    
31

    
32
n = length(intra);
33
ss = n;
34
bnet.nnodes_per_slice = ss;
35
bnet.intra = intra;
36
bnet.inter = inter;
37
bnet.intra1 = intra;
38

    
39
% As this method is used to generate a higher order Markov Model
40
% also connect from time slice t - i -> t with i > 1 has to be 
41
% taken into account.
42

    
43
%inter should be a three dimensional array where inter(:,:,i)
44
%describes the connections from time-slice t - i to t.  
45
[rows,columns,order] = size(inter);
46
assert(rows    == n);
47
assert(columns == n);
48
dag = zeros((order + 1)*n);
49

    
50
i = 0;
51
while i <= order
52
    j = i;
53
    while j <= order
54
        if j == i
55
            dag(1 + i*n:(i+1)*n,1+i*n:(i+1)*n) = intra;
56
        else
57
            dag(1+i*n:(i+1)*n,1+j*n:(j+1)*n) = inter(:,:,j - i);
58
        end
59
        j = j + 1;
60
    end;
61
    i = i + 1;
62
end;
63

    
64
bnet.dag = dag;
65
bnet.names = {};
66

    
67
directed = 1;
68
if ~acyclic(dag,directed)
69
  error('graph must be acyclic')
70
end
71

    
72
% Calculation of the equivalence classes
73
bnet.eclass1 = 1:n;
74
bnet.eclass = zeros(order + 1,ss);
75
bnet.eclass(1,:) = 1:n;
76
for i = 1:order
77
    bnet.eclass(i+1,:) = bnet.eclass(i,:);
78
    for j = 1:ss 
79
        if(isequal(parents(dag,(i-1)*n+j)+ss,parents(dag,(i*n + j))))
80
	   %fprintf('%d has isomorphic parents, eclass %d \n',j,bnet.eclass(i,j))
81
        else
82
	   bnet.eclass(i + 1,j) = max(bnet.eclass(i+1,:))+1;
83
	   %fprintf('%d has non isomorphic parents, eclass %d \n',j,bnet.eclass(i,j))  
84
	end;
85
    end;
86
end;
87
bnet.eclass1 = 1:n;
88

    
89
% To be compatible with whe rest of the code 
90
bnet.eclass2 = bnet.eclass(2,:);
91

    
92
dnodes = 1:n;
93
bnet.observed = [];
94

    
95
if nargin >= 4
96
  args = varargin;
97
  nargs = length(args);
98
  if ~isstr(args{1})
99
    if nargs >= 1 dnodes = args{1}; end
100
    if nargs >= 2 bnet.eclass1 = args{2}; bnet.eclass(1,:) = args{2}; end
101
    if nargs >= 3 bnet.eclass2 = args{3}; bnet.eclass(2,:) = args{2}; end
102
    if nargs >= 4 bnet.intra1 = args{4}; end
103
  else
104
    for i=1:2:nargs
105
      switch args{i},
106
       case 'discrete', dnodes = args{i+1}; 
107
       case 'observed', bnet.observed = args{i+1}; 
108
       case 'eclass1',  bnet.eclass1 = args{i+1}; bnet.eclass(1,:) = args{i+1}; 
109
       case 'eclass2',  bnet.eclass2 = args{i+1}; bnet.eclass(2,:) = args{i+1};
110
       case 'eclass',   bnet.eclass = args{i+1};  
111
       case 'intra1',  bnet.intra1 = args{i+1}; 
112
       %case 'ar_hmm',  bnet.ar_hmm = args{i+1};  % should check topology
113
       case 'names',  bnet.names = assocarray(args{i+1}, num2cell(1:n)); 
114
       otherwise,  
115
	error(['invalid argument name ' args{i}]);       
116
      end
117
    end
118
  end
119
end
120

    
121
bnet.observed = sort(bnet.observed); % for comparing sets
122
ns = node_sizes;
123
bnet.node_sizes_slice = ns(:)';
124
bnet.node_sizes = repmat(ns(:),1,order + 1);
125

    
126
cnodes = mysetdiff(1:n, dnodes);
127
bnet.dnodes_slice = dnodes;
128
bnet.cnodes_slice = cnodes;
129
bnet.dnodes = dnodes;
130
bnet.cnodes = cnodes;
131
% To adapt the function to higher order Markov models include dnodes for more 
132
% time slices
133
for i = 1:order
134
    bnet.dnodes = [bnet.dnodes dnodes+i*n];
135
    bnet.cnodes = [bnet.cnodes cnodes+i*n];
136
end
137

    
138
% Generieren einer Matrix, deren i-te Spalte die Aequivalenzklassen
139
% der i-ten Zeitscheibe enthaelt. 
140
bnet.equiv_class = [bnet.eclass(1,:)]';
141
for i = 2:(order + 1)
142
    bnet.equiv_class = [bnet.equiv_class   bnet.eclass(i,:)'];
143
end
144

    
145
bnet.CPD = cell(1,max(bnet.equiv_class(:)));
146

    
147
ss = n;
148
onodes = bnet.observed;
149
hnodes = mysetdiff(1:ss, onodes);
150
bnet.hidden_bitv = zeros(1,(order + 1)*ss);
151
for i = 0:order
152
    bnet.hidden_bitv(hnodes +i*ss) = 1;
153
end;
154

    
155
bnet.parents = cell(1, (order + 1)*ss);
156
for i=1:(order + 1)*ss
157
  bnet.parents{i} = parents(bnet.dag, i);
158
end
159

    
160
bnet.auto_regressive = zeros(1,ss);
161
% ar(i)=1 means (observed) node i depends on i in the  previous slice
162
for o=bnet.observed(:)'
163
  if any(bnet.parents{o+ss} <= ss)
164
    bnet.auto_regressive(o) = 1;
165
  end
166
end
167

    
168

    
169

    
170

    
171

    
172

    
173

    
174

    
175

    
176

    
177

    
178

    
179

    
180

    
181