Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/learning/learn_struct_dbn_reveal.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function inter = learn_struct_dbn_reveal(seqs, ns, max_fan_in, penalty) | |
2 % LEARN_STRUCT_DBN_REVEAL Learn inter-slice adjacency matrix given fully observable discrete time series | |
3 % inter = learn_struct_dbn_reveal(seqs, node_sizes, max_fan_in, penalty) | |
4 % | |
5 % seqs{l}{i,t} = value of node i in slice t of time-series l. | |
6 % If you have a single time series in an N*T array D, use | |
7 % seqs = { num2cell(D) }. | |
8 % If you have L time series, each of length T, in an N*T*L array D, use | |
9 % seqs= cell(1,L); for l=1:L, seqs{l} = num2cell(D(:,:,l)); end | |
10 % or, in vectorized form, | |
11 % seqs = squeeze(num2cell(num2cell(D),[1 2])); | |
12 % Currently the data is assumed to be discrete (1,2,...) | |
13 % | |
14 % node_sizes(i) is the number of possible values for node i | |
15 % max_fan_in is the largest number of parents we allow per node (default: N) | |
16 % penalty is weight given to the complexity penalty (default: 0.5) | |
17 % A penalty of 0.5 gives the BIC score. | |
18 % A penalty of 0 gives the ML score. | |
19 % Maximizing likelihood is equivalent to maximizing mutual information between parents and child. | |
20 % | |
21 % inter(i,j) = 1 iff node in slice t connects to node j in slice t+1 | |
22 % | |
23 % The parent set for each node in slice 2 is computed by evaluating all subsets of nodes in slice 1, | |
24 % and picking the largest scoring one. This takes O(n^k) time per node, where n is the num. nodes | |
25 % per slice, and k <= n is the max fan in. | |
26 % Since all the nodes are observed, we do not need to use an inference engine. | |
27 % And since we are only learning the inter-slice matrix, we do not need to check for cycles. | |
28 % | |
29 % This algorithm is described in | |
30 % - "REVEAL: A general reverse engineering algorithm for inference of genetic network | |
31 % architectures", Liang et al. PSB 1998 | |
32 % - "Extended dependency analysis of large systems", | |
33 % Roger Conant, Intl. J. General Systems, 1988, vol 14, pp 97-141 | |
34 % - "Learning the structure of DBNs", Friedman, Murphy and Russell, UAI 1998. | |
35 | |
36 n = length(ns); | |
37 | |
38 if nargin < 3, max_fan_in = n; end | |
39 if nargin < 4, penalty = 0.5; end | |
40 | |
41 inter = zeros(n,n); | |
42 | |
43 if ~iscell(seqs) | |
44 data{1} = seqs; | |
45 end | |
46 | |
47 nseq = length(seqs); | |
48 nslices = 0; | |
49 data = cell(1, nseq); | |
50 for l=1:nseq | |
51 nslices = nslices + size(seqs{l}, 2); | |
52 data{l} = cell2num(seqs{l})'; % each row is a case | |
53 end | |
54 ndata = nslices - nseq; % subtract off the initial slice of each sequence | |
55 | |
56 % We concatenate the sequences as in the following example. | |
57 % Let there be 2 sequences of lengths 4 and 5, with n nodes per slice, | |
58 % and let i be the target node. | |
59 % Then we construct following matrix D | |
60 % | |
61 % s{1}{1,1} ... s{1}{1,3} s{2}{1,1} ... s{2}{1,4} | |
62 % .... | |
63 % s{1}{n,1} ... s{1}{n,3} s{2}{n,1} ... s{2}{n,4} | |
64 % s{1}{i,2} ... s{1}{i,4} s{2}{i,2} ... s{2}{i,5} | |
65 % | |
66 % D(1:n, i) is the i'th input and D(n+1, i) is the i'th output. | |
67 % | |
68 % We concatenate each sequence separately to avoid treating the transition | |
69 % from the end of one sequence to the beginning of another as a "normal" transition. | |
70 | |
71 | |
72 for i=1:n | |
73 D = []; | |
74 for l=1:nseq | |
75 T = size(seqs{l}, 2); | |
76 A = cell2num(seqs{l}(:, 1:T-1)); | |
77 B = cell2num(seqs{l}(i, 2:T)); | |
78 C = [A;B]; | |
79 D = [D C]; | |
80 end | |
81 SS = subsets(1:n, max_fan_in, 1); % skip the empty set | |
82 nSS = length(SS); | |
83 bic_score = zeros(1, nSS); | |
84 ll_score = zeros(1, nSS); | |
85 target = n+1; | |
86 ns2 = [ns ns(i)]; | |
87 for h=1:nSS | |
88 ps = SS{h}; | |
89 dom = [ps target]; | |
90 counts = compute_counts(D(dom, :), ns2(dom)); | |
91 CPT = mk_stochastic(counts); | |
92 [bic_score(h), ll_score(h)] = bic_score_family(counts, CPT, ndata); | |
93 end | |
94 if penalty == 0 | |
95 h = argmax(ll_score); | |
96 else | |
97 h = argmax(bic_score); | |
98 end | |
99 ps = SS{h}; | |
100 inter(ps, i) = 1; | |
101 end |