Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/learning/kpm_learn_struct_mcmc.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/learning/kpm_learn_struct_mcmc.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,265 @@ +function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin) +% LEARN_STRUCT_MCMC Monte Carla Markov Chain search over DAGs assuming fully observed data +% [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...) +% +% data(i,m) is the value of node i in case m. +% ns(i) is the number of discrete values node i can take on. +% +% sampled_graphs{m} is the m'th sampled graph. +% accept_ratio(t) = acceptance ratio at iteration t +% num_edges(t) = number of edges in model at iteration t +% +% The following optional arguments can be specified in the form of name/value pairs: +% [default value in brackets] +% +% scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ] +% Currently, only networks with all tabular nodes support Bayesian scoring. +% type - type{i} is the type of CPD to use for node i, where the type is a string +% of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ] +% params - params{i} contains optional arguments passed to the CPD constructor for node i, +% or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ] +% discrete - the list of discrete nodes [ 1:N ] +% clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ] +% nsamples - number of samples to draw from the chain after burn-in [ 100*N ] +% burnin - number of steps to take before drawing samples [ 5*N ] +% init_dag - starting point for the search [ zeros(N,N) ] +% +% e.g., samples = learn_struct_mcmc(data, ns, 'nsamples', 1000); +% +% This interface is not backwards compatible with BNT2, +% but is designed to be compatible with the other learn_struct_xxx routines. +% +% Note: We currently assume a uniform structural prior. + +[n ncases] = size(data); + + +% set default params +type = cell(1,n); +params = cell(1,n); +for i=1:n + type{i} = 'tabular'; + %params{i} = { 'prior', 1 }; + params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 }; +end +scoring_fn = 'bayesian'; +discrete = 1:n; +clamped = zeros(n, ncases); +nsamples = 100*n; +burnin = 5*n; +dag = zeros(n); + +args = varargin; +nargs = length(args); +for i=1:2:nargs + switch args{i}, + case 'nsamples', nsamples = args{i+1}; + case 'burnin', burnin = args{i+1}; + case 'init_dag', dag = args{i+1}; + case 'scoring_fn', scoring_fn = args{i+1}; + case 'type', type = args{i+1}; + case 'discrete', discrete = args{i+1}; + case 'clamped', clamped = args{i+1}; + case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end + end +end + +% We implement the fast acyclicity check described by P. Giudici and R. Castelo, +% "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001. +use_giudici = 1; +if use_giudici + [nbrs, ops, nodes] = mk_nbrs_of_digraph(dag); + A = init_ancestor_matrix(dag); +else + [nbrs, ops, nodes] = mk_nbrs_of_dag(dag); + A = []; +end + +num_accepts = 1; +num_rejects = 1; +T = burnin + nsamples; +accept_ratio = zeros(1, T); +num_edges = zeros(1, T); +sampled_graphs = cell(1, nsamples); +%sampled_bitv = zeros(nsamples, n^2); + +for t=1:T + [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ... + scoring_fn, discrete, type, params); + num_edges(t) = sum(dag(:)); + num_accepts = num_accepts + accept; + num_rejects = num_rejects + (1-accept); + accept_ratio(t) = num_accepts/num_rejects; + if t > burnin + sampled_graphs{t-burnin} = dag; + %sampled_bitv(t-burnin, :) = dag(:)'; + end +end + + +%%%%%%%%% + + +function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ... + take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ... + scoring_fn, discrete, type, params) + + +use_giudici = ~isempty(A); +if use_giudici + [new_dag, op, i, j] = pick_digraph_nbr(dag, nbrs, ops, nodes, A); + %assert(acyclic(new_dag)); + [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag); +else + d = sample_discrete(normalise(ones(1, length(nbrs)))); + new_dag = nbrs{d}; + op = ops{d}; + i = nodes(d, 1); j = nodes(d, 2); + [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag); +end + +bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params); + +%R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs)); +R = bf * (length(nbrs) / length(new_nbrs)); +u = rand(1,1); +if u > min(1,R) % reject the move + accept = 0; + new_dag = dag; + new_nbrs = nbrs; + new_ops = ops; + new_nodes = nodes; +else + accept = 1; + if use_giudici + A = update_ancestor_matrix(A, op, i, j, new_dag); + end +end + + +%%%%%%%%% + +function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params) + +u = find(clamped(j,:)==0); +LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); +LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); +bf1 = exp(LLnew - LLold); + +if strcmp(op, 'rev') % must also multiply in the changes to i's family + u = find(clamped(i,:)==0); + LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); + LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); + bf2 = exp(LLnew - LLold); +else + bf2 = 1; +end +bfactor = bf1 * bf2; + + +%%%%%%%% Giudici stuff follows %%%%%%%%%% + + +function [new_dag, op, i, j] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A) + +legal = 0; +while ~legal + d = sample_discrete(normalise(ones(1, length(digraph_nbrs)))); + i = nodes(d, 1); j = nodes(d, 2); + switch ops{d} + case 'add', + if A(i,j)==0 + legal = 1; + end + case 'del', + legal = 1; + case 'rev', + ps = mysetdiff(parents(dag, j), i); + % if any(A(ps,i)) then there is a path i -> parent of j -> j + % so reversing i->j would create a cycle + legal = ~any(A(ps, i)); + end +end +%new_dag = digraph_nbrs{d}; +new_dag = digraph_nbrs(:,:,d); +op = ops{d}; +i = nodes(d, 1); j = nodes(d, 2); + + +%%%%%%%%%%%%%% + + +function A = update_ancestor_matrix(A, op, i, j, dag) + +switch op + case 'add', + A = do_addition(A, op, i, j, dag); + case 'del', + A = do_removal(A, op, i, j, dag); + case 'rev', + A = do_removal(A, op, i, j, dag); + A = do_addition(A, op, j, i, dag); +end + + +%%%%%%%%%%%% + +function A = do_addition(A, op, i, j, dag) + +A(j,i) = 1; % i is an ancestor of j +anci = find(A(i,:)); +if ~isempty(anci) + A(j,anci) = 1; % all of i's ancestors are added to Anc(j) +end +ancj = find(A(j,:)); +descj = find(A(:,j)); +if ~isempty(ancj) + for k=descj(:)' + A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j + end +end + +%%%%%%%%%%% + +function A = do_removal(A, op, i, j, dag) + +% find all the descendants of j, and put them in topological order +%descj = find(A(:,j)); +R = reachability_graph(dag); +descj = find(R(j,:)); +order = topological_sort(dag); +descj_topnum = order(descj); +[junk, perm] = sort(descj_topnum); +descj = descj(perm); +% Update j and all its descendants +A = update_row(A, j, dag); +for k = descj(:)' + A = update_row(A, k, dag); +end + +%%%%%%%%% + +function A = update_row(A, j, dag) + +% We compute row j of A +A(j, :) = 0; +ps = parents(dag, j); +if ~isempty(ps) + A(j, ps) = 1; +end +for k=ps(:)' + anck = find(A(k,:)); + if ~isempty(anck) + A(j, anck) = 1; + end +end + +%%%%%%%% + +function A = init_ancestor_matrix(dag) + +order = topological_sort(dag); +A = zeros(length(dag)); +for j=order(:)' + A = update_row(A, j, dag); +end