annotate toolboxes/FullBNT-1.0.7/bnt/learning/kpm_learn_struct_mcmc.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
rev   line source
Daniel@0 1 function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin)
Daniel@0 2 % LEARN_STRUCT_MCMC Monte Carla Markov Chain search over DAGs assuming fully observed data
Daniel@0 3 % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...)
Daniel@0 4 %
Daniel@0 5 % data(i,m) is the value of node i in case m.
Daniel@0 6 % ns(i) is the number of discrete values node i can take on.
Daniel@0 7 %
Daniel@0 8 % sampled_graphs{m} is the m'th sampled graph.
Daniel@0 9 % accept_ratio(t) = acceptance ratio at iteration t
Daniel@0 10 % num_edges(t) = number of edges in model at iteration t
Daniel@0 11 %
Daniel@0 12 % The following optional arguments can be specified in the form of name/value pairs:
Daniel@0 13 % [default value in brackets]
Daniel@0 14 %
Daniel@0 15 % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ]
Daniel@0 16 % Currently, only networks with all tabular nodes support Bayesian scoring.
Daniel@0 17 % type - type{i} is the type of CPD to use for node i, where the type is a string
Daniel@0 18 % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ]
Daniel@0 19 % params - params{i} contains optional arguments passed to the CPD constructor for node i,
Daniel@0 20 % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ]
Daniel@0 21 % discrete - the list of discrete nodes [ 1:N ]
Daniel@0 22 % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ]
Daniel@0 23 % nsamples - number of samples to draw from the chain after burn-in [ 100*N ]
Daniel@0 24 % burnin - number of steps to take before drawing samples [ 5*N ]
Daniel@0 25 % init_dag - starting point for the search [ zeros(N,N) ]
Daniel@0 26 %
Daniel@0 27 % e.g., samples = learn_struct_mcmc(data, ns, 'nsamples', 1000);
Daniel@0 28 %
Daniel@0 29 % This interface is not backwards compatible with BNT2,
Daniel@0 30 % but is designed to be compatible with the other learn_struct_xxx routines.
Daniel@0 31 %
Daniel@0 32 % Note: We currently assume a uniform structural prior.
Daniel@0 33
Daniel@0 34 [n ncases] = size(data);
Daniel@0 35
Daniel@0 36
Daniel@0 37 % set default params
Daniel@0 38 type = cell(1,n);
Daniel@0 39 params = cell(1,n);
Daniel@0 40 for i=1:n
Daniel@0 41 type{i} = 'tabular';
Daniel@0 42 %params{i} = { 'prior', 1 };
Daniel@0 43 params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 };
Daniel@0 44 end
Daniel@0 45 scoring_fn = 'bayesian';
Daniel@0 46 discrete = 1:n;
Daniel@0 47 clamped = zeros(n, ncases);
Daniel@0 48 nsamples = 100*n;
Daniel@0 49 burnin = 5*n;
Daniel@0 50 dag = zeros(n);
Daniel@0 51
Daniel@0 52 args = varargin;
Daniel@0 53 nargs = length(args);
Daniel@0 54 for i=1:2:nargs
Daniel@0 55 switch args{i},
Daniel@0 56 case 'nsamples', nsamples = args{i+1};
Daniel@0 57 case 'burnin', burnin = args{i+1};
Daniel@0 58 case 'init_dag', dag = args{i+1};
Daniel@0 59 case 'scoring_fn', scoring_fn = args{i+1};
Daniel@0 60 case 'type', type = args{i+1};
Daniel@0 61 case 'discrete', discrete = args{i+1};
Daniel@0 62 case 'clamped', clamped = args{i+1};
Daniel@0 63 case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end
Daniel@0 64 end
Daniel@0 65 end
Daniel@0 66
Daniel@0 67 % We implement the fast acyclicity check described by P. Giudici and R. Castelo,
Daniel@0 68 % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001.
Daniel@0 69 use_giudici = 1;
Daniel@0 70 if use_giudici
Daniel@0 71 [nbrs, ops, nodes] = mk_nbrs_of_digraph(dag);
Daniel@0 72 A = init_ancestor_matrix(dag);
Daniel@0 73 else
Daniel@0 74 [nbrs, ops, nodes] = mk_nbrs_of_dag(dag);
Daniel@0 75 A = [];
Daniel@0 76 end
Daniel@0 77
Daniel@0 78 num_accepts = 1;
Daniel@0 79 num_rejects = 1;
Daniel@0 80 T = burnin + nsamples;
Daniel@0 81 accept_ratio = zeros(1, T);
Daniel@0 82 num_edges = zeros(1, T);
Daniel@0 83 sampled_graphs = cell(1, nsamples);
Daniel@0 84 %sampled_bitv = zeros(nsamples, n^2);
Daniel@0 85
Daniel@0 86 for t=1:T
Daniel@0 87 [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
Daniel@0 88 scoring_fn, discrete, type, params);
Daniel@0 89 num_edges(t) = sum(dag(:));
Daniel@0 90 num_accepts = num_accepts + accept;
Daniel@0 91 num_rejects = num_rejects + (1-accept);
Daniel@0 92 accept_ratio(t) = num_accepts/num_rejects;
Daniel@0 93 if t > burnin
Daniel@0 94 sampled_graphs{t-burnin} = dag;
Daniel@0 95 %sampled_bitv(t-burnin, :) = dag(:)';
Daniel@0 96 end
Daniel@0 97 end
Daniel@0 98
Daniel@0 99
Daniel@0 100 %%%%%%%%%
Daniel@0 101
Daniel@0 102
Daniel@0 103 function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ...
Daniel@0 104 take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
Daniel@0 105 scoring_fn, discrete, type, params)
Daniel@0 106
Daniel@0 107
Daniel@0 108 use_giudici = ~isempty(A);
Daniel@0 109 if use_giudici
Daniel@0 110 [new_dag, op, i, j] = pick_digraph_nbr(dag, nbrs, ops, nodes, A);
Daniel@0 111 %assert(acyclic(new_dag));
Daniel@0 112 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag);
Daniel@0 113 else
Daniel@0 114 d = sample_discrete(normalise(ones(1, length(nbrs))));
Daniel@0 115 new_dag = nbrs{d};
Daniel@0 116 op = ops{d};
Daniel@0 117 i = nodes(d, 1); j = nodes(d, 2);
Daniel@0 118 [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag);
Daniel@0 119 end
Daniel@0 120
Daniel@0 121 bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params);
Daniel@0 122
Daniel@0 123 %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs));
Daniel@0 124 R = bf * (length(nbrs) / length(new_nbrs));
Daniel@0 125 u = rand(1,1);
Daniel@0 126 if u > min(1,R) % reject the move
Daniel@0 127 accept = 0;
Daniel@0 128 new_dag = dag;
Daniel@0 129 new_nbrs = nbrs;
Daniel@0 130 new_ops = ops;
Daniel@0 131 new_nodes = nodes;
Daniel@0 132 else
Daniel@0 133 accept = 1;
Daniel@0 134 if use_giudici
Daniel@0 135 A = update_ancestor_matrix(A, op, i, j, new_dag);
Daniel@0 136 end
Daniel@0 137 end
Daniel@0 138
Daniel@0 139
Daniel@0 140 %%%%%%%%%
Daniel@0 141
Daniel@0 142 function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params)
Daniel@0 143
Daniel@0 144 u = find(clamped(j,:)==0);
Daniel@0 145 LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
Daniel@0 146 LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
Daniel@0 147 bf1 = exp(LLnew - LLold);
Daniel@0 148
Daniel@0 149 if strcmp(op, 'rev') % must also multiply in the changes to i's family
Daniel@0 150 u = find(clamped(i,:)==0);
Daniel@0 151 LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
Daniel@0 152 LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
Daniel@0 153 bf2 = exp(LLnew - LLold);
Daniel@0 154 else
Daniel@0 155 bf2 = 1;
Daniel@0 156 end
Daniel@0 157 bfactor = bf1 * bf2;
Daniel@0 158
Daniel@0 159
Daniel@0 160 %%%%%%%% Giudici stuff follows %%%%%%%%%%
Daniel@0 161
Daniel@0 162
Daniel@0 163 function [new_dag, op, i, j] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A)
Daniel@0 164
Daniel@0 165 legal = 0;
Daniel@0 166 while ~legal
Daniel@0 167 d = sample_discrete(normalise(ones(1, length(digraph_nbrs))));
Daniel@0 168 i = nodes(d, 1); j = nodes(d, 2);
Daniel@0 169 switch ops{d}
Daniel@0 170 case 'add',
Daniel@0 171 if A(i,j)==0
Daniel@0 172 legal = 1;
Daniel@0 173 end
Daniel@0 174 case 'del',
Daniel@0 175 legal = 1;
Daniel@0 176 case 'rev',
Daniel@0 177 ps = mysetdiff(parents(dag, j), i);
Daniel@0 178 % if any(A(ps,i)) then there is a path i -> parent of j -> j
Daniel@0 179 % so reversing i->j would create a cycle
Daniel@0 180 legal = ~any(A(ps, i));
Daniel@0 181 end
Daniel@0 182 end
Daniel@0 183 %new_dag = digraph_nbrs{d};
Daniel@0 184 new_dag = digraph_nbrs(:,:,d);
Daniel@0 185 op = ops{d};
Daniel@0 186 i = nodes(d, 1); j = nodes(d, 2);
Daniel@0 187
Daniel@0 188
Daniel@0 189 %%%%%%%%%%%%%%
Daniel@0 190
Daniel@0 191
Daniel@0 192 function A = update_ancestor_matrix(A, op, i, j, dag)
Daniel@0 193
Daniel@0 194 switch op
Daniel@0 195 case 'add',
Daniel@0 196 A = do_addition(A, op, i, j, dag);
Daniel@0 197 case 'del',
Daniel@0 198 A = do_removal(A, op, i, j, dag);
Daniel@0 199 case 'rev',
Daniel@0 200 A = do_removal(A, op, i, j, dag);
Daniel@0 201 A = do_addition(A, op, j, i, dag);
Daniel@0 202 end
Daniel@0 203
Daniel@0 204
Daniel@0 205 %%%%%%%%%%%%
Daniel@0 206
Daniel@0 207 function A = do_addition(A, op, i, j, dag)
Daniel@0 208
Daniel@0 209 A(j,i) = 1; % i is an ancestor of j
Daniel@0 210 anci = find(A(i,:));
Daniel@0 211 if ~isempty(anci)
Daniel@0 212 A(j,anci) = 1; % all of i's ancestors are added to Anc(j)
Daniel@0 213 end
Daniel@0 214 ancj = find(A(j,:));
Daniel@0 215 descj = find(A(:,j));
Daniel@0 216 if ~isempty(ancj)
Daniel@0 217 for k=descj(:)'
Daniel@0 218 A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j
Daniel@0 219 end
Daniel@0 220 end
Daniel@0 221
Daniel@0 222 %%%%%%%%%%%
Daniel@0 223
Daniel@0 224 function A = do_removal(A, op, i, j, dag)
Daniel@0 225
Daniel@0 226 % find all the descendants of j, and put them in topological order
Daniel@0 227 %descj = find(A(:,j));
Daniel@0 228 R = reachability_graph(dag);
Daniel@0 229 descj = find(R(j,:));
Daniel@0 230 order = topological_sort(dag);
Daniel@0 231 descj_topnum = order(descj);
Daniel@0 232 [junk, perm] = sort(descj_topnum);
Daniel@0 233 descj = descj(perm);
Daniel@0 234 % Update j and all its descendants
Daniel@0 235 A = update_row(A, j, dag);
Daniel@0 236 for k = descj(:)'
Daniel@0 237 A = update_row(A, k, dag);
Daniel@0 238 end
Daniel@0 239
Daniel@0 240 %%%%%%%%%
Daniel@0 241
Daniel@0 242 function A = update_row(A, j, dag)
Daniel@0 243
Daniel@0 244 % We compute row j of A
Daniel@0 245 A(j, :) = 0;
Daniel@0 246 ps = parents(dag, j);
Daniel@0 247 if ~isempty(ps)
Daniel@0 248 A(j, ps) = 1;
Daniel@0 249 end
Daniel@0 250 for k=ps(:)'
Daniel@0 251 anck = find(A(k,:));
Daniel@0 252 if ~isempty(anck)
Daniel@0 253 A(j, anck) = 1;
Daniel@0 254 end
Daniel@0 255 end
Daniel@0 256
Daniel@0 257 %%%%%%%%
Daniel@0 258
Daniel@0 259 function A = init_ancestor_matrix(dag)
Daniel@0 260
Daniel@0 261 order = topological_sort(dag);
Daniel@0 262 A = zeros(length(dag));
Daniel@0 263 for j=order(:)'
Daniel@0 264 A = update_row(A, j, dag);
Daniel@0 265 end