annotate _FullBNT/BNT/learning/kpm_learn_struct_mcmc.m @ 9:4ea6619cb3f5 tip

removed log files
author matthiasm
date Fri, 11 Apr 2014 15:55:11 +0100
parents b5b38998ef3b
children
rev   line source
matthiasm@8 1 function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin)
matthiasm@8 2 % LEARN_STRUCT_MCMC Monte Carla Markov Chain search over DAGs assuming fully observed data
matthiasm@8 3 % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...)
matthiasm@8 4 %
matthiasm@8 5 % data(i,m) is the value of node i in case m.
matthiasm@8 6 % ns(i) is the number of discrete values node i can take on.
matthiasm@8 7 %
matthiasm@8 8 % sampled_graphs{m} is the m'th sampled graph.
matthiasm@8 9 % accept_ratio(t) = acceptance ratio at iteration t
matthiasm@8 10 % num_edges(t) = number of edges in model at iteration t
matthiasm@8 11 %
matthiasm@8 12 % The following optional arguments can be specified in the form of name/value pairs:
matthiasm@8 13 % [default value in brackets]
matthiasm@8 14 %
matthiasm@8 15 % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ]
matthiasm@8 16 % Currently, only networks with all tabular nodes support Bayesian scoring.
matthiasm@8 17 % type - type{i} is the type of CPD to use for node i, where the type is a string
matthiasm@8 18 % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ]
matthiasm@8 19 % params - params{i} contains optional arguments passed to the CPD constructor for node i,
matthiasm@8 20 % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ]
matthiasm@8 21 % discrete - the list of discrete nodes [ 1:N ]
matthiasm@8 22 % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ]
matthiasm@8 23 % nsamples - number of samples to draw from the chain after burn-in [ 100*N ]
matthiasm@8 24 % burnin - number of steps to take before drawing samples [ 5*N ]
matthiasm@8 25 % init_dag - starting point for the search [ zeros(N,N) ]
matthiasm@8 26 %
matthiasm@8 27 % e.g., samples = learn_struct_mcmc(data, ns, 'nsamples', 1000);
matthiasm@8 28 %
matthiasm@8 29 % This interface is not backwards compatible with BNT2,
matthiasm@8 30 % but is designed to be compatible with the other learn_struct_xxx routines.
matthiasm@8 31 %
matthiasm@8 32 % Note: We currently assume a uniform structural prior.
matthiasm@8 33
matthiasm@8 34 [n ncases] = size(data);
matthiasm@8 35
matthiasm@8 36
matthiasm@8 37 % set default params
matthiasm@8 38 type = cell(1,n);
matthiasm@8 39 params = cell(1,n);
matthiasm@8 40 for i=1:n
matthiasm@8 41 type{i} = 'tabular';
matthiasm@8 42 %params{i} = { 'prior', 1 };
matthiasm@8 43 params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 };
matthiasm@8 44 end
matthiasm@8 45 scoring_fn = 'bayesian';
matthiasm@8 46 discrete = 1:n;
matthiasm@8 47 clamped = zeros(n, ncases);
matthiasm@8 48 nsamples = 100*n;
matthiasm@8 49 burnin = 5*n;
matthiasm@8 50 dag = zeros(n);
matthiasm@8 51
matthiasm@8 52 args = varargin;
matthiasm@8 53 nargs = length(args);
matthiasm@8 54 for i=1:2:nargs
matthiasm@8 55 switch args{i},
matthiasm@8 56 case 'nsamples', nsamples = args{i+1};
matthiasm@8 57 case 'burnin', burnin = args{i+1};
matthiasm@8 58 case 'init_dag', dag = args{i+1};
matthiasm@8 59 case 'scoring_fn', scoring_fn = args{i+1};
matthiasm@8 60 case 'type', type = args{i+1};
matthiasm@8 61 case 'discrete', discrete = args{i+1};
matthiasm@8 62 case 'clamped', clamped = args{i+1};
matthiasm@8 63 case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end
matthiasm@8 64 end
matthiasm@8 65 end
matthiasm@8 66
matthiasm@8 67 % We implement the fast acyclicity check described by P. Giudici and R. Castelo,
matthiasm@8 68 % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001.
matthiasm@8 69 use_giudici = 1;
matthiasm@8 70 if use_giudici
matthiasm@8 71 [nbrs, ops, nodes] = mk_nbrs_of_digraph(dag);
matthiasm@8 72 A = init_ancestor_matrix(dag);
matthiasm@8 73 else
matthiasm@8 74 [nbrs, ops, nodes] = mk_nbrs_of_dag(dag);
matthiasm@8 75 A = [];
matthiasm@8 76 end
matthiasm@8 77
matthiasm@8 78 num_accepts = 1;
matthiasm@8 79 num_rejects = 1;
matthiasm@8 80 T = burnin + nsamples;
matthiasm@8 81 accept_ratio = zeros(1, T);
matthiasm@8 82 num_edges = zeros(1, T);
matthiasm@8 83 sampled_graphs = cell(1, nsamples);
matthiasm@8 84 %sampled_bitv = zeros(nsamples, n^2);
matthiasm@8 85
matthiasm@8 86 for t=1:T
matthiasm@8 87 [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
matthiasm@8 88 scoring_fn, discrete, type, params);
matthiasm@8 89 num_edges(t) = sum(dag(:));
matthiasm@8 90 num_accepts = num_accepts + accept;
matthiasm@8 91 num_rejects = num_rejects + (1-accept);
matthiasm@8 92 accept_ratio(t) = num_accepts/num_rejects;
matthiasm@8 93 if t > burnin
matthiasm@8 94 sampled_graphs{t-burnin} = dag;
matthiasm@8 95 %sampled_bitv(t-burnin, :) = dag(:)';
matthiasm@8 96 end
matthiasm@8 97 end
matthiasm@8 98
matthiasm@8 99
matthiasm@8 100 %%%%%%%%%
matthiasm@8 101
matthiasm@8 102
matthiasm@8 103 function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ...
matthiasm@8 104 take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
matthiasm@8 105 scoring_fn, discrete, type, params)
matthiasm@8 106
matthiasm@8 107
matthiasm@8 108 use_giudici = ~isempty(A);
matthiasm@8 109 if use_giudici
matthiasm@8 110 [new_dag, op, i, j] = pick_digraph_nbr(dag, nbrs, ops, nodes, A);
matthiasm@8 111 %assert(acyclic(new_dag));
matthiasm@8 112 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag);
matthiasm@8 113 else
matthiasm@8 114 d = sample_discrete(normalise(ones(1, length(nbrs))));
matthiasm@8 115 new_dag = nbrs{d};
matthiasm@8 116 op = ops{d};
matthiasm@8 117 i = nodes(d, 1); j = nodes(d, 2);
matthiasm@8 118 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag);
matthiasm@8 119 end
matthiasm@8 120
matthiasm@8 121 bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params);
matthiasm@8 122
matthiasm@8 123 %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs));
matthiasm@8 124 R = bf * (length(nbrs) / length(new_nbrs));
matthiasm@8 125 u = rand(1,1);
matthiasm@8 126 if u > min(1,R) % reject the move
matthiasm@8 127 accept = 0;
matthiasm@8 128 new_dag = dag;
matthiasm@8 129 new_nbrs = nbrs;
matthiasm@8 130 new_ops = ops;
matthiasm@8 131 new_nodes = nodes;
matthiasm@8 132 else
matthiasm@8 133 accept = 1;
matthiasm@8 134 if use_giudici
matthiasm@8 135 A = update_ancestor_matrix(A, op, i, j, new_dag);
matthiasm@8 136 end
matthiasm@8 137 end
matthiasm@8 138
matthiasm@8 139
matthiasm@8 140 %%%%%%%%%
matthiasm@8 141
matthiasm@8 142 function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params)
matthiasm@8 143
matthiasm@8 144 u = find(clamped(j,:)==0);
matthiasm@8 145 LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
matthiasm@8 146 LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
matthiasm@8 147 bf1 = exp(LLnew - LLold);
matthiasm@8 148
matthiasm@8 149 if strcmp(op, 'rev') % must also multiply in the changes to i's family
matthiasm@8 150 u = find(clamped(i,:)==0);
matthiasm@8 151 LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
matthiasm@8 152 LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
matthiasm@8 153 bf2 = exp(LLnew - LLold);
matthiasm@8 154 else
matthiasm@8 155 bf2 = 1;
matthiasm@8 156 end
matthiasm@8 157 bfactor = bf1 * bf2;
matthiasm@8 158
matthiasm@8 159
matthiasm@8 160 %%%%%%%% Giudici stuff follows %%%%%%%%%%
matthiasm@8 161
matthiasm@8 162
matthiasm@8 163 function [new_dag, op, i, j] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A)
matthiasm@8 164
matthiasm@8 165 legal = 0;
matthiasm@8 166 while ~legal
matthiasm@8 167 d = sample_discrete(normalise(ones(1, length(digraph_nbrs))));
matthiasm@8 168 i = nodes(d, 1); j = nodes(d, 2);
matthiasm@8 169 switch ops{d}
matthiasm@8 170 case 'add',
matthiasm@8 171 if A(i,j)==0
matthiasm@8 172 legal = 1;
matthiasm@8 173 end
matthiasm@8 174 case 'del',
matthiasm@8 175 legal = 1;
matthiasm@8 176 case 'rev',
matthiasm@8 177 ps = mysetdiff(parents(dag, j), i);
matthiasm@8 178 % if any(A(ps,i)) then there is a path i -> parent of j -> j
matthiasm@8 179 % so reversing i->j would create a cycle
matthiasm@8 180 legal = ~any(A(ps, i));
matthiasm@8 181 end
matthiasm@8 182 end
matthiasm@8 183 %new_dag = digraph_nbrs{d};
matthiasm@8 184 new_dag = digraph_nbrs(:,:,d);
matthiasm@8 185 op = ops{d};
matthiasm@8 186 i = nodes(d, 1); j = nodes(d, 2);
matthiasm@8 187
matthiasm@8 188
matthiasm@8 189 %%%%%%%%%%%%%%
matthiasm@8 190
matthiasm@8 191
matthiasm@8 192 function A = update_ancestor_matrix(A, op, i, j, dag)
matthiasm@8 193
matthiasm@8 194 switch op
matthiasm@8 195 case 'add',
matthiasm@8 196 A = do_addition(A, op, i, j, dag);
matthiasm@8 197 case 'del',
matthiasm@8 198 A = do_removal(A, op, i, j, dag);
matthiasm@8 199 case 'rev',
matthiasm@8 200 A = do_removal(A, op, i, j, dag);
matthiasm@8 201 A = do_addition(A, op, j, i, dag);
matthiasm@8 202 end
matthiasm@8 203
matthiasm@8 204
matthiasm@8 205 %%%%%%%%%%%%
matthiasm@8 206
matthiasm@8 207 function A = do_addition(A, op, i, j, dag)
matthiasm@8 208
matthiasm@8 209 A(j,i) = 1; % i is an ancestor of j
matthiasm@8 210 anci = find(A(i,:));
matthiasm@8 211 if ~isempty(anci)
matthiasm@8 212 A(j,anci) = 1; % all of i's ancestors are added to Anc(j)
matthiasm@8 213 end
matthiasm@8 214 ancj = find(A(j,:));
matthiasm@8 215 descj = find(A(:,j));
matthiasm@8 216 if ~isempty(ancj)
matthiasm@8 217 for k=descj(:)'
matthiasm@8 218 A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j
matthiasm@8 219 end
matthiasm@8 220 end
matthiasm@8 221
matthiasm@8 222 %%%%%%%%%%%
matthiasm@8 223
matthiasm@8 224 function A = do_removal(A, op, i, j, dag)
matthiasm@8 225
matthiasm@8 226 % find all the descendants of j, and put them in topological order
matthiasm@8 227 %descj = find(A(:,j));
matthiasm@8 228 R = reachability_graph(dag);
matthiasm@8 229 descj = find(R(j,:));
matthiasm@8 230 order = topological_sort(dag);
matthiasm@8 231 descj_topnum = order(descj);
matthiasm@8 232 [junk, perm] = sort(descj_topnum);
matthiasm@8 233 descj = descj(perm);
matthiasm@8 234 % Update j and all its descendants
matthiasm@8 235 A = update_row(A, j, dag);
matthiasm@8 236 for k = descj(:)'
matthiasm@8 237 A = update_row(A, k, dag);
matthiasm@8 238 end
matthiasm@8 239
matthiasm@8 240 %%%%%%%%%
matthiasm@8 241
matthiasm@8 242 function A = update_row(A, j, dag)
matthiasm@8 243
matthiasm@8 244 % We compute row j of A
matthiasm@8 245 A(j, :) = 0;
matthiasm@8 246 ps = parents(dag, j);
matthiasm@8 247 if ~isempty(ps)
matthiasm@8 248 A(j, ps) = 1;
matthiasm@8 249 end
matthiasm@8 250 for k=ps(:)'
matthiasm@8 251 anck = find(A(k,:));
matthiasm@8 252 if ~isempty(anck)
matthiasm@8 253 A(j, anck) = 1;
matthiasm@8 254 end
matthiasm@8 255 end
matthiasm@8 256
matthiasm@8 257 %%%%%%%%
matthiasm@8 258
matthiasm@8 259 function A = init_ancestor_matrix(dag)
matthiasm@8 260
matthiasm@8 261 order = topological_sort(dag);
matthiasm@8 262 A = zeros(length(dag));
matthiasm@8 263 for j=order(:)'
matthiasm@8 264 A = update_row(A, j, dag);
matthiasm@8 265 end