annotate toolboxes/FullBNT-1.0.7/bnt/learning/learn_struct_mcmc.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin)
wolffd@0 2 % MY_LEARN_STRUCT_MCMC Monte Carlo Markov Chain search over DAGs assuming fully observed data
wolffd@0 3 % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...)
wolffd@0 4 %
wolffd@0 5 % data(i,m) is the value of node i in case m.
wolffd@0 6 % ns(i) is the number of discrete values node i can take on.
wolffd@0 7 %
wolffd@0 8 % sampled_graphs{m} is the m'th sampled graph.
wolffd@0 9 % accept_ratio(t) = acceptance ratio at iteration t
wolffd@0 10 % num_edges(t) = number of edges in model at iteration t
wolffd@0 11 %
wolffd@0 12 % The following optional arguments can be specified in the form of name/value pairs:
wolffd@0 13 % [default value in brackets]
wolffd@0 14 %
wolffd@0 15 % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ]
wolffd@0 16 % Currently, only networks with all tabular nodes support Bayesian scoring.
wolffd@0 17 % type - type{i} is the type of CPD to use for node i, where the type is a string
wolffd@0 18 % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ]
wolffd@0 19 % params - params{i} contains optional arguments passed to the CPD constructor for node i,
wolffd@0 20 % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ]
wolffd@0 21 % discrete - the list of discrete nodes [ 1:N ]
wolffd@0 22 % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ]
wolffd@0 23 % nsamples - number of samples to draw from the chain after burn-in [ 100*N ]
wolffd@0 24 % burnin - number of steps to take before drawing samples [ 5*N ]
wolffd@0 25 % init_dag - starting point for the search [ zeros(N,N) ]
wolffd@0 26 %
wolffd@0 27 % e.g., samples = my_learn_struct_mcmc(data, ns, 'nsamples', 1000);
wolffd@0 28 %
wolffd@0 29 % Modified by Sonia Leach (SML) 2/4/02, 9/5/03
wolffd@0 30
wolffd@0 31
wolffd@0 32
wolffd@0 33 [n ncases] = size(data);
wolffd@0 34
wolffd@0 35
wolffd@0 36 % set default params
wolffd@0 37 type = cell(1,n);
wolffd@0 38 params = cell(1,n);
wolffd@0 39 for i=1:n
wolffd@0 40 type{i} = 'tabular';
wolffd@0 41 %params{i} = { 'prior', 1};
wolffd@0 42 params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 };
wolffd@0 43 end
wolffd@0 44 scoring_fn = 'bayesian';
wolffd@0 45 discrete = 1:n;
wolffd@0 46 clamped = zeros(n, ncases);
wolffd@0 47 nsamples = 100*n;
wolffd@0 48 burnin = 5*n;
wolffd@0 49 dag = zeros(n);
wolffd@0 50
wolffd@0 51 args = varargin;
wolffd@0 52 nargs = length(args);
wolffd@0 53 for i=1:2:nargs
wolffd@0 54 switch args{i},
wolffd@0 55 case 'nsamples', nsamples = args{i+1};
wolffd@0 56 case 'burnin', burnin = args{i+1};
wolffd@0 57 case 'init_dag', dag = args{i+1};
wolffd@0 58 case 'scoring_fn', scoring_fn = args{i+1};
wolffd@0 59 case 'type', type = args{i+1};
wolffd@0 60 case 'discrete', discrete = args{i+1};
wolffd@0 61 case 'clamped', clamped = args{i+1};
wolffd@0 62 case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end
wolffd@0 63 end
wolffd@0 64 end
wolffd@0 65
wolffd@0 66 % We implement the fast acyclicity check described by P. Giudici and R. Castelo,
wolffd@0 67 % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001.
wolffd@0 68
wolffd@0 69 % SML: also keep descendant matrix C
wolffd@0 70 use_giudici = 1;
wolffd@0 71 if use_giudici
wolffd@0 72 [nbrs, ops, nodes, A] = mk_nbrs_of_digraph(dag);
wolffd@0 73 else
wolffd@0 74 [nbrs, ops, nodes] = mk_nbrs_of_dag(dag);
wolffd@0 75 A = [];
wolffd@0 76 end
wolffd@0 77
wolffd@0 78 num_accepts = 1;
wolffd@0 79 num_rejects = 1;
wolffd@0 80 T = burnin + nsamples;
wolffd@0 81 accept_ratio = zeros(1, T);
wolffd@0 82 num_edges = zeros(1, T);
wolffd@0 83 sampled_graphs = cell(1, nsamples);
wolffd@0 84 %sampled_bitv = zeros(nsamples, n^2);
wolffd@0 85
wolffd@0 86 for t=1:T
wolffd@0 87 [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, ...
wolffd@0 88 nodes, ns, data, clamped, A, ...
wolffd@0 89 scoring_fn, discrete, type, params);
wolffd@0 90 num_edges(t) = sum(dag(:));
wolffd@0 91 num_accepts = num_accepts + accept;
wolffd@0 92 num_rejects = num_rejects + (1-accept);
wolffd@0 93 accept_ratio(t) = num_accepts/num_rejects;
wolffd@0 94 if t > burnin
wolffd@0 95 sampled_graphs{t-burnin} = dag;
wolffd@0 96 %sampled_bitv(t-burnin, :) = dag(:)';
wolffd@0 97 end
wolffd@0 98 end
wolffd@0 99
wolffd@0 100
wolffd@0 101 %%%%%%%%%
wolffd@0 102
wolffd@0 103
wolffd@0 104 function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ...
wolffd@0 105 take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
wolffd@0 106 scoring_fn, discrete, type, params, prior_w)
wolffd@0 107
wolffd@0 108
wolffd@0 109 use_giudici = ~isempty(A);
wolffd@0 110 if use_giudici
wolffd@0 111 [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, nbrs, ops, nodes,A); % updates A
wolffd@0 112 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag, new_A);
wolffd@0 113 else
wolffd@0 114 d = sample_discrete(normalise(ones(1, length(nbrs))));
wolffd@0 115 new_dag = nbrs{d};
wolffd@0 116 op = ops{d};
wolffd@0 117 i = nodes(d, 1); j = nodes(d, 2);
wolffd@0 118 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag);
wolffd@0 119 end
wolffd@0 120
wolffd@0 121 bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params);
wolffd@0 122
wolffd@0 123 %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs));
wolffd@0 124 R = bf * (length(nbrs) / length(new_nbrs));
wolffd@0 125 u = rand(1,1);
wolffd@0 126 if u > min(1,R) % reject the move
wolffd@0 127 accept = 0;
wolffd@0 128 new_dag = dag;
wolffd@0 129 new_nbrs = nbrs;
wolffd@0 130 new_ops = ops;
wolffd@0 131 new_nodes = nodes;
wolffd@0 132 else
wolffd@0 133 accept = 1;
wolffd@0 134 if use_giudici
wolffd@0 135 A = new_A; % new_A already updated in pick_digraph_nbr
wolffd@0 136 end
wolffd@0 137 end
wolffd@0 138
wolffd@0 139
wolffd@0 140 %%%%%%%%%
wolffd@0 141
wolffd@0 142 function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params)
wolffd@0 143
wolffd@0 144 u = find(clamped(j,:)==0);
wolffd@0 145 LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
wolffd@0 146 LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
wolffd@0 147 bf1 = exp(LLnew - LLold);
wolffd@0 148
wolffd@0 149 if strcmp(op, 'rev') % must also multiply in the changes to i's family
wolffd@0 150 u = find(clamped(i,:)==0);
wolffd@0 151 LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
wolffd@0 152 LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
wolffd@0 153 bf2 = exp(LLnew - LLold);
wolffd@0 154 else
wolffd@0 155 bf2 = 1;
wolffd@0 156 end
wolffd@0 157 bfactor = bf1 * bf2;
wolffd@0 158
wolffd@0 159
wolffd@0 160 %%%%%%%% Giudici stuff follows %%%%%%%%%%
wolffd@0 161
wolffd@0 162
wolffd@0 163 % SML: This now updates A as it goes from digraph it choses
wolffd@0 164 function [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A)
wolffd@0 165
wolffd@0 166 d = sample_discrete(normalise(ones(1, length(digraph_nbrs))));
wolffd@0 167 %d = myunidrnd(length(digraph_nbrs),1,1);
wolffd@0 168 i = nodes(d, 1); j = nodes(d, 2);
wolffd@0 169 new_dag = digraph_nbrs(:,:,d);
wolffd@0 170 op = ops{d};
wolffd@0 171 new_A = update_ancestor_matrix(A, op, i, j, new_dag);
wolffd@0 172
wolffd@0 173
wolffd@0 174 %%%%%%%%%%%%%%
wolffd@0 175
wolffd@0 176
wolffd@0 177 function A = update_ancestor_matrix(A, op, i, j, dag)
wolffd@0 178
wolffd@0 179 switch op
wolffd@0 180 case 'add',
wolffd@0 181 A = do_addition(A, op, i, j, dag);
wolffd@0 182 case 'del',
wolffd@0 183 A = do_removal(A, op, i, j, dag);
wolffd@0 184 case 'rev',
wolffd@0 185 A = do_removal(A, op, i, j, dag);
wolffd@0 186 A = do_addition(A, op, j, i, dag);
wolffd@0 187 end
wolffd@0 188
wolffd@0 189
wolffd@0 190 %%%%%%%%%%%%
wolffd@0 191
wolffd@0 192 function A = do_addition(A, op, i, j, dag)
wolffd@0 193
wolffd@0 194 A(j,i) = 1; % i is an ancestor of j
wolffd@0 195 anci = find(A(i,:));
wolffd@0 196 if ~isempty(anci)
wolffd@0 197 A(j,anci) = 1; % all of i's ancestors are added to Anc(j)
wolffd@0 198 end
wolffd@0 199 ancj = find(A(j,:));
wolffd@0 200 descj = find(A(:,j));
wolffd@0 201 if ~isempty(ancj)
wolffd@0 202 for k=descj(:)'
wolffd@0 203 A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j
wolffd@0 204 end
wolffd@0 205 end
wolffd@0 206
wolffd@0 207 %%%%%%%%%%%
wolffd@0 208 function A = do_removal(A, op, i, j, dag)
wolffd@0 209
wolffd@0 210 % find all the descendants of j, and put them in topological order
wolffd@0 211
wolffd@0 212 % SML: originally Kevin had the next line commented and the %* lines
wolffd@0 213 % being used but I think this is equivalent and much less expensive
wolffd@0 214 % I assume he put it there for debugging and never changed it back...?
wolffd@0 215 descj = find(A(:,j));
wolffd@0 216 %* R = reachability_graph(dag);
wolffd@0 217 %* descj = find(R(j,:));
wolffd@0 218
wolffd@0 219 order = topological_sort(dag);
wolffd@0 220
wolffd@0 221 % SML: originally Kevin used the %* line but this was extracting the
wolffd@0 222 % wrong things to sort
wolffd@0 223 %* descj_topnum = order(descj);
wolffd@0 224 [junk, perm] = sort(order); %SML:node i is perm(i)-TH in order
wolffd@0 225 descj_topnum = perm(descj); %SML:descj(i) is descj_topnum(i)-th in order
wolffd@0 226
wolffd@0 227 % SML: now re-sort descj by rank in descj_topnum
wolffd@0 228 [junk, perm] = sort(descj_topnum);
wolffd@0 229 descj = descj(perm);
wolffd@0 230
wolffd@0 231 % Update j and all its descendants
wolffd@0 232 A = update_row(A, j, dag);
wolffd@0 233 for k = descj(:)'
wolffd@0 234 A = update_row(A, k, dag);
wolffd@0 235 end
wolffd@0 236
wolffd@0 237 %%%%%%%%%%%
wolffd@0 238
wolffd@0 239 function A = old_do_removal(A, op, i, j, dag)
wolffd@0 240
wolffd@0 241 % find all the descendants of j, and put them in topological order
wolffd@0 242 % SML: originally Kevin had the next line commented and the %* lines
wolffd@0 243 % being used but I think this is equivalent and much less expensive
wolffd@0 244 % I assume he put it there for debugging and never changed it back...?
wolffd@0 245 descj = find(A(:,j));
wolffd@0 246 %* R = reachability_graph(dag);
wolffd@0 247 %* descj = find(R(j,:));
wolffd@0 248
wolffd@0 249 order = topological_sort(dag);
wolffd@0 250 descj_topnum = order(descj);
wolffd@0 251 [junk, perm] = sort(descj_topnum);
wolffd@0 252 descj = descj(perm);
wolffd@0 253 % Update j and all its descendants
wolffd@0 254 A = update_row(A, j, dag);
wolffd@0 255 for k = descj(:)'
wolffd@0 256 A = update_row(A, k, dag);
wolffd@0 257 end
wolffd@0 258
wolffd@0 259 %%%%%%%%%
wolffd@0 260
wolffd@0 261 function A = update_row(A, j, dag)
wolffd@0 262
wolffd@0 263 % We compute row j of A
wolffd@0 264 A(j, :) = 0;
wolffd@0 265 ps = parents(dag, j);
wolffd@0 266 if ~isempty(ps)
wolffd@0 267 A(j, ps) = 1;
wolffd@0 268 end
wolffd@0 269 for k=ps(:)'
wolffd@0 270 anck = find(A(k,:));
wolffd@0 271 if ~isempty(anck)
wolffd@0 272 A(j, anck) = 1;
wolffd@0 273 end
wolffd@0 274 end
wolffd@0 275
wolffd@0 276 %%%%%%%%
wolffd@0 277
wolffd@0 278 function A = init_ancestor_matrix(dag)
wolffd@0 279
wolffd@0 280 order = topological_sort(dag);
wolffd@0 281 A = zeros(length(dag));
wolffd@0 282 for j=order(:)'
wolffd@0 283 A = update_row(A, j, dag);
wolffd@0 284 end