wolffd@0
|
1 function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin)
|
wolffd@0
|
2 % MY_LEARN_STRUCT_MCMC Monte Carlo Markov Chain search over DAGs assuming fully observed data
|
wolffd@0
|
3 % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % data(i,m) is the value of node i in case m.
|
wolffd@0
|
6 % ns(i) is the number of discrete values node i can take on.
|
wolffd@0
|
7 %
|
wolffd@0
|
8 % sampled_graphs{m} is the m'th sampled graph.
|
wolffd@0
|
9 % accept_ratio(t) = acceptance ratio at iteration t
|
wolffd@0
|
10 % num_edges(t) = number of edges in model at iteration t
|
wolffd@0
|
11 %
|
wolffd@0
|
12 % The following optional arguments can be specified in the form of name/value pairs:
|
wolffd@0
|
13 % [default value in brackets]
|
wolffd@0
|
14 %
|
wolffd@0
|
15 % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ]
|
wolffd@0
|
16 % Currently, only networks with all tabular nodes support Bayesian scoring.
|
wolffd@0
|
17 % type - type{i} is the type of CPD to use for node i, where the type is a string
|
wolffd@0
|
18 % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ]
|
wolffd@0
|
19 % params - params{i} contains optional arguments passed to the CPD constructor for node i,
|
wolffd@0
|
20 % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ]
|
wolffd@0
|
21 % discrete - the list of discrete nodes [ 1:N ]
|
wolffd@0
|
22 % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ]
|
wolffd@0
|
23 % nsamples - number of samples to draw from the chain after burn-in [ 100*N ]
|
wolffd@0
|
24 % burnin - number of steps to take before drawing samples [ 5*N ]
|
wolffd@0
|
25 % init_dag - starting point for the search [ zeros(N,N) ]
|
wolffd@0
|
26 %
|
wolffd@0
|
27 % e.g., samples = my_learn_struct_mcmc(data, ns, 'nsamples', 1000);
|
wolffd@0
|
28 %
|
wolffd@0
|
29 % Modified by Sonia Leach (SML) 2/4/02, 9/5/03
|
wolffd@0
|
30
|
wolffd@0
|
31
|
wolffd@0
|
32
|
wolffd@0
|
33 [n ncases] = size(data);
|
wolffd@0
|
34
|
wolffd@0
|
35
|
wolffd@0
|
36 % set default params
|
wolffd@0
|
37 type = cell(1,n);
|
wolffd@0
|
38 params = cell(1,n);
|
wolffd@0
|
39 for i=1:n
|
wolffd@0
|
40 type{i} = 'tabular';
|
wolffd@0
|
41 %params{i} = { 'prior', 1};
|
wolffd@0
|
42 params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 };
|
wolffd@0
|
43 end
|
wolffd@0
|
44 scoring_fn = 'bayesian';
|
wolffd@0
|
45 discrete = 1:n;
|
wolffd@0
|
46 clamped = zeros(n, ncases);
|
wolffd@0
|
47 nsamples = 100*n;
|
wolffd@0
|
48 burnin = 5*n;
|
wolffd@0
|
49 dag = zeros(n);
|
wolffd@0
|
50
|
wolffd@0
|
51 args = varargin;
|
wolffd@0
|
52 nargs = length(args);
|
wolffd@0
|
53 for i=1:2:nargs
|
wolffd@0
|
54 switch args{i},
|
wolffd@0
|
55 case 'nsamples', nsamples = args{i+1};
|
wolffd@0
|
56 case 'burnin', burnin = args{i+1};
|
wolffd@0
|
57 case 'init_dag', dag = args{i+1};
|
wolffd@0
|
58 case 'scoring_fn', scoring_fn = args{i+1};
|
wolffd@0
|
59 case 'type', type = args{i+1};
|
wolffd@0
|
60 case 'discrete', discrete = args{i+1};
|
wolffd@0
|
61 case 'clamped', clamped = args{i+1};
|
wolffd@0
|
62 case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end
|
wolffd@0
|
63 end
|
wolffd@0
|
64 end
|
wolffd@0
|
65
|
wolffd@0
|
66 % We implement the fast acyclicity check described by P. Giudici and R. Castelo,
|
wolffd@0
|
67 % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001.
|
wolffd@0
|
68
|
wolffd@0
|
69 % SML: also keep descendant matrix C
|
wolffd@0
|
70 use_giudici = 1;
|
wolffd@0
|
71 if use_giudici
|
wolffd@0
|
72 [nbrs, ops, nodes, A] = mk_nbrs_of_digraph(dag);
|
wolffd@0
|
73 else
|
wolffd@0
|
74 [nbrs, ops, nodes] = mk_nbrs_of_dag(dag);
|
wolffd@0
|
75 A = [];
|
wolffd@0
|
76 end
|
wolffd@0
|
77
|
wolffd@0
|
78 num_accepts = 1;
|
wolffd@0
|
79 num_rejects = 1;
|
wolffd@0
|
80 T = burnin + nsamples;
|
wolffd@0
|
81 accept_ratio = zeros(1, T);
|
wolffd@0
|
82 num_edges = zeros(1, T);
|
wolffd@0
|
83 sampled_graphs = cell(1, nsamples);
|
wolffd@0
|
84 %sampled_bitv = zeros(nsamples, n^2);
|
wolffd@0
|
85
|
wolffd@0
|
86 for t=1:T
|
wolffd@0
|
87 [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, ...
|
wolffd@0
|
88 nodes, ns, data, clamped, A, ...
|
wolffd@0
|
89 scoring_fn, discrete, type, params);
|
wolffd@0
|
90 num_edges(t) = sum(dag(:));
|
wolffd@0
|
91 num_accepts = num_accepts + accept;
|
wolffd@0
|
92 num_rejects = num_rejects + (1-accept);
|
wolffd@0
|
93 accept_ratio(t) = num_accepts/num_rejects;
|
wolffd@0
|
94 if t > burnin
|
wolffd@0
|
95 sampled_graphs{t-burnin} = dag;
|
wolffd@0
|
96 %sampled_bitv(t-burnin, :) = dag(:)';
|
wolffd@0
|
97 end
|
wolffd@0
|
98 end
|
wolffd@0
|
99
|
wolffd@0
|
100
|
wolffd@0
|
101 %%%%%%%%%
|
wolffd@0
|
102
|
wolffd@0
|
103
|
wolffd@0
|
104 function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ...
|
wolffd@0
|
105 take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
|
wolffd@0
|
106 scoring_fn, discrete, type, params, prior_w)
|
wolffd@0
|
107
|
wolffd@0
|
108
|
wolffd@0
|
109 use_giudici = ~isempty(A);
|
wolffd@0
|
110 if use_giudici
|
wolffd@0
|
111 [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, nbrs, ops, nodes,A); % updates A
|
wolffd@0
|
112 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag, new_A);
|
wolffd@0
|
113 else
|
wolffd@0
|
114 d = sample_discrete(normalise(ones(1, length(nbrs))));
|
wolffd@0
|
115 new_dag = nbrs{d};
|
wolffd@0
|
116 op = ops{d};
|
wolffd@0
|
117 i = nodes(d, 1); j = nodes(d, 2);
|
wolffd@0
|
118 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag);
|
wolffd@0
|
119 end
|
wolffd@0
|
120
|
wolffd@0
|
121 bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params);
|
wolffd@0
|
122
|
wolffd@0
|
123 %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs));
|
wolffd@0
|
124 R = bf * (length(nbrs) / length(new_nbrs));
|
wolffd@0
|
125 u = rand(1,1);
|
wolffd@0
|
126 if u > min(1,R) % reject the move
|
wolffd@0
|
127 accept = 0;
|
wolffd@0
|
128 new_dag = dag;
|
wolffd@0
|
129 new_nbrs = nbrs;
|
wolffd@0
|
130 new_ops = ops;
|
wolffd@0
|
131 new_nodes = nodes;
|
wolffd@0
|
132 else
|
wolffd@0
|
133 accept = 1;
|
wolffd@0
|
134 if use_giudici
|
wolffd@0
|
135 A = new_A; % new_A already updated in pick_digraph_nbr
|
wolffd@0
|
136 end
|
wolffd@0
|
137 end
|
wolffd@0
|
138
|
wolffd@0
|
139
|
wolffd@0
|
140 %%%%%%%%%
|
wolffd@0
|
141
|
wolffd@0
|
142 function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params)
|
wolffd@0
|
143
|
wolffd@0
|
144 u = find(clamped(j,:)==0);
|
wolffd@0
|
145 LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
|
wolffd@0
|
146 LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
|
wolffd@0
|
147 bf1 = exp(LLnew - LLold);
|
wolffd@0
|
148
|
wolffd@0
|
149 if strcmp(op, 'rev') % must also multiply in the changes to i's family
|
wolffd@0
|
150 u = find(clamped(i,:)==0);
|
wolffd@0
|
151 LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
|
wolffd@0
|
152 LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
|
wolffd@0
|
153 bf2 = exp(LLnew - LLold);
|
wolffd@0
|
154 else
|
wolffd@0
|
155 bf2 = 1;
|
wolffd@0
|
156 end
|
wolffd@0
|
157 bfactor = bf1 * bf2;
|
wolffd@0
|
158
|
wolffd@0
|
159
|
wolffd@0
|
160 %%%%%%%% Giudici stuff follows %%%%%%%%%%
|
wolffd@0
|
161
|
wolffd@0
|
162
|
wolffd@0
|
163 % SML: This now updates A as it goes from digraph it choses
|
wolffd@0
|
164 function [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A)
|
wolffd@0
|
165
|
wolffd@0
|
166 d = sample_discrete(normalise(ones(1, length(digraph_nbrs))));
|
wolffd@0
|
167 %d = myunidrnd(length(digraph_nbrs),1,1);
|
wolffd@0
|
168 i = nodes(d, 1); j = nodes(d, 2);
|
wolffd@0
|
169 new_dag = digraph_nbrs(:,:,d);
|
wolffd@0
|
170 op = ops{d};
|
wolffd@0
|
171 new_A = update_ancestor_matrix(A, op, i, j, new_dag);
|
wolffd@0
|
172
|
wolffd@0
|
173
|
wolffd@0
|
174 %%%%%%%%%%%%%%
|
wolffd@0
|
175
|
wolffd@0
|
176
|
wolffd@0
|
177 function A = update_ancestor_matrix(A, op, i, j, dag)
|
wolffd@0
|
178
|
wolffd@0
|
179 switch op
|
wolffd@0
|
180 case 'add',
|
wolffd@0
|
181 A = do_addition(A, op, i, j, dag);
|
wolffd@0
|
182 case 'del',
|
wolffd@0
|
183 A = do_removal(A, op, i, j, dag);
|
wolffd@0
|
184 case 'rev',
|
wolffd@0
|
185 A = do_removal(A, op, i, j, dag);
|
wolffd@0
|
186 A = do_addition(A, op, j, i, dag);
|
wolffd@0
|
187 end
|
wolffd@0
|
188
|
wolffd@0
|
189
|
wolffd@0
|
190 %%%%%%%%%%%%
|
wolffd@0
|
191
|
wolffd@0
|
192 function A = do_addition(A, op, i, j, dag)
|
wolffd@0
|
193
|
wolffd@0
|
194 A(j,i) = 1; % i is an ancestor of j
|
wolffd@0
|
195 anci = find(A(i,:));
|
wolffd@0
|
196 if ~isempty(anci)
|
wolffd@0
|
197 A(j,anci) = 1; % all of i's ancestors are added to Anc(j)
|
wolffd@0
|
198 end
|
wolffd@0
|
199 ancj = find(A(j,:));
|
wolffd@0
|
200 descj = find(A(:,j));
|
wolffd@0
|
201 if ~isempty(ancj)
|
wolffd@0
|
202 for k=descj(:)'
|
wolffd@0
|
203 A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j
|
wolffd@0
|
204 end
|
wolffd@0
|
205 end
|
wolffd@0
|
206
|
wolffd@0
|
207 %%%%%%%%%%%
|
wolffd@0
|
208 function A = do_removal(A, op, i, j, dag)
|
wolffd@0
|
209
|
wolffd@0
|
210 % find all the descendants of j, and put them in topological order
|
wolffd@0
|
211
|
wolffd@0
|
212 % SML: originally Kevin had the next line commented and the %* lines
|
wolffd@0
|
213 % being used but I think this is equivalent and much less expensive
|
wolffd@0
|
214 % I assume he put it there for debugging and never changed it back...?
|
wolffd@0
|
215 descj = find(A(:,j));
|
wolffd@0
|
216 %* R = reachability_graph(dag);
|
wolffd@0
|
217 %* descj = find(R(j,:));
|
wolffd@0
|
218
|
wolffd@0
|
219 order = topological_sort(dag);
|
wolffd@0
|
220
|
wolffd@0
|
221 % SML: originally Kevin used the %* line but this was extracting the
|
wolffd@0
|
222 % wrong things to sort
|
wolffd@0
|
223 %* descj_topnum = order(descj);
|
wolffd@0
|
224 [junk, perm] = sort(order); %SML:node i is perm(i)-TH in order
|
wolffd@0
|
225 descj_topnum = perm(descj); %SML:descj(i) is descj_topnum(i)-th in order
|
wolffd@0
|
226
|
wolffd@0
|
227 % SML: now re-sort descj by rank in descj_topnum
|
wolffd@0
|
228 [junk, perm] = sort(descj_topnum);
|
wolffd@0
|
229 descj = descj(perm);
|
wolffd@0
|
230
|
wolffd@0
|
231 % Update j and all its descendants
|
wolffd@0
|
232 A = update_row(A, j, dag);
|
wolffd@0
|
233 for k = descj(:)'
|
wolffd@0
|
234 A = update_row(A, k, dag);
|
wolffd@0
|
235 end
|
wolffd@0
|
236
|
wolffd@0
|
237 %%%%%%%%%%%
|
wolffd@0
|
238
|
wolffd@0
|
239 function A = old_do_removal(A, op, i, j, dag)
|
wolffd@0
|
240
|
wolffd@0
|
241 % find all the descendants of j, and put them in topological order
|
wolffd@0
|
242 % SML: originally Kevin had the next line commented and the %* lines
|
wolffd@0
|
243 % being used but I think this is equivalent and much less expensive
|
wolffd@0
|
244 % I assume he put it there for debugging and never changed it back...?
|
wolffd@0
|
245 descj = find(A(:,j));
|
wolffd@0
|
246 %* R = reachability_graph(dag);
|
wolffd@0
|
247 %* descj = find(R(j,:));
|
wolffd@0
|
248
|
wolffd@0
|
249 order = topological_sort(dag);
|
wolffd@0
|
250 descj_topnum = order(descj);
|
wolffd@0
|
251 [junk, perm] = sort(descj_topnum);
|
wolffd@0
|
252 descj = descj(perm);
|
wolffd@0
|
253 % Update j and all its descendants
|
wolffd@0
|
254 A = update_row(A, j, dag);
|
wolffd@0
|
255 for k = descj(:)'
|
wolffd@0
|
256 A = update_row(A, k, dag);
|
wolffd@0
|
257 end
|
wolffd@0
|
258
|
wolffd@0
|
259 %%%%%%%%%
|
wolffd@0
|
260
|
wolffd@0
|
261 function A = update_row(A, j, dag)
|
wolffd@0
|
262
|
wolffd@0
|
263 % We compute row j of A
|
wolffd@0
|
264 A(j, :) = 0;
|
wolffd@0
|
265 ps = parents(dag, j);
|
wolffd@0
|
266 if ~isempty(ps)
|
wolffd@0
|
267 A(j, ps) = 1;
|
wolffd@0
|
268 end
|
wolffd@0
|
269 for k=ps(:)'
|
wolffd@0
|
270 anck = find(A(k,:));
|
wolffd@0
|
271 if ~isempty(anck)
|
wolffd@0
|
272 A(j, anck) = 1;
|
wolffd@0
|
273 end
|
wolffd@0
|
274 end
|
wolffd@0
|
275
|
wolffd@0
|
276 %%%%%%%%
|
wolffd@0
|
277
|
wolffd@0
|
278 function A = init_ancestor_matrix(dag)
|
wolffd@0
|
279
|
wolffd@0
|
280 order = topological_sort(dag);
|
wolffd@0
|
281 A = zeros(length(dag));
|
wolffd@0
|
282 for j=order(:)'
|
wolffd@0
|
283 A = update_row(A, j, dag);
|
wolffd@0
|
284 end
|