Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/learning/learn_struct_mcmc.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin) | |
2 % MY_LEARN_STRUCT_MCMC Monte Carlo Markov Chain search over DAGs assuming fully observed data | |
3 % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...) | |
4 % | |
5 % data(i,m) is the value of node i in case m. | |
6 % ns(i) is the number of discrete values node i can take on. | |
7 % | |
8 % sampled_graphs{m} is the m'th sampled graph. | |
9 % accept_ratio(t) = acceptance ratio at iteration t | |
10 % num_edges(t) = number of edges in model at iteration t | |
11 % | |
12 % The following optional arguments can be specified in the form of name/value pairs: | |
13 % [default value in brackets] | |
14 % | |
15 % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ] | |
16 % Currently, only networks with all tabular nodes support Bayesian scoring. | |
17 % type - type{i} is the type of CPD to use for node i, where the type is a string | |
18 % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ] | |
19 % params - params{i} contains optional arguments passed to the CPD constructor for node i, | |
20 % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ] | |
21 % discrete - the list of discrete nodes [ 1:N ] | |
22 % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ] | |
23 % nsamples - number of samples to draw from the chain after burn-in [ 100*N ] | |
24 % burnin - number of steps to take before drawing samples [ 5*N ] | |
25 % init_dag - starting point for the search [ zeros(N,N) ] | |
26 % | |
27 % e.g., samples = my_learn_struct_mcmc(data, ns, 'nsamples', 1000); | |
28 % | |
29 % Modified by Sonia Leach (SML) 2/4/02, 9/5/03 | |
30 | |
31 | |
32 | |
33 [n ncases] = size(data); | |
34 | |
35 | |
36 % set default params | |
37 type = cell(1,n); | |
38 params = cell(1,n); | |
39 for i=1:n | |
40 type{i} = 'tabular'; | |
41 %params{i} = { 'prior', 1}; | |
42 params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 }; | |
43 end | |
44 scoring_fn = 'bayesian'; | |
45 discrete = 1:n; | |
46 clamped = zeros(n, ncases); | |
47 nsamples = 100*n; | |
48 burnin = 5*n; | |
49 dag = zeros(n); | |
50 | |
51 args = varargin; | |
52 nargs = length(args); | |
53 for i=1:2:nargs | |
54 switch args{i}, | |
55 case 'nsamples', nsamples = args{i+1}; | |
56 case 'burnin', burnin = args{i+1}; | |
57 case 'init_dag', dag = args{i+1}; | |
58 case 'scoring_fn', scoring_fn = args{i+1}; | |
59 case 'type', type = args{i+1}; | |
60 case 'discrete', discrete = args{i+1}; | |
61 case 'clamped', clamped = args{i+1}; | |
62 case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end | |
63 end | |
64 end | |
65 | |
66 % We implement the fast acyclicity check described by P. Giudici and R. Castelo, | |
67 % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001. | |
68 | |
69 % SML: also keep descendant matrix C | |
70 use_giudici = 1; | |
71 if use_giudici | |
72 [nbrs, ops, nodes, A] = mk_nbrs_of_digraph(dag); | |
73 else | |
74 [nbrs, ops, nodes] = mk_nbrs_of_dag(dag); | |
75 A = []; | |
76 end | |
77 | |
78 num_accepts = 1; | |
79 num_rejects = 1; | |
80 T = burnin + nsamples; | |
81 accept_ratio = zeros(1, T); | |
82 num_edges = zeros(1, T); | |
83 sampled_graphs = cell(1, nsamples); | |
84 %sampled_bitv = zeros(nsamples, n^2); | |
85 | |
86 for t=1:T | |
87 [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, ... | |
88 nodes, ns, data, clamped, A, ... | |
89 scoring_fn, discrete, type, params); | |
90 num_edges(t) = sum(dag(:)); | |
91 num_accepts = num_accepts + accept; | |
92 num_rejects = num_rejects + (1-accept); | |
93 accept_ratio(t) = num_accepts/num_rejects; | |
94 if t > burnin | |
95 sampled_graphs{t-burnin} = dag; | |
96 %sampled_bitv(t-burnin, :) = dag(:)'; | |
97 end | |
98 end | |
99 | |
100 | |
101 %%%%%%%%% | |
102 | |
103 | |
104 function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ... | |
105 take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ... | |
106 scoring_fn, discrete, type, params, prior_w) | |
107 | |
108 | |
109 use_giudici = ~isempty(A); | |
110 if use_giudici | |
111 [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, nbrs, ops, nodes,A); % updates A | |
112 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag, new_A); | |
113 else | |
114 d = sample_discrete(normalise(ones(1, length(nbrs)))); | |
115 new_dag = nbrs{d}; | |
116 op = ops{d}; | |
117 i = nodes(d, 1); j = nodes(d, 2); | |
118 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag); | |
119 end | |
120 | |
121 bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params); | |
122 | |
123 %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs)); | |
124 R = bf * (length(nbrs) / length(new_nbrs)); | |
125 u = rand(1,1); | |
126 if u > min(1,R) % reject the move | |
127 accept = 0; | |
128 new_dag = dag; | |
129 new_nbrs = nbrs; | |
130 new_ops = ops; | |
131 new_nodes = nodes; | |
132 else | |
133 accept = 1; | |
134 if use_giudici | |
135 A = new_A; % new_A already updated in pick_digraph_nbr | |
136 end | |
137 end | |
138 | |
139 | |
140 %%%%%%%%% | |
141 | |
142 function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params) | |
143 | |
144 u = find(clamped(j,:)==0); | |
145 LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); | |
146 LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); | |
147 bf1 = exp(LLnew - LLold); | |
148 | |
149 if strcmp(op, 'rev') % must also multiply in the changes to i's family | |
150 u = find(clamped(i,:)==0); | |
151 LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); | |
152 LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); | |
153 bf2 = exp(LLnew - LLold); | |
154 else | |
155 bf2 = 1; | |
156 end | |
157 bfactor = bf1 * bf2; | |
158 | |
159 | |
160 %%%%%%%% Giudici stuff follows %%%%%%%%%% | |
161 | |
162 | |
163 % SML: This now updates A as it goes from digraph it choses | |
164 function [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A) | |
165 | |
166 d = sample_discrete(normalise(ones(1, length(digraph_nbrs)))); | |
167 %d = myunidrnd(length(digraph_nbrs),1,1); | |
168 i = nodes(d, 1); j = nodes(d, 2); | |
169 new_dag = digraph_nbrs(:,:,d); | |
170 op = ops{d}; | |
171 new_A = update_ancestor_matrix(A, op, i, j, new_dag); | |
172 | |
173 | |
174 %%%%%%%%%%%%%% | |
175 | |
176 | |
177 function A = update_ancestor_matrix(A, op, i, j, dag) | |
178 | |
179 switch op | |
180 case 'add', | |
181 A = do_addition(A, op, i, j, dag); | |
182 case 'del', | |
183 A = do_removal(A, op, i, j, dag); | |
184 case 'rev', | |
185 A = do_removal(A, op, i, j, dag); | |
186 A = do_addition(A, op, j, i, dag); | |
187 end | |
188 | |
189 | |
190 %%%%%%%%%%%% | |
191 | |
192 function A = do_addition(A, op, i, j, dag) | |
193 | |
194 A(j,i) = 1; % i is an ancestor of j | |
195 anci = find(A(i,:)); | |
196 if ~isempty(anci) | |
197 A(j,anci) = 1; % all of i's ancestors are added to Anc(j) | |
198 end | |
199 ancj = find(A(j,:)); | |
200 descj = find(A(:,j)); | |
201 if ~isempty(ancj) | |
202 for k=descj(:)' | |
203 A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j | |
204 end | |
205 end | |
206 | |
207 %%%%%%%%%%% | |
208 function A = do_removal(A, op, i, j, dag) | |
209 | |
210 % find all the descendants of j, and put them in topological order | |
211 | |
212 % SML: originally Kevin had the next line commented and the %* lines | |
213 % being used but I think this is equivalent and much less expensive | |
214 % I assume he put it there for debugging and never changed it back...? | |
215 descj = find(A(:,j)); | |
216 %* R = reachability_graph(dag); | |
217 %* descj = find(R(j,:)); | |
218 | |
219 order = topological_sort(dag); | |
220 | |
221 % SML: originally Kevin used the %* line but this was extracting the | |
222 % wrong things to sort | |
223 %* descj_topnum = order(descj); | |
224 [junk, perm] = sort(order); %SML:node i is perm(i)-TH in order | |
225 descj_topnum = perm(descj); %SML:descj(i) is descj_topnum(i)-th in order | |
226 | |
227 % SML: now re-sort descj by rank in descj_topnum | |
228 [junk, perm] = sort(descj_topnum); | |
229 descj = descj(perm); | |
230 | |
231 % Update j and all its descendants | |
232 A = update_row(A, j, dag); | |
233 for k = descj(:)' | |
234 A = update_row(A, k, dag); | |
235 end | |
236 | |
237 %%%%%%%%%%% | |
238 | |
239 function A = old_do_removal(A, op, i, j, dag) | |
240 | |
241 % find all the descendants of j, and put them in topological order | |
242 % SML: originally Kevin had the next line commented and the %* lines | |
243 % being used but I think this is equivalent and much less expensive | |
244 % I assume he put it there for debugging and never changed it back...? | |
245 descj = find(A(:,j)); | |
246 %* R = reachability_graph(dag); | |
247 %* descj = find(R(j,:)); | |
248 | |
249 order = topological_sort(dag); | |
250 descj_topnum = order(descj); | |
251 [junk, perm] = sort(descj_topnum); | |
252 descj = descj(perm); | |
253 % Update j and all its descendants | |
254 A = update_row(A, j, dag); | |
255 for k = descj(:)' | |
256 A = update_row(A, k, dag); | |
257 end | |
258 | |
259 %%%%%%%%% | |
260 | |
261 function A = update_row(A, j, dag) | |
262 | |
263 % We compute row j of A | |
264 A(j, :) = 0; | |
265 ps = parents(dag, j); | |
266 if ~isempty(ps) | |
267 A(j, ps) = 1; | |
268 end | |
269 for k=ps(:)' | |
270 anck = find(A(k,:)); | |
271 if ~isempty(anck) | |
272 A(j, anck) = 1; | |
273 end | |
274 end | |
275 | |
276 %%%%%%%% | |
277 | |
278 function A = init_ancestor_matrix(dag) | |
279 | |
280 order = topological_sort(dag); | |
281 A = zeros(length(dag)); | |
282 for j=order(:)' | |
283 A = update_row(A, j, dag); | |
284 end |