matthiasm@8: function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin) matthiasm@8: % MY_LEARN_STRUCT_MCMC Monte Carlo Markov Chain search over DAGs assuming fully observed data matthiasm@8: % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...) matthiasm@8: % matthiasm@8: % data(i,m) is the value of node i in case m. matthiasm@8: % ns(i) is the number of discrete values node i can take on. matthiasm@8: % matthiasm@8: % sampled_graphs{m} is the m'th sampled graph. matthiasm@8: % accept_ratio(t) = acceptance ratio at iteration t matthiasm@8: % num_edges(t) = number of edges in model at iteration t matthiasm@8: % matthiasm@8: % The following optional arguments can be specified in the form of name/value pairs: matthiasm@8: % [default value in brackets] matthiasm@8: % matthiasm@8: % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ] matthiasm@8: % Currently, only networks with all tabular nodes support Bayesian scoring. matthiasm@8: % type - type{i} is the type of CPD to use for node i, where the type is a string matthiasm@8: % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ] matthiasm@8: % params - params{i} contains optional arguments passed to the CPD constructor for node i, matthiasm@8: % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ] matthiasm@8: % discrete - the list of discrete nodes [ 1:N ] matthiasm@8: % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ] matthiasm@8: % nsamples - number of samples to draw from the chain after burn-in [ 100*N ] matthiasm@8: % burnin - number of steps to take before drawing samples [ 5*N ] matthiasm@8: % init_dag - starting point for the search [ zeros(N,N) ] matthiasm@8: % matthiasm@8: % e.g., samples = my_learn_struct_mcmc(data, ns, 'nsamples', 1000); matthiasm@8: % matthiasm@8: % Modified by Sonia Leach (SML) 2/4/02, 9/5/03 matthiasm@8: matthiasm@8: matthiasm@8: matthiasm@8: [n ncases] = size(data); matthiasm@8: matthiasm@8: matthiasm@8: % set default params matthiasm@8: type = cell(1,n); matthiasm@8: params = cell(1,n); matthiasm@8: for i=1:n matthiasm@8: type{i} = 'tabular'; matthiasm@8: %params{i} = { 'prior', 1}; matthiasm@8: params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 }; matthiasm@8: end matthiasm@8: scoring_fn = 'bayesian'; matthiasm@8: discrete = 1:n; matthiasm@8: clamped = zeros(n, ncases); matthiasm@8: nsamples = 100*n; matthiasm@8: burnin = 5*n; matthiasm@8: dag = zeros(n); matthiasm@8: matthiasm@8: args = varargin; matthiasm@8: nargs = length(args); matthiasm@8: for i=1:2:nargs matthiasm@8: switch args{i}, matthiasm@8: case 'nsamples', nsamples = args{i+1}; matthiasm@8: case 'burnin', burnin = args{i+1}; matthiasm@8: case 'init_dag', dag = args{i+1}; matthiasm@8: case 'scoring_fn', scoring_fn = args{i+1}; matthiasm@8: case 'type', type = args{i+1}; matthiasm@8: case 'discrete', discrete = args{i+1}; matthiasm@8: case 'clamped', clamped = args{i+1}; matthiasm@8: case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: % We implement the fast acyclicity check described by P. Giudici and R. Castelo, matthiasm@8: % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001. matthiasm@8: matthiasm@8: % SML: also keep descendant matrix C matthiasm@8: use_giudici = 1; matthiasm@8: if use_giudici matthiasm@8: [nbrs, ops, nodes, A] = mk_nbrs_of_digraph(dag); matthiasm@8: else matthiasm@8: [nbrs, ops, nodes] = mk_nbrs_of_dag(dag); matthiasm@8: A = []; matthiasm@8: end matthiasm@8: matthiasm@8: num_accepts = 1; matthiasm@8: num_rejects = 1; matthiasm@8: T = burnin + nsamples; matthiasm@8: accept_ratio = zeros(1, T); matthiasm@8: num_edges = zeros(1, T); matthiasm@8: sampled_graphs = cell(1, nsamples); matthiasm@8: %sampled_bitv = zeros(nsamples, n^2); matthiasm@8: matthiasm@8: for t=1:T matthiasm@8: [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, ... matthiasm@8: nodes, ns, data, clamped, A, ... matthiasm@8: scoring_fn, discrete, type, params); matthiasm@8: num_edges(t) = sum(dag(:)); matthiasm@8: num_accepts = num_accepts + accept; matthiasm@8: num_rejects = num_rejects + (1-accept); matthiasm@8: accept_ratio(t) = num_accepts/num_rejects; matthiasm@8: if t > burnin matthiasm@8: sampled_graphs{t-burnin} = dag; matthiasm@8: %sampled_bitv(t-burnin, :) = dag(:)'; matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: matthiasm@8: %%%%%%%%% matthiasm@8: matthiasm@8: matthiasm@8: function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ... matthiasm@8: take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ... matthiasm@8: scoring_fn, discrete, type, params, prior_w) matthiasm@8: matthiasm@8: matthiasm@8: use_giudici = ~isempty(A); matthiasm@8: if use_giudici matthiasm@8: [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, nbrs, ops, nodes,A); % updates A matthiasm@8: [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag, new_A); matthiasm@8: else matthiasm@8: d = sample_discrete(normalise(ones(1, length(nbrs)))); matthiasm@8: new_dag = nbrs{d}; matthiasm@8: op = ops{d}; matthiasm@8: i = nodes(d, 1); j = nodes(d, 2); matthiasm@8: [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag); matthiasm@8: end matthiasm@8: matthiasm@8: bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params); matthiasm@8: matthiasm@8: %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs)); matthiasm@8: R = bf * (length(nbrs) / length(new_nbrs)); matthiasm@8: u = rand(1,1); matthiasm@8: if u > min(1,R) % reject the move matthiasm@8: accept = 0; matthiasm@8: new_dag = dag; matthiasm@8: new_nbrs = nbrs; matthiasm@8: new_ops = ops; matthiasm@8: new_nodes = nodes; matthiasm@8: else matthiasm@8: accept = 1; matthiasm@8: if use_giudici matthiasm@8: A = new_A; % new_A already updated in pick_digraph_nbr matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: matthiasm@8: %%%%%%%%% matthiasm@8: matthiasm@8: function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params) matthiasm@8: matthiasm@8: u = find(clamped(j,:)==0); matthiasm@8: LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); matthiasm@8: LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); matthiasm@8: bf1 = exp(LLnew - LLold); matthiasm@8: matthiasm@8: if strcmp(op, 'rev') % must also multiply in the changes to i's family matthiasm@8: u = find(clamped(i,:)==0); matthiasm@8: LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); matthiasm@8: LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); matthiasm@8: bf2 = exp(LLnew - LLold); matthiasm@8: else matthiasm@8: bf2 = 1; matthiasm@8: end matthiasm@8: bfactor = bf1 * bf2; matthiasm@8: matthiasm@8: matthiasm@8: %%%%%%%% Giudici stuff follows %%%%%%%%%% matthiasm@8: matthiasm@8: matthiasm@8: % SML: This now updates A as it goes from digraph it choses matthiasm@8: function [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A) matthiasm@8: matthiasm@8: d = sample_discrete(normalise(ones(1, length(digraph_nbrs)))); matthiasm@8: %d = myunidrnd(length(digraph_nbrs),1,1); matthiasm@8: i = nodes(d, 1); j = nodes(d, 2); matthiasm@8: new_dag = digraph_nbrs(:,:,d); matthiasm@8: op = ops{d}; matthiasm@8: new_A = update_ancestor_matrix(A, op, i, j, new_dag); matthiasm@8: matthiasm@8: matthiasm@8: %%%%%%%%%%%%%% matthiasm@8: matthiasm@8: matthiasm@8: function A = update_ancestor_matrix(A, op, i, j, dag) matthiasm@8: matthiasm@8: switch op matthiasm@8: case 'add', matthiasm@8: A = do_addition(A, op, i, j, dag); matthiasm@8: case 'del', matthiasm@8: A = do_removal(A, op, i, j, dag); matthiasm@8: case 'rev', matthiasm@8: A = do_removal(A, op, i, j, dag); matthiasm@8: A = do_addition(A, op, j, i, dag); matthiasm@8: end matthiasm@8: matthiasm@8: matthiasm@8: %%%%%%%%%%%% matthiasm@8: matthiasm@8: function A = do_addition(A, op, i, j, dag) matthiasm@8: matthiasm@8: A(j,i) = 1; % i is an ancestor of j matthiasm@8: anci = find(A(i,:)); matthiasm@8: if ~isempty(anci) matthiasm@8: A(j,anci) = 1; % all of i's ancestors are added to Anc(j) matthiasm@8: end matthiasm@8: ancj = find(A(j,:)); matthiasm@8: descj = find(A(:,j)); matthiasm@8: if ~isempty(ancj) matthiasm@8: for k=descj(:)' matthiasm@8: A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: %%%%%%%%%%% matthiasm@8: function A = do_removal(A, op, i, j, dag) matthiasm@8: matthiasm@8: % find all the descendants of j, and put them in topological order matthiasm@8: matthiasm@8: % SML: originally Kevin had the next line commented and the %* lines matthiasm@8: % being used but I think this is equivalent and much less expensive matthiasm@8: % I assume he put it there for debugging and never changed it back...? matthiasm@8: descj = find(A(:,j)); matthiasm@8: %* R = reachability_graph(dag); matthiasm@8: %* descj = find(R(j,:)); matthiasm@8: matthiasm@8: order = topological_sort(dag); matthiasm@8: matthiasm@8: % SML: originally Kevin used the %* line but this was extracting the matthiasm@8: % wrong things to sort matthiasm@8: %* descj_topnum = order(descj); matthiasm@8: [junk, perm] = sort(order); %SML:node i is perm(i)-TH in order matthiasm@8: descj_topnum = perm(descj); %SML:descj(i) is descj_topnum(i)-th in order matthiasm@8: matthiasm@8: % SML: now re-sort descj by rank in descj_topnum matthiasm@8: [junk, perm] = sort(descj_topnum); matthiasm@8: descj = descj(perm); matthiasm@8: matthiasm@8: % Update j and all its descendants matthiasm@8: A = update_row(A, j, dag); matthiasm@8: for k = descj(:)' matthiasm@8: A = update_row(A, k, dag); matthiasm@8: end matthiasm@8: matthiasm@8: %%%%%%%%%%% matthiasm@8: matthiasm@8: function A = old_do_removal(A, op, i, j, dag) matthiasm@8: matthiasm@8: % find all the descendants of j, and put them in topological order matthiasm@8: % SML: originally Kevin had the next line commented and the %* lines matthiasm@8: % being used but I think this is equivalent and much less expensive matthiasm@8: % I assume he put it there for debugging and never changed it back...? matthiasm@8: descj = find(A(:,j)); matthiasm@8: %* R = reachability_graph(dag); matthiasm@8: %* descj = find(R(j,:)); matthiasm@8: matthiasm@8: order = topological_sort(dag); matthiasm@8: descj_topnum = order(descj); matthiasm@8: [junk, perm] = sort(descj_topnum); matthiasm@8: descj = descj(perm); matthiasm@8: % Update j and all its descendants matthiasm@8: A = update_row(A, j, dag); matthiasm@8: for k = descj(:)' matthiasm@8: A = update_row(A, k, dag); matthiasm@8: end matthiasm@8: matthiasm@8: %%%%%%%%% matthiasm@8: matthiasm@8: function A = update_row(A, j, dag) matthiasm@8: matthiasm@8: % We compute row j of A matthiasm@8: A(j, :) = 0; matthiasm@8: ps = parents(dag, j); matthiasm@8: if ~isempty(ps) matthiasm@8: A(j, ps) = 1; matthiasm@8: end matthiasm@8: for k=ps(:)' matthiasm@8: anck = find(A(k,:)); matthiasm@8: if ~isempty(anck) matthiasm@8: A(j, anck) = 1; matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: %%%%%%%% matthiasm@8: matthiasm@8: function A = init_ancestor_matrix(dag) matthiasm@8: matthiasm@8: order = topological_sort(dag); matthiasm@8: A = zeros(length(dag)); matthiasm@8: for j=order(:)' matthiasm@8: A = update_row(A, j, dag); matthiasm@8: end