diff toolboxes/FullBNT-1.0.7/bnt/learning/kpm_learn_struct_mcmc.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/bnt/learning/kpm_learn_struct_mcmc.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,265 @@
+function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin)
+% LEARN_STRUCT_MCMC  Monte Carla Markov Chain search over DAGs assuming fully observed data
+% [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...)
+% 
+% data(i,m) is the value of node i in case m.
+% ns(i) is the number of discrete values node i can take on.
+%
+% sampled_graphs{m} is the m'th sampled graph.
+% accept_ratio(t) = acceptance ratio at iteration t
+% num_edges(t) = number of edges in model at iteration t
+%
+% The following optional arguments can be specified in the form of name/value pairs:
+% [default value in brackets]
+%
+% scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ]
+%              Currently, only networks with all tabular nodes support Bayesian scoring.
+% type       - type{i} is the type of CPD to use for node i, where the type is a string
+%              of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ]
+% params     - params{i} contains optional arguments passed to the CPD constructor for node i,
+%              or [] if none.  [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ]
+% discrete   - the list of discrete nodes [ 1:N ]
+% clamped    - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ]
+% nsamples   - number of samples to draw from the chain after burn-in [ 100*N ]
+% burnin     - number of steps to take before drawing samples [ 5*N ]
+% init_dag   - starting point for the search [ zeros(N,N) ]
+%
+% e.g., samples = learn_struct_mcmc(data, ns, 'nsamples', 1000);
+%
+% This interface is not backwards compatible with BNT2,
+% but is designed to be compatible with the other learn_struct_xxx routines.
+%
+% Note: We currently assume a uniform structural prior.
+
+[n ncases] = size(data);
+
+
+% set default params
+type = cell(1,n);
+params = cell(1,n);
+for i=1:n
+  type{i} = 'tabular';
+  %params{i} = { 'prior', 1 };
+  params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 };
+end
+scoring_fn = 'bayesian';
+discrete = 1:n;
+clamped = zeros(n, ncases);
+nsamples = 100*n;
+burnin = 5*n;
+dag = zeros(n);
+
+args = varargin;
+nargs = length(args);
+for i=1:2:nargs
+  switch args{i},
+   case 'nsamples',   nsamples = args{i+1};
+   case 'burnin',     burnin = args{i+1};
+   case 'init_dag',   dag = args{i+1};
+   case 'scoring_fn', scoring_fn = args{i+1};
+   case 'type',       type = args{i+1}; 
+   case 'discrete',   discrete = args{i+1}; 
+   case 'clamped',    clamped = args{i+1}; 
+   case 'params',     if isempty(args{i+1}), params = cell(1,n); else params = args{i+1};  end
+  end
+end
+
+% We implement the fast acyclicity check described by P. Giudici and R. Castelo,
+% "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001.
+use_giudici = 1;
+if use_giudici
+  [nbrs, ops, nodes] = mk_nbrs_of_digraph(dag);
+  A = init_ancestor_matrix(dag);
+else
+  [nbrs, ops, nodes] = mk_nbrs_of_dag(dag);
+  A = [];
+end
+
+num_accepts = 1;
+num_rejects = 1;
+T = burnin + nsamples;
+accept_ratio = zeros(1, T);
+num_edges = zeros(1, T);
+sampled_graphs = cell(1, nsamples);
+%sampled_bitv = zeros(nsamples, n^2);
+
+for t=1:T
+  [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
+						 scoring_fn, discrete, type, params);
+  num_edges(t) = sum(dag(:));
+  num_accepts = num_accepts + accept;
+  num_rejects = num_rejects + (1-accept);
+  accept_ratio(t) =  num_accepts/num_rejects;
+  if t > burnin
+    sampled_graphs{t-burnin} = dag;
+    %sampled_bitv(t-burnin, :) = dag(:)';
+  end
+end
+
+
+%%%%%%%%%
+
+
+function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ...
+    take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ...
+	      scoring_fn, discrete, type, params)
+
+
+use_giudici = ~isempty(A);
+if use_giudici
+  [new_dag, op, i, j] = pick_digraph_nbr(dag, nbrs, ops, nodes, A);
+  %assert(acyclic(new_dag));
+  [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag);
+else
+  d = sample_discrete(normalise(ones(1, length(nbrs))));
+  new_dag = nbrs{d};
+  op = ops{d};
+  i = nodes(d, 1); j = nodes(d, 2);
+  [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag);
+end
+
+bf =  bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params);
+
+%R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs)); 
+R = bf * (length(nbrs) / length(new_nbrs)); 
+u = rand(1,1);
+if u > min(1,R) % reject the move
+  accept = 0;
+  new_dag = dag;
+  new_nbrs = nbrs;
+  new_ops = ops;
+  new_nodes = nodes;
+else
+  accept = 1;
+  if use_giudici
+    A = update_ancestor_matrix(A, op, i, j, new_dag);
+  end
+end
+
+
+%%%%%%%%%
+
+function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params)
+
+u = find(clamped(j,:)==0);
+LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
+LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j});
+bf1 = exp(LLnew - LLold);
+
+if strcmp(op, 'rev')  % must also multiply in the changes to i's family
+  u = find(clamped(i,:)==0);
+  LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
+  LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i});
+  bf2 = exp(LLnew - LLold);
+else
+  bf2 = 1;
+end
+bfactor = bf1 * bf2;
+
+
+%%%%%%%% Giudici stuff follows %%%%%%%%%%
+
+
+function [new_dag, op, i, j] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A)
+
+legal = 0;
+while ~legal
+  d = sample_discrete(normalise(ones(1, length(digraph_nbrs))));
+  i = nodes(d, 1); j = nodes(d, 2);
+  switch ops{d}
+   case 'add',
+    if A(i,j)==0
+      legal = 1;
+    end
+   case 'del',
+    legal = 1;
+   case 'rev',
+    ps = mysetdiff(parents(dag, j), i);
+    % if any(A(ps,i)) then there is a path i -> parent of j -> j
+    % so reversing i->j would create a cycle
+    legal = ~any(A(ps, i));
+  end
+end
+%new_dag = digraph_nbrs{d};
+new_dag = digraph_nbrs(:,:,d);
+op = ops{d};
+i = nodes(d, 1); j = nodes(d, 2);
+
+
+%%%%%%%%%%%%%%
+
+
+function A = update_ancestor_matrix(A, op, i, j, dag)
+
+switch op
+ case 'add',
+  A = do_addition(A, op, i, j, dag);
+ case 'del', 
+  A = do_removal(A, op, i, j, dag);
+ case 'rev', 
+  A = do_removal(A, op, i, j, dag);
+  A = do_addition(A, op, j, i, dag);
+end
+
+  
+%%%%%%%%%%%%
+
+function A = do_addition(A, op, i, j, dag)
+
+A(j,i) = 1; % i is an ancestor of j
+anci = find(A(i,:));
+if ~isempty(anci)
+  A(j,anci) = 1; % all of i's ancestors are added to Anc(j)
+end
+ancj = find(A(j,:));
+descj = find(A(:,j)); 
+if ~isempty(ancj)
+  for k=descj(:)'
+    A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j
+  end
+end
+
+%%%%%%%%%%%
+
+function A = do_removal(A, op, i, j, dag)
+
+% find all the descendants of j, and put them in topological order
+%descj = find(A(:,j)); 
+R = reachability_graph(dag);
+descj = find(R(j,:)); 
+order = topological_sort(dag);
+descj_topnum = order(descj);
+[junk, perm] = sort(descj_topnum);
+descj = descj(perm);
+% Update j and all its descendants
+A = update_row(A, j, dag);
+for k = descj(:)'
+  A = update_row(A, k, dag);
+end
+
+%%%%%%%%%
+
+function A = update_row(A, j, dag)
+
+% We compute row j of A
+A(j, :) = 0;
+ps = parents(dag, j);
+if ~isempty(ps)
+  A(j, ps) = 1;
+end
+for k=ps(:)'
+  anck = find(A(k,:));
+  if ~isempty(anck)
+    A(j, anck) = 1;
+  end
+end
+  
+%%%%%%%%
+
+function A = init_ancestor_matrix(dag)
+
+order = topological_sort(dag);
+A = zeros(length(dag));
+for j=order(:)'
+  A = update_row(A, j, dag);
+end