comparison toolboxes/FullBNT-1.0.7/bnt/learning/kpm_learn_struct_mcmc.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin)
2 % LEARN_STRUCT_MCMC Monte Carla Markov Chain search over DAGs assuming fully observed data
3 % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...)
4 %
5 % data(i,m) is the value of node i in case m.
6 % ns(i) is the number of discrete values node i can take on.
7 %
8 % sampled_graphs{m} is the m'th sampled graph.
9 % accept_ratio(t) = acceptance ratio at iteration t
10 % num_edges(t) = number of edges in model at iteration t
11 %
12 % The following optional arguments can be specified in the form of name/value pairs:
13 % [default value in brackets]
14 %
15 % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ]
16 % Currently, only networks with all tabular nodes support Bayesian scoring.
17 % type - type{i} is the type of CPD to use for node i, where the type is a string
18 % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ]
19 % params - params{i} contains optional arguments passed to the CPD constructor for node i,
20 % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ]
21 % discrete - the list of discrete nodes [ 1:N ]
22 % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ]
23 % nsamples - number of samples to draw from the chain after burn-in [ 100*N ]
24 % burnin - number of steps to take before drawing samples [ 5*N ]
25 % init_dag - starting point for the search [ zeros(N,N) ]
26 %
27 % e.g., samples = learn_struct_mcmc(data, ns, 'nsamples', 1000);
28 %
29 % This interface is not backwards compatible with BNT2,
30 % but is designed to be compatible with the other learn_struct_xxx routines.
31 %
32 % Note: We currently assume a uniform structural prior.
33
34 [n ncases] = size(data);
35
36
37 % set default params
38 type = cell(1,n);
39 params = cell(1,n);
40 for i=1:n
41 type{i} = 'tabular';
42 %params{i} = { 'prior', 1 };
43 params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 };
44 end
45 scoring_fn = 'bayesian';
46 discrete = 1:n;
47 clamped = zeros(n, ncases);
48 nsamples = 100*n;
49 burnin = 5*n;
50 dag = zeros(n);
51
52 args = varargin;
53 nargs = length(args);
54 for i=1:2:nargs
55 switch args{i},
56 case 'nsamples', nsamples = args{i+1};
57 case 'burnin', burnin = args{i+1};
58 case 'init_dag', dag = args{i+1};
59 case 'scoring_fn', scoring_fn = args{i+1};
60 case 'type', type = args{i+1};
61 case 'discrete', discrete = args{i+1};
62 case 'clamped', clamped = args{i+1};
63 case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end
64 end
65 end
66
67 % We implement the fast acyclicity check described by P. Giudici and R. Castelo,
68 % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001.
69 use_giudici = 1;
70 if use_giudici
71 [nbrs, ops, nodes] = mk_nbrs_of_digraph(dag);
72 A = init_ancestor_matrix(dag);
73 else
74 [nbrs, ops, nodes] = mk_nbrs_of_dag(dag);
75 A = [];
76 end
77
78 num_accepts = 1;
79 num_rejects = 1;
80 T = burnin + nsamples;
81 accept_ratio = zeros(1, T);
82 num_edges = zeros(1, T);
83 sampled_graphs = cell(1, nsamples);
84 %sampled_bitv = zeros(nsamples, n^2);
85
86 for t=1:T
87 [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
88 scoring_fn, discrete, type, params);
89 num_edges(t) = sum(dag(:));
90 num_accepts = num_accepts + accept;
91 num_rejects = num_rejects + (1-accept);
92 accept_ratio(t) = num_accepts/num_rejects;
93 if t > burnin
94 sampled_graphs{t-burnin} = dag;
95 %sampled_bitv(t-burnin, :) = dag(:)';
96 end
97 end
98
99
100 %%%%%%%%%
101
102
103 function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ...
104 take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
105 scoring_fn, discrete, type, params)
106
107
108 use_giudici = ~isempty(A);
109 if use_giudici
110 [new_dag, op, i, j] = pick_digraph_nbr(dag, nbrs, ops, nodes, A);
111 %assert(acyclic(new_dag));
112 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag);
113 else
114 d = sample_discrete(normalise(ones(1, length(nbrs))));
115 new_dag = nbrs{d};
116 op = ops{d};
117 i = nodes(d, 1); j = nodes(d, 2);
118 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag);
119 end
120
121 bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params);
122
123 %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs));
124 R = bf * (length(nbrs) / length(new_nbrs));
125 u = rand(1,1);
126 if u > min(1,R) % reject the move
127 accept = 0;
128 new_dag = dag;
129 new_nbrs = nbrs;
130 new_ops = ops;
131 new_nodes = nodes;
132 else
133 accept = 1;
134 if use_giudici
135 A = update_ancestor_matrix(A, op, i, j, new_dag);
136 end
137 end
138
139
140 %%%%%%%%%
141
142 function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params)
143
144 u = find(clamped(j,:)==0);
145 LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
146 LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
147 bf1 = exp(LLnew - LLold);
148
149 if strcmp(op, 'rev') % must also multiply in the changes to i's family
150 u = find(clamped(i,:)==0);
151 LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
152 LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
153 bf2 = exp(LLnew - LLold);
154 else
155 bf2 = 1;
156 end
157 bfactor = bf1 * bf2;
158
159
160 %%%%%%%% Giudici stuff follows %%%%%%%%%%
161
162
163 function [new_dag, op, i, j] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A)
164
165 legal = 0;
166 while ~legal
167 d = sample_discrete(normalise(ones(1, length(digraph_nbrs))));
168 i = nodes(d, 1); j = nodes(d, 2);
169 switch ops{d}
170 case 'add',
171 if A(i,j)==0
172 legal = 1;
173 end
174 case 'del',
175 legal = 1;
176 case 'rev',
177 ps = mysetdiff(parents(dag, j), i);
178 % if any(A(ps,i)) then there is a path i -> parent of j -> j
179 % so reversing i->j would create a cycle
180 legal = ~any(A(ps, i));
181 end
182 end
183 %new_dag = digraph_nbrs{d};
184 new_dag = digraph_nbrs(:,:,d);
185 op = ops{d};
186 i = nodes(d, 1); j = nodes(d, 2);
187
188
189 %%%%%%%%%%%%%%
190
191
192 function A = update_ancestor_matrix(A, op, i, j, dag)
193
194 switch op
195 case 'add',
196 A = do_addition(A, op, i, j, dag);
197 case 'del',
198 A = do_removal(A, op, i, j, dag);
199 case 'rev',
200 A = do_removal(A, op, i, j, dag);
201 A = do_addition(A, op, j, i, dag);
202 end
203
204
205 %%%%%%%%%%%%
206
207 function A = do_addition(A, op, i, j, dag)
208
209 A(j,i) = 1; % i is an ancestor of j
210 anci = find(A(i,:));
211 if ~isempty(anci)
212 A(j,anci) = 1; % all of i's ancestors are added to Anc(j)
213 end
214 ancj = find(A(j,:));
215 descj = find(A(:,j));
216 if ~isempty(ancj)
217 for k=descj(:)'
218 A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j
219 end
220 end
221
222 %%%%%%%%%%%
223
224 function A = do_removal(A, op, i, j, dag)
225
226 % find all the descendants of j, and put them in topological order
227 %descj = find(A(:,j));
228 R = reachability_graph(dag);
229 descj = find(R(j,:));
230 order = topological_sort(dag);
231 descj_topnum = order(descj);
232 [junk, perm] = sort(descj_topnum);
233 descj = descj(perm);
234 % Update j and all its descendants
235 A = update_row(A, j, dag);
236 for k = descj(:)'
237 A = update_row(A, k, dag);
238 end
239
240 %%%%%%%%%
241
242 function A = update_row(A, j, dag)
243
244 % We compute row j of A
245 A(j, :) = 0;
246 ps = parents(dag, j);
247 if ~isempty(ps)
248 A(j, ps) = 1;
249 end
250 for k=ps(:)'
251 anck = find(A(k,:));
252 if ~isempty(anck)
253 A(j, anck) = 1;
254 end
255 end
256
257 %%%%%%%%
258
259 function A = init_ancestor_matrix(dag)
260
261 order = topological_sort(dag);
262 A = zeros(length(dag));
263 for j=order(:)'
264 A = update_row(A, j, dag);
265 end