wolffd@0
|
1 function engine = gibbs_sampling_inf_engine(bnet, varargin)
|
wolffd@0
|
2 % GIBBS_SAMPLING_INF_ENGINE
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % engine = gibbs_sampling_inf_engine(bnet, ...)
|
wolffd@0
|
5 %
|
wolffd@0
|
6 % Optional parameters [default in brackets]
|
wolffd@0
|
7 % 'burnin' - How long before you start using the samples [100].
|
wolffd@0
|
8 % 'gap' - how often you use the samples in the estimate [1].
|
wolffd@0
|
9 % 'T' - number of samples [1000]
|
wolffd@0
|
10 % i.e, number of node flips (so, for
|
wolffd@0
|
11 % example if there are 10 nodes in the bnet, and T is 1000, each
|
wolffd@0
|
12 % node will get flipped 100 times (assuming a deterministic schedule))
|
wolffd@0
|
13 % The total running time is proportional to burnin + T*gap.
|
wolffd@0
|
14 %
|
wolffd@0
|
15 % 'order' - if the sampling schedule is deterministic, use this
|
wolffd@0
|
16 % parameter to specify the order in which nodes are sampled.
|
wolffd@0
|
17 % Order is allowed to include multiple copies of nodes, which is
|
wolffd@0
|
18 % useful if you want to, say, focus sampling on particular nodes.
|
wolffd@0
|
19 % Default is to use a deterministic schedule that goes through the
|
wolffd@0
|
20 % nodes in order.
|
wolffd@0
|
21 %
|
wolffd@0
|
22 % 'sampling_dist' - when using a stochastic sampling method, at
|
wolffd@0
|
23 % each step the node to sample is chosen according to this
|
wolffd@0
|
24 % distribution (may be unnormalized)
|
wolffd@0
|
25 %
|
wolffd@0
|
26 % The sampling_dist and order parameters shouldn't both be used,
|
wolffd@0
|
27 % and this will cause an assert.
|
wolffd@0
|
28 %
|
wolffd@0
|
29 %
|
wolffd@0
|
30 % Written by "Bhaskara Marthi" <bhaskara@cs.berkeley.edu> Feb 02.
|
wolffd@0
|
31
|
wolffd@0
|
32
|
wolffd@0
|
33 engine.burnin = 100;
|
wolffd@0
|
34 engine.gap = 1;
|
wolffd@0
|
35 engine.T = 1000;
|
wolffd@0
|
36 use_default_order = 1;
|
wolffd@0
|
37 engine.deterministic = 1;
|
wolffd@0
|
38 engine.order = {};
|
wolffd@0
|
39 engine.sampling_dist = {};
|
wolffd@0
|
40
|
wolffd@0
|
41 if nargin >= 2
|
wolffd@0
|
42 args = varargin;
|
wolffd@0
|
43 nargs = length(args);
|
wolffd@0
|
44 for i = 1:2:nargs
|
wolffd@0
|
45 switch args{i}
|
wolffd@0
|
46 case 'burnin'
|
wolffd@0
|
47 engine.burnin = args{i+1};
|
wolffd@0
|
48 case 'gap'
|
wolffd@0
|
49 engine.gap = args{i+1};
|
wolffd@0
|
50 case 'T'
|
wolffd@0
|
51 engine.T = args{i+1};
|
wolffd@0
|
52 case 'order'
|
wolffd@0
|
53 assert (use_default_order);
|
wolffd@0
|
54 use_default_order = 0;
|
wolffd@0
|
55 engine.order = args{i+1};
|
wolffd@0
|
56 case 'sampling_dist'
|
wolffd@0
|
57 assert (use_default_order);
|
wolffd@0
|
58 use_default_order = 0;
|
wolffd@0
|
59 engine.deterministic = 0;
|
wolffd@0
|
60 engine.sampling_dist = args{i+1};
|
wolffd@0
|
61 otherwise
|
wolffd@0
|
62 error(['unrecognized parameter to gibbs_sampling_inf_engine']);
|
wolffd@0
|
63 end
|
wolffd@0
|
64 end
|
wolffd@0
|
65 end
|
wolffd@0
|
66
|
wolffd@0
|
67 engine.slice_size = size(bnet.dag, 2);
|
wolffd@0
|
68 if (use_default_order)
|
wolffd@0
|
69 engine.order = 1:engine.slice_size;
|
wolffd@0
|
70 end
|
wolffd@0
|
71 engine.hnodes = [];
|
wolffd@0
|
72 engine.onodes = [];
|
wolffd@0
|
73 engine.evidence = [];
|
wolffd@0
|
74 engine.state = [];
|
wolffd@0
|
75 engine.marginal_counts = {};
|
wolffd@0
|
76
|
wolffd@0
|
77 % Precompute the strides for each CPT
|
wolffd@0
|
78 engine.strides = compute_strides(bnet);
|
wolffd@0
|
79
|
wolffd@0
|
80 % Precompute graphical information
|
wolffd@0
|
81 engine.families = compute_families(bnet);
|
wolffd@0
|
82 engine.children = compute_children(bnet);
|
wolffd@0
|
83
|
wolffd@0
|
84 % For convenience, store the CPTs as tables rather than objects
|
wolffd@0
|
85 engine.CPT = get_cpts(bnet);
|
wolffd@0
|
86
|
wolffd@0
|
87 engine = class(engine, 'gibbs_sampling_inf_engine', inf_engine(bnet));
|
wolffd@0
|
88
|
wolffd@0
|
89
|
wolffd@0
|
90
|
wolffd@0
|
91
|
wolffd@0
|
92
|
wolffd@0
|
93
|
wolffd@0
|
94
|
wolffd@0
|
95
|
wolffd@0
|
96
|
wolffd@0
|
97
|
wolffd@0
|
98
|
wolffd@0
|
99
|
wolffd@0
|
100
|
wolffd@0
|
101
|
wolffd@0
|
102
|
wolffd@0
|
103
|
wolffd@0
|
104
|