Mercurial > hg > camir-aes2014
view toolboxes/FullBNT-1.0.7/bnt/learning/kpm_learn_struct_mcmc.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line source
function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin) % LEARN_STRUCT_MCMC Monte Carla Markov Chain search over DAGs assuming fully observed data % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...) % % data(i,m) is the value of node i in case m. % ns(i) is the number of discrete values node i can take on. % % sampled_graphs{m} is the m'th sampled graph. % accept_ratio(t) = acceptance ratio at iteration t % num_edges(t) = number of edges in model at iteration t % % The following optional arguments can be specified in the form of name/value pairs: % [default value in brackets] % % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ] % Currently, only networks with all tabular nodes support Bayesian scoring. % type - type{i} is the type of CPD to use for node i, where the type is a string % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ] % params - params{i} contains optional arguments passed to the CPD constructor for node i, % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ] % discrete - the list of discrete nodes [ 1:N ] % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ] % nsamples - number of samples to draw from the chain after burn-in [ 100*N ] % burnin - number of steps to take before drawing samples [ 5*N ] % init_dag - starting point for the search [ zeros(N,N) ] % % e.g., samples = learn_struct_mcmc(data, ns, 'nsamples', 1000); % % This interface is not backwards compatible with BNT2, % but is designed to be compatible with the other learn_struct_xxx routines. % % Note: We currently assume a uniform structural prior. [n ncases] = size(data); % set default params type = cell(1,n); params = cell(1,n); for i=1:n type{i} = 'tabular'; %params{i} = { 'prior', 1 }; params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 }; end scoring_fn = 'bayesian'; discrete = 1:n; clamped = zeros(n, ncases); nsamples = 100*n; burnin = 5*n; dag = zeros(n); args = varargin; nargs = length(args); for i=1:2:nargs switch args{i}, case 'nsamples', nsamples = args{i+1}; case 'burnin', burnin = args{i+1}; case 'init_dag', dag = args{i+1}; case 'scoring_fn', scoring_fn = args{i+1}; case 'type', type = args{i+1}; case 'discrete', discrete = args{i+1}; case 'clamped', clamped = args{i+1}; case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end end end % We implement the fast acyclicity check described by P. Giudici and R. Castelo, % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001. use_giudici = 1; if use_giudici [nbrs, ops, nodes] = mk_nbrs_of_digraph(dag); A = init_ancestor_matrix(dag); else [nbrs, ops, nodes] = mk_nbrs_of_dag(dag); A = []; end num_accepts = 1; num_rejects = 1; T = burnin + nsamples; accept_ratio = zeros(1, T); num_edges = zeros(1, T); sampled_graphs = cell(1, nsamples); %sampled_bitv = zeros(nsamples, n^2); for t=1:T [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ... scoring_fn, discrete, type, params); num_edges(t) = sum(dag(:)); num_accepts = num_accepts + accept; num_rejects = num_rejects + (1-accept); accept_ratio(t) = num_accepts/num_rejects; if t > burnin sampled_graphs{t-burnin} = dag; %sampled_bitv(t-burnin, :) = dag(:)'; end end %%%%%%%%% function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ... take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ... scoring_fn, discrete, type, params) use_giudici = ~isempty(A); if use_giudici [new_dag, op, i, j] = pick_digraph_nbr(dag, nbrs, ops, nodes, A); %assert(acyclic(new_dag)); [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag); else d = sample_discrete(normalise(ones(1, length(nbrs)))); new_dag = nbrs{d}; op = ops{d}; i = nodes(d, 1); j = nodes(d, 2); [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag); end bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params); %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs)); R = bf * (length(nbrs) / length(new_nbrs)); u = rand(1,1); if u > min(1,R) % reject the move accept = 0; new_dag = dag; new_nbrs = nbrs; new_ops = ops; new_nodes = nodes; else accept = 1; if use_giudici A = update_ancestor_matrix(A, op, i, j, new_dag); end end %%%%%%%%% function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params) u = find(clamped(j,:)==0); LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); bf1 = exp(LLnew - LLold); if strcmp(op, 'rev') % must also multiply in the changes to i's family u = find(clamped(i,:)==0); LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); bf2 = exp(LLnew - LLold); else bf2 = 1; end bfactor = bf1 * bf2; %%%%%%%% Giudici stuff follows %%%%%%%%%% function [new_dag, op, i, j] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A) legal = 0; while ~legal d = sample_discrete(normalise(ones(1, length(digraph_nbrs)))); i = nodes(d, 1); j = nodes(d, 2); switch ops{d} case 'add', if A(i,j)==0 legal = 1; end case 'del', legal = 1; case 'rev', ps = mysetdiff(parents(dag, j), i); % if any(A(ps,i)) then there is a path i -> parent of j -> j % so reversing i->j would create a cycle legal = ~any(A(ps, i)); end end %new_dag = digraph_nbrs{d}; new_dag = digraph_nbrs(:,:,d); op = ops{d}; i = nodes(d, 1); j = nodes(d, 2); %%%%%%%%%%%%%% function A = update_ancestor_matrix(A, op, i, j, dag) switch op case 'add', A = do_addition(A, op, i, j, dag); case 'del', A = do_removal(A, op, i, j, dag); case 'rev', A = do_removal(A, op, i, j, dag); A = do_addition(A, op, j, i, dag); end %%%%%%%%%%%% function A = do_addition(A, op, i, j, dag) A(j,i) = 1; % i is an ancestor of j anci = find(A(i,:)); if ~isempty(anci) A(j,anci) = 1; % all of i's ancestors are added to Anc(j) end ancj = find(A(j,:)); descj = find(A(:,j)); if ~isempty(ancj) for k=descj(:)' A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j end end %%%%%%%%%%% function A = do_removal(A, op, i, j, dag) % find all the descendants of j, and put them in topological order %descj = find(A(:,j)); R = reachability_graph(dag); descj = find(R(j,:)); order = topological_sort(dag); descj_topnum = order(descj); [junk, perm] = sort(descj_topnum); descj = descj(perm); % Update j and all its descendants A = update_row(A, j, dag); for k = descj(:)' A = update_row(A, k, dag); end %%%%%%%%% function A = update_row(A, j, dag) % We compute row j of A A(j, :) = 0; ps = parents(dag, j); if ~isempty(ps) A(j, ps) = 1; end for k=ps(:)' anck = find(A(k,:)); if ~isempty(anck) A(j, anck) = 1; end end %%%%%%%% function A = init_ancestor_matrix(dag) order = topological_sort(dag); A = zeros(length(dag)); for j=order(:)' A = update_row(A, j, dag); end