annotate _FullBNT/BNT/learning/learn_struct_mcmc.m @ 9:4ea6619cb3f5 tip

removed log files
author matthiasm
date Fri, 11 Apr 2014 15:55:11 +0100
parents b5b38998ef3b
children
rev   line source
matthiasm@8 1 function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin)
matthiasm@8 2 % MY_LEARN_STRUCT_MCMC Monte Carlo Markov Chain search over DAGs assuming fully observed data
matthiasm@8 3 % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...)
matthiasm@8 4 %
matthiasm@8 5 % data(i,m) is the value of node i in case m.
matthiasm@8 6 % ns(i) is the number of discrete values node i can take on.
matthiasm@8 7 %
matthiasm@8 8 % sampled_graphs{m} is the m'th sampled graph.
matthiasm@8 9 % accept_ratio(t) = acceptance ratio at iteration t
matthiasm@8 10 % num_edges(t) = number of edges in model at iteration t
matthiasm@8 11 %
matthiasm@8 12 % The following optional arguments can be specified in the form of name/value pairs:
matthiasm@8 13 % [default value in brackets]
matthiasm@8 14 %
matthiasm@8 15 % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ]
matthiasm@8 16 % Currently, only networks with all tabular nodes support Bayesian scoring.
matthiasm@8 17 % type - type{i} is the type of CPD to use for node i, where the type is a string
matthiasm@8 18 % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ]
matthiasm@8 19 % params - params{i} contains optional arguments passed to the CPD constructor for node i,
matthiasm@8 20 % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ]
matthiasm@8 21 % discrete - the list of discrete nodes [ 1:N ]
matthiasm@8 22 % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ]
matthiasm@8 23 % nsamples - number of samples to draw from the chain after burn-in [ 100*N ]
matthiasm@8 24 % burnin - number of steps to take before drawing samples [ 5*N ]
matthiasm@8 25 % init_dag - starting point for the search [ zeros(N,N) ]
matthiasm@8 26 %
matthiasm@8 27 % e.g., samples = my_learn_struct_mcmc(data, ns, 'nsamples', 1000);
matthiasm@8 28 %
matthiasm@8 29 % Modified by Sonia Leach (SML) 2/4/02, 9/5/03
matthiasm@8 30
matthiasm@8 31
matthiasm@8 32
matthiasm@8 33 [n ncases] = size(data);
matthiasm@8 34
matthiasm@8 35
matthiasm@8 36 % set default params
matthiasm@8 37 type = cell(1,n);
matthiasm@8 38 params = cell(1,n);
matthiasm@8 39 for i=1:n
matthiasm@8 40 type{i} = 'tabular';
matthiasm@8 41 %params{i} = { 'prior', 1};
matthiasm@8 42 params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 };
matthiasm@8 43 end
matthiasm@8 44 scoring_fn = 'bayesian';
matthiasm@8 45 discrete = 1:n;
matthiasm@8 46 clamped = zeros(n, ncases);
matthiasm@8 47 nsamples = 100*n;
matthiasm@8 48 burnin = 5*n;
matthiasm@8 49 dag = zeros(n);
matthiasm@8 50
matthiasm@8 51 args = varargin;
matthiasm@8 52 nargs = length(args);
matthiasm@8 53 for i=1:2:nargs
matthiasm@8 54 switch args{i},
matthiasm@8 55 case 'nsamples', nsamples = args{i+1};
matthiasm@8 56 case 'burnin', burnin = args{i+1};
matthiasm@8 57 case 'init_dag', dag = args{i+1};
matthiasm@8 58 case 'scoring_fn', scoring_fn = args{i+1};
matthiasm@8 59 case 'type', type = args{i+1};
matthiasm@8 60 case 'discrete', discrete = args{i+1};
matthiasm@8 61 case 'clamped', clamped = args{i+1};
matthiasm@8 62 case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end
matthiasm@8 63 end
matthiasm@8 64 end
matthiasm@8 65
matthiasm@8 66 % We implement the fast acyclicity check described by P. Giudici and R. Castelo,
matthiasm@8 67 % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001.
matthiasm@8 68
matthiasm@8 69 % SML: also keep descendant matrix C
matthiasm@8 70 use_giudici = 1;
matthiasm@8 71 if use_giudici
matthiasm@8 72 [nbrs, ops, nodes, A] = mk_nbrs_of_digraph(dag);
matthiasm@8 73 else
matthiasm@8 74 [nbrs, ops, nodes] = mk_nbrs_of_dag(dag);
matthiasm@8 75 A = [];
matthiasm@8 76 end
matthiasm@8 77
matthiasm@8 78 num_accepts = 1;
matthiasm@8 79 num_rejects = 1;
matthiasm@8 80 T = burnin + nsamples;
matthiasm@8 81 accept_ratio = zeros(1, T);
matthiasm@8 82 num_edges = zeros(1, T);
matthiasm@8 83 sampled_graphs = cell(1, nsamples);
matthiasm@8 84 %sampled_bitv = zeros(nsamples, n^2);
matthiasm@8 85
matthiasm@8 86 for t=1:T
matthiasm@8 87 [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, ...
matthiasm@8 88 nodes, ns, data, clamped, A, ...
matthiasm@8 89 scoring_fn, discrete, type, params);
matthiasm@8 90 num_edges(t) = sum(dag(:));
matthiasm@8 91 num_accepts = num_accepts + accept;
matthiasm@8 92 num_rejects = num_rejects + (1-accept);
matthiasm@8 93 accept_ratio(t) = num_accepts/num_rejects;
matthiasm@8 94 if t > burnin
matthiasm@8 95 sampled_graphs{t-burnin} = dag;
matthiasm@8 96 %sampled_bitv(t-burnin, :) = dag(:)';
matthiasm@8 97 end
matthiasm@8 98 end
matthiasm@8 99
matthiasm@8 100
matthiasm@8 101 %%%%%%%%%
matthiasm@8 102
matthiasm@8 103
matthiasm@8 104 function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ...
matthiasm@8 105 take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
matthiasm@8 106 scoring_fn, discrete, type, params, prior_w)
matthiasm@8 107
matthiasm@8 108
matthiasm@8 109 use_giudici = ~isempty(A);
matthiasm@8 110 if use_giudici
matthiasm@8 111 [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, nbrs, ops, nodes,A); % updates A
matthiasm@8 112 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag, new_A);
matthiasm@8 113 else
matthiasm@8 114 d = sample_discrete(normalise(ones(1, length(nbrs))));
matthiasm@8 115 new_dag = nbrs{d};
matthiasm@8 116 op = ops{d};
matthiasm@8 117 i = nodes(d, 1); j = nodes(d, 2);
matthiasm@8 118 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag);
matthiasm@8 119 end
matthiasm@8 120
matthiasm@8 121 bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params);
matthiasm@8 122
matthiasm@8 123 %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs));
matthiasm@8 124 R = bf * (length(nbrs) / length(new_nbrs));
matthiasm@8 125 u = rand(1,1);
matthiasm@8 126 if u > min(1,R) % reject the move
matthiasm@8 127 accept = 0;
matthiasm@8 128 new_dag = dag;
matthiasm@8 129 new_nbrs = nbrs;
matthiasm@8 130 new_ops = ops;
matthiasm@8 131 new_nodes = nodes;
matthiasm@8 132 else
matthiasm@8 133 accept = 1;
matthiasm@8 134 if use_giudici
matthiasm@8 135 A = new_A; % new_A already updated in pick_digraph_nbr
matthiasm@8 136 end
matthiasm@8 137 end
matthiasm@8 138
matthiasm@8 139
matthiasm@8 140 %%%%%%%%%
matthiasm@8 141
matthiasm@8 142 function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params)
matthiasm@8 143
matthiasm@8 144 u = find(clamped(j,:)==0);
matthiasm@8 145 LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
matthiasm@8 146 LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
matthiasm@8 147 bf1 = exp(LLnew - LLold);
matthiasm@8 148
matthiasm@8 149 if strcmp(op, 'rev') % must also multiply in the changes to i's family
matthiasm@8 150 u = find(clamped(i,:)==0);
matthiasm@8 151 LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
matthiasm@8 152 LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
matthiasm@8 153 bf2 = exp(LLnew - LLold);
matthiasm@8 154 else
matthiasm@8 155 bf2 = 1;
matthiasm@8 156 end
matthiasm@8 157 bfactor = bf1 * bf2;
matthiasm@8 158
matthiasm@8 159
matthiasm@8 160 %%%%%%%% Giudici stuff follows %%%%%%%%%%
matthiasm@8 161
matthiasm@8 162
matthiasm@8 163 % SML: This now updates A as it goes from digraph it choses
matthiasm@8 164 function [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A)
matthiasm@8 165
matthiasm@8 166 d = sample_discrete(normalise(ones(1, length(digraph_nbrs))));
matthiasm@8 167 %d = myunidrnd(length(digraph_nbrs),1,1);
matthiasm@8 168 i = nodes(d, 1); j = nodes(d, 2);
matthiasm@8 169 new_dag = digraph_nbrs(:,:,d);
matthiasm@8 170 op = ops{d};
matthiasm@8 171 new_A = update_ancestor_matrix(A, op, i, j, new_dag);
matthiasm@8 172
matthiasm@8 173
matthiasm@8 174 %%%%%%%%%%%%%%
matthiasm@8 175
matthiasm@8 176
matthiasm@8 177 function A = update_ancestor_matrix(A, op, i, j, dag)
matthiasm@8 178
matthiasm@8 179 switch op
matthiasm@8 180 case 'add',
matthiasm@8 181 A = do_addition(A, op, i, j, dag);
matthiasm@8 182 case 'del',
matthiasm@8 183 A = do_removal(A, op, i, j, dag);
matthiasm@8 184 case 'rev',
matthiasm@8 185 A = do_removal(A, op, i, j, dag);
matthiasm@8 186 A = do_addition(A, op, j, i, dag);
matthiasm@8 187 end
matthiasm@8 188
matthiasm@8 189
matthiasm@8 190 %%%%%%%%%%%%
matthiasm@8 191
matthiasm@8 192 function A = do_addition(A, op, i, j, dag)
matthiasm@8 193
matthiasm@8 194 A(j,i) = 1; % i is an ancestor of j
matthiasm@8 195 anci = find(A(i,:));
matthiasm@8 196 if ~isempty(anci)
matthiasm@8 197 A(j,anci) = 1; % all of i's ancestors are added to Anc(j)
matthiasm@8 198 end
matthiasm@8 199 ancj = find(A(j,:));
matthiasm@8 200 descj = find(A(:,j));
matthiasm@8 201 if ~isempty(ancj)
matthiasm@8 202 for k=descj(:)'
matthiasm@8 203 A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j
matthiasm@8 204 end
matthiasm@8 205 end
matthiasm@8 206
matthiasm@8 207 %%%%%%%%%%%
matthiasm@8 208 function A = do_removal(A, op, i, j, dag)
matthiasm@8 209
matthiasm@8 210 % find all the descendants of j, and put them in topological order
matthiasm@8 211
matthiasm@8 212 % SML: originally Kevin had the next line commented and the %* lines
matthiasm@8 213 % being used but I think this is equivalent and much less expensive
matthiasm@8 214 % I assume he put it there for debugging and never changed it back...?
matthiasm@8 215 descj = find(A(:,j));
matthiasm@8 216 %* R = reachability_graph(dag);
matthiasm@8 217 %* descj = find(R(j,:));
matthiasm@8 218
matthiasm@8 219 order = topological_sort(dag);
matthiasm@8 220
matthiasm@8 221 % SML: originally Kevin used the %* line but this was extracting the
matthiasm@8 222 % wrong things to sort
matthiasm@8 223 %* descj_topnum = order(descj);
matthiasm@8 224 [junk, perm] = sort(order); %SML:node i is perm(i)-TH in order
matthiasm@8 225 descj_topnum = perm(descj); %SML:descj(i) is descj_topnum(i)-th in order
matthiasm@8 226
matthiasm@8 227 % SML: now re-sort descj by rank in descj_topnum
matthiasm@8 228 [junk, perm] = sort(descj_topnum);
matthiasm@8 229 descj = descj(perm);
matthiasm@8 230
matthiasm@8 231 % Update j and all its descendants
matthiasm@8 232 A = update_row(A, j, dag);
matthiasm@8 233 for k = descj(:)'
matthiasm@8 234 A = update_row(A, k, dag);
matthiasm@8 235 end
matthiasm@8 236
matthiasm@8 237 %%%%%%%%%%%
matthiasm@8 238
matthiasm@8 239 function A = old_do_removal(A, op, i, j, dag)
matthiasm@8 240
matthiasm@8 241 % find all the descendants of j, and put them in topological order
matthiasm@8 242 % SML: originally Kevin had the next line commented and the %* lines
matthiasm@8 243 % being used but I think this is equivalent and much less expensive
matthiasm@8 244 % I assume he put it there for debugging and never changed it back...?
matthiasm@8 245 descj = find(A(:,j));
matthiasm@8 246 %* R = reachability_graph(dag);
matthiasm@8 247 %* descj = find(R(j,:));
matthiasm@8 248
matthiasm@8 249 order = topological_sort(dag);
matthiasm@8 250 descj_topnum = order(descj);
matthiasm@8 251 [junk, perm] = sort(descj_topnum);
matthiasm@8 252 descj = descj(perm);
matthiasm@8 253 % Update j and all its descendants
matthiasm@8 254 A = update_row(A, j, dag);
matthiasm@8 255 for k = descj(:)'
matthiasm@8 256 A = update_row(A, k, dag);
matthiasm@8 257 end
matthiasm@8 258
matthiasm@8 259 %%%%%%%%%
matthiasm@8 260
matthiasm@8 261 function A = update_row(A, j, dag)
matthiasm@8 262
matthiasm@8 263 % We compute row j of A
matthiasm@8 264 A(j, :) = 0;
matthiasm@8 265 ps = parents(dag, j);
matthiasm@8 266 if ~isempty(ps)
matthiasm@8 267 A(j, ps) = 1;
matthiasm@8 268 end
matthiasm@8 269 for k=ps(:)'
matthiasm@8 270 anck = find(A(k,:));
matthiasm@8 271 if ~isempty(anck)
matthiasm@8 272 A(j, anck) = 1;
matthiasm@8 273 end
matthiasm@8 274 end
matthiasm@8 275
matthiasm@8 276 %%%%%%%%
matthiasm@8 277
matthiasm@8 278 function A = init_ancestor_matrix(dag)
matthiasm@8 279
matthiasm@8 280 order = topological_sort(dag);
matthiasm@8 281 A = zeros(length(dag));
matthiasm@8 282 for j=order(:)'
matthiasm@8 283 A = update_row(A, j, dag);
matthiasm@8 284 end