wolffd@0
|
1 function inter = learn_struct_dbn_reveal(seqs, ns, max_fan_in, penalty)
|
wolffd@0
|
2 % LEARN_STRUCT_DBN_REVEAL Learn inter-slice adjacency matrix given fully observable discrete time series
|
wolffd@0
|
3 % inter = learn_struct_dbn_reveal(seqs, node_sizes, max_fan_in, penalty)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % seqs{l}{i,t} = value of node i in slice t of time-series l.
|
wolffd@0
|
6 % If you have a single time series in an N*T array D, use
|
wolffd@0
|
7 % seqs = { num2cell(D) }.
|
wolffd@0
|
8 % If you have L time series, each of length T, in an N*T*L array D, use
|
wolffd@0
|
9 % seqs= cell(1,L); for l=1:L, seqs{l} = num2cell(D(:,:,l)); end
|
wolffd@0
|
10 % or, in vectorized form,
|
wolffd@0
|
11 % seqs = squeeze(num2cell(num2cell(D),[1 2]));
|
wolffd@0
|
12 % Currently the data is assumed to be discrete (1,2,...)
|
wolffd@0
|
13 %
|
wolffd@0
|
14 % node_sizes(i) is the number of possible values for node i
|
wolffd@0
|
15 % max_fan_in is the largest number of parents we allow per node (default: N)
|
wolffd@0
|
16 % penalty is weight given to the complexity penalty (default: 0.5)
|
wolffd@0
|
17 % A penalty of 0.5 gives the BIC score.
|
wolffd@0
|
18 % A penalty of 0 gives the ML score.
|
wolffd@0
|
19 % Maximizing likelihood is equivalent to maximizing mutual information between parents and child.
|
wolffd@0
|
20 %
|
wolffd@0
|
21 % inter(i,j) = 1 iff node in slice t connects to node j in slice t+1
|
wolffd@0
|
22 %
|
wolffd@0
|
23 % The parent set for each node in slice 2 is computed by evaluating all subsets of nodes in slice 1,
|
wolffd@0
|
24 % and picking the largest scoring one. This takes O(n^k) time per node, where n is the num. nodes
|
wolffd@0
|
25 % per slice, and k <= n is the max fan in.
|
wolffd@0
|
26 % Since all the nodes are observed, we do not need to use an inference engine.
|
wolffd@0
|
27 % And since we are only learning the inter-slice matrix, we do not need to check for cycles.
|
wolffd@0
|
28 %
|
wolffd@0
|
29 % This algorithm is described in
|
wolffd@0
|
30 % - "REVEAL: A general reverse engineering algorithm for inference of genetic network
|
wolffd@0
|
31 % architectures", Liang et al. PSB 1998
|
wolffd@0
|
32 % - "Extended dependency analysis of large systems",
|
wolffd@0
|
33 % Roger Conant, Intl. J. General Systems, 1988, vol 14, pp 97-141
|
wolffd@0
|
34 % - "Learning the structure of DBNs", Friedman, Murphy and Russell, UAI 1998.
|
wolffd@0
|
35
|
wolffd@0
|
36 n = length(ns);
|
wolffd@0
|
37
|
wolffd@0
|
38 if nargin < 3, max_fan_in = n; end
|
wolffd@0
|
39 if nargin < 4, penalty = 0.5; end
|
wolffd@0
|
40
|
wolffd@0
|
41 inter = zeros(n,n);
|
wolffd@0
|
42
|
wolffd@0
|
43 if ~iscell(seqs)
|
wolffd@0
|
44 data{1} = seqs;
|
wolffd@0
|
45 end
|
wolffd@0
|
46
|
wolffd@0
|
47 nseq = length(seqs);
|
wolffd@0
|
48 nslices = 0;
|
wolffd@0
|
49 data = cell(1, nseq);
|
wolffd@0
|
50 for l=1:nseq
|
wolffd@0
|
51 nslices = nslices + size(seqs{l}, 2);
|
wolffd@0
|
52 data{l} = cell2num(seqs{l})'; % each row is a case
|
wolffd@0
|
53 end
|
wolffd@0
|
54 ndata = nslices - nseq; % subtract off the initial slice of each sequence
|
wolffd@0
|
55
|
wolffd@0
|
56 % We concatenate the sequences as in the following example.
|
wolffd@0
|
57 % Let there be 2 sequences of lengths 4 and 5, with n nodes per slice,
|
wolffd@0
|
58 % and let i be the target node.
|
wolffd@0
|
59 % Then we construct following matrix D
|
wolffd@0
|
60 %
|
wolffd@0
|
61 % s{1}{1,1} ... s{1}{1,3} s{2}{1,1} ... s{2}{1,4}
|
wolffd@0
|
62 % ....
|
wolffd@0
|
63 % s{1}{n,1} ... s{1}{n,3} s{2}{n,1} ... s{2}{n,4}
|
wolffd@0
|
64 % s{1}{i,2} ... s{1}{i,4} s{2}{i,2} ... s{2}{i,5}
|
wolffd@0
|
65 %
|
wolffd@0
|
66 % D(1:n, i) is the i'th input and D(n+1, i) is the i'th output.
|
wolffd@0
|
67 %
|
wolffd@0
|
68 % We concatenate each sequence separately to avoid treating the transition
|
wolffd@0
|
69 % from the end of one sequence to the beginning of another as a "normal" transition.
|
wolffd@0
|
70
|
wolffd@0
|
71
|
wolffd@0
|
72 for i=1:n
|
wolffd@0
|
73 D = [];
|
wolffd@0
|
74 for l=1:nseq
|
wolffd@0
|
75 T = size(seqs{l}, 2);
|
wolffd@0
|
76 A = cell2num(seqs{l}(:, 1:T-1));
|
wolffd@0
|
77 B = cell2num(seqs{l}(i, 2:T));
|
wolffd@0
|
78 C = [A;B];
|
wolffd@0
|
79 D = [D C];
|
wolffd@0
|
80 end
|
wolffd@0
|
81 SS = subsets(1:n, max_fan_in, 1); % skip the empty set
|
wolffd@0
|
82 nSS = length(SS);
|
wolffd@0
|
83 bic_score = zeros(1, nSS);
|
wolffd@0
|
84 ll_score = zeros(1, nSS);
|
wolffd@0
|
85 target = n+1;
|
wolffd@0
|
86 ns2 = [ns ns(i)];
|
wolffd@0
|
87 for h=1:nSS
|
wolffd@0
|
88 ps = SS{h};
|
wolffd@0
|
89 dom = [ps target];
|
wolffd@0
|
90 counts = compute_counts(D(dom, :), ns2(dom));
|
wolffd@0
|
91 CPT = mk_stochastic(counts);
|
wolffd@0
|
92 [bic_score(h), ll_score(h)] = bic_score_family(counts, CPT, ndata);
|
wolffd@0
|
93 end
|
wolffd@0
|
94 if penalty == 0
|
wolffd@0
|
95 h = argmax(ll_score);
|
wolffd@0
|
96 else
|
wolffd@0
|
97 h = argmax(bic_score);
|
wolffd@0
|
98 end
|
wolffd@0
|
99 ps = SS{h};
|
wolffd@0
|
100 inter(ps, i) = 1;
|
wolffd@0
|
101 end
|