To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.
root / _FullBNT / BNT / general / mk_dbn.m @ 8:b5b38998ef3b
History | View | Annotate | Download (4.14 KB)
| 1 |
function bnet = mk_dbn(intra, inter, node_sizes, varargin) |
|---|---|
| 2 |
% MK_DBN Make a Dynamic Bayesian Network. |
| 3 |
% |
| 4 |
% BNET = MK_DBN(INTRA, INTER, NODE_SIZES, ...) makes a DBN with arcs |
| 5 |
% from i in slice t to j in slice t iff intra(i,j) = 1, and |
| 6 |
% from i in slice t to j in slice t+1 iff inter(i,j) = 1, |
| 7 |
% for i,j in {1, 2, ..., n}, where n = num. nodes per slice, and t >= 1.
|
| 8 |
% node_sizes(i) is the number of values node i can take on. |
| 9 |
% The nodes are assumed to be in topological order. Use TOPOLOGICAL_SORT if necessary. |
| 10 |
% See also mk_bnet. |
| 11 |
% |
| 12 |
% Optional arguments [default in brackets] |
| 13 |
% 'discrete' - list of discrete nodes [1:n] |
| 14 |
% 'observed' - the list of nodes which will definitely be observed in every slice of every case [ [] ] |
| 15 |
% 'eclass1' - equiv class for slice 1 [1:n] |
| 16 |
% 'eclass2' - equiv class for slice 2 [tie nodes with equivalent parents to slice 1] |
| 17 |
% equiv_class1(i) = j means node i in slice 1 gets its parameters from bnet.CPD{j},
|
| 18 |
% i.e., nodes i and j have tied parameters. |
| 19 |
% 'intra1' - topology of first slice, if different from others |
| 20 |
% 'names' - a cell array of strings to be associated with nodes 1:n [{}]
|
| 21 |
% This creates an associative array, so you write e.g. |
| 22 |
% 'evidence(bnet.names{'bar'}) = 42' instead of 'evidence(2} = 42'
|
| 23 |
% assuming names = { 'foo', 'bar', ...}.
|
| 24 |
% |
| 25 |
% For backwards compatibility with BNT2, arguments can also be specified as follows |
| 26 |
% bnet = mk_dbn(intra, inter, node_sizes, dnodes, eclass1, eclass2, intra1) |
| 27 |
% |
| 28 |
% After calling this function, you must specify the parameters (conditional probability |
| 29 |
% distributions) using bnet.CPD{i} = gaussian_CPD(...) or tabular_CPD(...) etc.
|
| 30 |
|
| 31 |
|
| 32 |
n = length(intra); |
| 33 |
ss = n; |
| 34 |
bnet.nnodes_per_slice = ss; |
| 35 |
bnet.intra = intra; |
| 36 |
bnet.inter = inter; |
| 37 |
bnet.intra1 = intra; |
| 38 |
dag = zeros(2*n); |
| 39 |
dag(1:n,1:n) = bnet.intra1; |
| 40 |
dag(1:n,(1:n)+n) = bnet.inter; |
| 41 |
dag((1:n)+n,(1:n)+n) = bnet.intra; |
| 42 |
bnet.dag = dag; |
| 43 |
bnet.names = {};
|
| 44 |
|
| 45 |
directed = 1; |
| 46 |
if ~acyclic(dag,directed) |
| 47 |
error('graph must be acyclic')
|
| 48 |
end |
| 49 |
|
| 50 |
|
| 51 |
bnet.eclass1 = 1:n; |
| 52 |
%bnet.eclass2 = (1:n)+n; |
| 53 |
bnet.eclass2 = bnet.eclass1; |
| 54 |
for i=1:ss |
| 55 |
if isequal(parents(dag, i+ss), parents(dag, i)+ss) |
| 56 |
%fprintf('%d has isomorphic parents, eclass %d\n', i, bnet.eclass2(i))
|
| 57 |
else |
| 58 |
bnet.eclass2(i) = max(bnet.eclass2) + 1; |
| 59 |
%fprintf('%d has non isomorphic parents, eclass %d\n', i, bnet.eclass2(i))
|
| 60 |
end |
| 61 |
end |
| 62 |
|
| 63 |
dnodes = 1:n; |
| 64 |
bnet.observed = []; |
| 65 |
|
| 66 |
if nargin >= 4 |
| 67 |
args = varargin; |
| 68 |
nargs = length(args); |
| 69 |
if ~isstr(args{1})
|
| 70 |
if nargs >= 1, dnodes = args{1}; end
|
| 71 |
if nargs >= 2, bnet.eclass1 = args{2}; end
|
| 72 |
if nargs >= 3, bnet.eclass2 = args{3}; end
|
| 73 |
if nargs >= 4, bnet.intra1 = args{4}; end
|
| 74 |
else |
| 75 |
for i=1:2:nargs |
| 76 |
switch args{i},
|
| 77 |
case 'discrete', dnodes = args{i+1};
|
| 78 |
case 'observed', bnet.observed = args{i+1};
|
| 79 |
case 'eclass1', bnet.eclass1 = args{i+1};
|
| 80 |
case 'eclass2', bnet.eclass2 = args{i+1};
|
| 81 |
case 'intra1', bnet.intra1 = args{i+1};
|
| 82 |
%case 'ar_hmm', bnet.ar_hmm = args{i+1}; % should check topology
|
| 83 |
case 'names', bnet.names = assocarray(args{i+1}, num2cell(1:n));
|
| 84 |
otherwise, |
| 85 |
error(['invalid argument name ' args{i}]);
|
| 86 |
end |
| 87 |
end |
| 88 |
end |
| 89 |
end |
| 90 |
|
| 91 |
|
| 92 |
bnet.observed = sort(bnet.observed); % for comparing sets |
| 93 |
ns = node_sizes; |
| 94 |
bnet.node_sizes_slice = ns(:)'; |
| 95 |
bnet.node_sizes = [ns(:) ns(:)]; |
| 96 |
|
| 97 |
cnodes = mysetdiff(1:n, dnodes); |
| 98 |
bnet.dnodes_slice = dnodes; |
| 99 |
bnet.cnodes_slice = cnodes; |
| 100 |
bnet.dnodes = [dnodes dnodes+n]; |
| 101 |
bnet.cnodes = [cnodes cnodes+n]; |
| 102 |
|
| 103 |
bnet.equiv_class = [bnet.eclass1(:) bnet.eclass2(:)]; |
| 104 |
bnet.CPD = cell(1,max(bnet.equiv_class(:))); |
| 105 |
eclass = bnet.equiv_class(:); |
| 106 |
E = max(eclass); |
| 107 |
bnet.rep_of_eclass = zeros(1,E); |
| 108 |
for e=1:E |
| 109 |
mems = find(eclass==e); |
| 110 |
bnet.rep_of_eclass(e) = mems(1); |
| 111 |
end |
| 112 |
|
| 113 |
ss = n; |
| 114 |
onodes = bnet.observed; |
| 115 |
hnodes = mysetdiff(1:ss, onodes); |
| 116 |
bnet.hidden_bitv = zeros(1,2*ss); |
| 117 |
bnet.hidden_bitv(hnodes) = 1; |
| 118 |
bnet.hidden_bitv(hnodes+ss) = 1; |
| 119 |
|
| 120 |
bnet.parents = cell(1, 2*ss); |
| 121 |
for i=1:ss |
| 122 |
bnet.parents{i} = parents(bnet.dag, i);
|
| 123 |
bnet.parents{i+ss} = parents(bnet.dag, i+ss);
|
| 124 |
end |
| 125 |
|
| 126 |
bnet.auto_regressive = zeros(1,ss); |
| 127 |
% ar(i)=1 means (observed) node i depends on i in the previous slice |
| 128 |
for o=bnet.observed(:)' |
| 129 |
if any(bnet.parents{o+ss} <= ss)
|
| 130 |
bnet.auto_regressive(o) = 1; |
| 131 |
end |
| 132 |
end |
| 133 |
|