comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@gibbs_sampling_inf_engine/marginal_nodes.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [marginal, engine] = marginal_nodes(engine, nodes, varargin);
2 % MARGINAL_NODES Compute the marginal on the specified query nodes
3 % (gibbs_sampling_engine)
4 % [marginal, engine] = marginal_nodes(engine, nodes, ...)
5 %
6 % returns Pr(X(nodes) | X(observedNodes))
7 %
8 % The engine is also modified, and so it is returned as well, since
9 % Matlab doesn't support passing by reference(!) So
10 % if you want to, for example, incrementally run gibbs for a few 100
11 % steps at a time, you should use the returned value.
12 %
13 % Optional arguments :
14 %
15 % 'reset_counts' is 1 if you want to reset the counts made in the
16 % past, and 0 otherwise (if the current query nodes are different
17 % from the previous query nodes, or if marginal_nodes has not been
18 % called before, reset_counts should be set to 1).
19 % By default it is 1.
20
21
22 reset_counts = 1;
23
24 if (nargin > 3)
25 args = varargin;
26 nargs = length(args);
27 for i = 1:2:nargs
28 switch args{i}
29 case 'reset_counts'
30 reset_counts = args{i+1};
31 otherwise
32 error(['Incorrect argument to gibbs_sampling_engine/' ...
33 ' marginal_nodes']);
34 end
35 end
36 end
37
38 % initialization stuff
39 bnet = bnet_from_engine(engine);
40 slice_size = engine.slice_size;
41 hnodes = engine.hnodes;
42 onodes = engine.onodes;
43 nonqnodes = mysetdiff(1:slice_size, nodes);
44 gap = engine.gap;
45 burnin = engine.burnin;
46 T_max = engine.T;
47 ns = bnet.node_sizes(nodes);
48
49
50 % Cache the strides for the marginal table
51 marg_strides = [1 cumprod(ns(1:end-1))];
52
53 % Reset counts if necessary
54 if (reset_counts == 1)
55 %state = sample_bnet(bnet, 1, 0);
56 %state = cell2num(sample_bnet(bnet, 'evidence', num2cell(engine.evidence)));
57 state = cell2num(sample_bnet(bnet));
58 state(onodes) = engine.evidence(onodes);
59 if (length(ns) == 1)
60 marginal_counts = zeros(ns(1),1);
61 else
62 marginal_counts = zeros(ns);
63 end
64
65 % Otherwise, use the counts that have been stored in the engine
66 else
67 state = engine.state;
68 state(onodes, :) = engine.evidence(onodes, :);
69 marginal_counts = engine.marginal_counts;
70 end
71
72 if (engine.deterministic == 1)
73 pos = 1;
74 order = engine.order;
75 orderSize = length(engine.order);
76 else
77 sampling_dist = normalise(engine.sampling_dist);
78 end
79
80
81 for t = 1:(T_max*gap+burnin)
82
83 % First, select node m to sample
84 if (engine.deterministic == 1)
85 m = engine.order(pos);
86 pos = pos+1;
87 if (pos > orderSize)
88 pos = 1;
89 end
90 else
91 m = my_sample_discrete(sampling_dist);
92 end
93
94
95 % If the node is observed, then don't bother resampling
96 if (myismember(m, onodes))
97 continue;
98 end
99
100 % Next, compute the posterior
101 post = compute_posterior (bnet, state, m, engine.strides, engine.families, ...
102 engine.children, engine.CPT);
103 state(m) = my_sample_discrete(post);
104
105 % Now update our monte carlo estimate of the posterior
106 % distribution on the query node
107 if ((mod(t-burnin, gap) == 0) & (t > burnin))
108
109 vals = state(nodes);
110 index = 1+marg_strides*(vals-1);
111 marginal_counts(index) = marginal_counts(index)+1;
112 end
113 end
114
115 % Store results for future computation. Note that we store
116 % unnormalized counts
117 engine.state = state;
118 engine.marginal_counts = marginal_counts;
119
120 marginal.T = normalise(marginal_counts);
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135