wolffd@0: function engine = gibbs_sampling_inf_engine(bnet, varargin) wolffd@0: % GIBBS_SAMPLING_INF_ENGINE wolffd@0: % wolffd@0: % engine = gibbs_sampling_inf_engine(bnet, ...) wolffd@0: % wolffd@0: % Optional parameters [default in brackets] wolffd@0: % 'burnin' - How long before you start using the samples [100]. wolffd@0: % 'gap' - how often you use the samples in the estimate [1]. wolffd@0: % 'T' - number of samples [1000] wolffd@0: % i.e, number of node flips (so, for wolffd@0: % example if there are 10 nodes in the bnet, and T is 1000, each wolffd@0: % node will get flipped 100 times (assuming a deterministic schedule)) wolffd@0: % The total running time is proportional to burnin + T*gap. wolffd@0: % wolffd@0: % 'order' - if the sampling schedule is deterministic, use this wolffd@0: % parameter to specify the order in which nodes are sampled. wolffd@0: % Order is allowed to include multiple copies of nodes, which is wolffd@0: % useful if you want to, say, focus sampling on particular nodes. wolffd@0: % Default is to use a deterministic schedule that goes through the wolffd@0: % nodes in order. wolffd@0: % wolffd@0: % 'sampling_dist' - when using a stochastic sampling method, at wolffd@0: % each step the node to sample is chosen according to this wolffd@0: % distribution (may be unnormalized) wolffd@0: % wolffd@0: % The sampling_dist and order parameters shouldn't both be used, wolffd@0: % and this will cause an assert. wolffd@0: % wolffd@0: % wolffd@0: % Written by "Bhaskara Marthi" Feb 02. wolffd@0: wolffd@0: wolffd@0: engine.burnin = 100; wolffd@0: engine.gap = 1; wolffd@0: engine.T = 1000; wolffd@0: use_default_order = 1; wolffd@0: engine.deterministic = 1; wolffd@0: engine.order = {}; wolffd@0: engine.sampling_dist = {}; wolffd@0: wolffd@0: if nargin >= 2 wolffd@0: args = varargin; wolffd@0: nargs = length(args); wolffd@0: for i = 1:2:nargs wolffd@0: switch args{i} wolffd@0: case 'burnin' wolffd@0: engine.burnin = args{i+1}; wolffd@0: case 'gap' wolffd@0: engine.gap = args{i+1}; wolffd@0: case 'T' wolffd@0: engine.T = args{i+1}; wolffd@0: case 'order' wolffd@0: assert (use_default_order); wolffd@0: use_default_order = 0; wolffd@0: engine.order = args{i+1}; wolffd@0: case 'sampling_dist' wolffd@0: assert (use_default_order); wolffd@0: use_default_order = 0; wolffd@0: engine.deterministic = 0; wolffd@0: engine.sampling_dist = args{i+1}; wolffd@0: otherwise wolffd@0: error(['unrecognized parameter to gibbs_sampling_inf_engine']); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: engine.slice_size = size(bnet.dag, 2); wolffd@0: if (use_default_order) wolffd@0: engine.order = 1:engine.slice_size; wolffd@0: end wolffd@0: engine.hnodes = []; wolffd@0: engine.onodes = []; wolffd@0: engine.evidence = []; wolffd@0: engine.state = []; wolffd@0: engine.marginal_counts = {}; wolffd@0: wolffd@0: % Precompute the strides for each CPT wolffd@0: engine.strides = compute_strides(bnet); wolffd@0: wolffd@0: % Precompute graphical information wolffd@0: engine.families = compute_families(bnet); wolffd@0: engine.children = compute_children(bnet); wolffd@0: wolffd@0: % For convenience, store the CPTs as tables rather than objects wolffd@0: engine.CPT = get_cpts(bnet); wolffd@0: wolffd@0: engine = class(engine, 'gibbs_sampling_inf_engine', inf_engine(bnet)); wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: wolffd@0: