Daniel@0: function [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, varargin) Daniel@0: % MY_LEARN_STRUCT_MCMC Monte Carlo Markov Chain search over DAGs assuming fully observed data Daniel@0: % [sampled_graphs, accept_ratio, num_edges] = learn_struct_mcmc(data, ns, ...) Daniel@0: % Daniel@0: % data(i,m) is the value of node i in case m. Daniel@0: % ns(i) is the number of discrete values node i can take on. Daniel@0: % Daniel@0: % sampled_graphs{m} is the m'th sampled graph. Daniel@0: % accept_ratio(t) = acceptance ratio at iteration t Daniel@0: % num_edges(t) = number of edges in model at iteration t Daniel@0: % Daniel@0: % The following optional arguments can be specified in the form of name/value pairs: Daniel@0: % [default value in brackets] Daniel@0: % Daniel@0: % scoring_fn - 'bayesian' or 'bic' [ 'bayesian' ] Daniel@0: % Currently, only networks with all tabular nodes support Bayesian scoring. Daniel@0: % type - type{i} is the type of CPD to use for node i, where the type is a string Daniel@0: % of the form 'tabular', 'noisy_or', 'gaussian', etc. [ all cells contain 'tabular' ] Daniel@0: % params - params{i} contains optional arguments passed to the CPD constructor for node i, Daniel@0: % or [] if none. [ all cells contain {'prior', 1}, meaning use uniform Dirichlet priors ] Daniel@0: % discrete - the list of discrete nodes [ 1:N ] Daniel@0: % clamped - clamped(i,m) = 1 if node i is clamped in case m [ zeros(N, ncases) ] Daniel@0: % nsamples - number of samples to draw from the chain after burn-in [ 100*N ] Daniel@0: % burnin - number of steps to take before drawing samples [ 5*N ] Daniel@0: % init_dag - starting point for the search [ zeros(N,N) ] Daniel@0: % Daniel@0: % e.g., samples = my_learn_struct_mcmc(data, ns, 'nsamples', 1000); Daniel@0: % Daniel@0: % Modified by Sonia Leach (SML) 2/4/02, 9/5/03 Daniel@0: Daniel@0: Daniel@0: Daniel@0: [n ncases] = size(data); Daniel@0: Daniel@0: Daniel@0: % set default params Daniel@0: type = cell(1,n); Daniel@0: params = cell(1,n); Daniel@0: for i=1:n Daniel@0: type{i} = 'tabular'; Daniel@0: %params{i} = { 'prior', 1}; Daniel@0: params{i} = { 'prior_type', 'dirichlet', 'dirichlet_weight', 1 }; Daniel@0: end Daniel@0: scoring_fn = 'bayesian'; Daniel@0: discrete = 1:n; Daniel@0: clamped = zeros(n, ncases); Daniel@0: nsamples = 100*n; Daniel@0: burnin = 5*n; Daniel@0: dag = zeros(n); Daniel@0: Daniel@0: args = varargin; Daniel@0: nargs = length(args); Daniel@0: for i=1:2:nargs Daniel@0: switch args{i}, Daniel@0: case 'nsamples', nsamples = args{i+1}; Daniel@0: case 'burnin', burnin = args{i+1}; Daniel@0: case 'init_dag', dag = args{i+1}; Daniel@0: case 'scoring_fn', scoring_fn = args{i+1}; Daniel@0: case 'type', type = args{i+1}; Daniel@0: case 'discrete', discrete = args{i+1}; Daniel@0: case 'clamped', clamped = args{i+1}; Daniel@0: case 'params', if isempty(args{i+1}), params = cell(1,n); else params = args{i+1}; end Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: % We implement the fast acyclicity check described by P. Giudici and R. Castelo, Daniel@0: % "Improving MCMC model search for data mining", submitted to J. Machine Learning, 2001. Daniel@0: Daniel@0: % SML: also keep descendant matrix C Daniel@0: use_giudici = 1; Daniel@0: if use_giudici Daniel@0: [nbrs, ops, nodes, A] = mk_nbrs_of_digraph(dag); Daniel@0: else Daniel@0: [nbrs, ops, nodes] = mk_nbrs_of_dag(dag); Daniel@0: A = []; Daniel@0: end Daniel@0: Daniel@0: num_accepts = 1; Daniel@0: num_rejects = 1; Daniel@0: T = burnin + nsamples; Daniel@0: accept_ratio = zeros(1, T); Daniel@0: num_edges = zeros(1, T); Daniel@0: sampled_graphs = cell(1, nsamples); Daniel@0: %sampled_bitv = zeros(nsamples, n^2); Daniel@0: Daniel@0: for t=1:T Daniel@0: [dag, nbrs, ops, nodes, A, accept] = take_step(dag, nbrs, ops, ... Daniel@0: nodes, ns, data, clamped, A, ... Daniel@0: scoring_fn, discrete, type, params); Daniel@0: num_edges(t) = sum(dag(:)); Daniel@0: num_accepts = num_accepts + accept; Daniel@0: num_rejects = num_rejects + (1-accept); Daniel@0: accept_ratio(t) = num_accepts/num_rejects; Daniel@0: if t > burnin Daniel@0: sampled_graphs{t-burnin} = dag; Daniel@0: %sampled_bitv(t-burnin, :) = dag(:)'; Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: Daniel@0: %%%%%%%%% Daniel@0: Daniel@0: Daniel@0: function [new_dag, new_nbrs, new_ops, new_nodes, A, accept] = ... Daniel@0: take_step(dag, nbrs, ops, nodes, ns, data, clamped, A, ... Daniel@0: scoring_fn, discrete, type, params, prior_w) Daniel@0: Daniel@0: Daniel@0: use_giudici = ~isempty(A); Daniel@0: if use_giudici Daniel@0: [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, nbrs, ops, nodes,A); % updates A Daniel@0: [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_digraph(new_dag, new_A); Daniel@0: else Daniel@0: d = sample_discrete(normalise(ones(1, length(nbrs)))); Daniel@0: new_dag = nbrs{d}; Daniel@0: op = ops{d}; Daniel@0: i = nodes(d, 1); j = nodes(d, 2); Daniel@0: [new_nbrs, new_ops, new_nodes] = mk_nbrs_of_dag(new_dag); Daniel@0: end Daniel@0: Daniel@0: bf = bayes_factor(dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params); Daniel@0: Daniel@0: %R = bf * (new_prior / prior) * (length(nbrs) / length(new_nbrs)); Daniel@0: R = bf * (length(nbrs) / length(new_nbrs)); Daniel@0: u = rand(1,1); Daniel@0: if u > min(1,R) % reject the move Daniel@0: accept = 0; Daniel@0: new_dag = dag; Daniel@0: new_nbrs = nbrs; Daniel@0: new_ops = ops; Daniel@0: new_nodes = nodes; Daniel@0: else Daniel@0: accept = 1; Daniel@0: if use_giudici Daniel@0: A = new_A; % new_A already updated in pick_digraph_nbr Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: Daniel@0: %%%%%%%%% Daniel@0: Daniel@0: function bfactor = bayes_factor(old_dag, new_dag, op, i, j, ns, data, clamped, scoring_fn, discrete, type, params) Daniel@0: Daniel@0: u = find(clamped(j,:)==0); Daniel@0: LLnew = score_family(j, parents(new_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); Daniel@0: LLold = score_family(j, parents(old_dag, j), type{j}, scoring_fn, ns, discrete, data(:,u), params{j}); Daniel@0: bf1 = exp(LLnew - LLold); Daniel@0: Daniel@0: if strcmp(op, 'rev') % must also multiply in the changes to i's family Daniel@0: u = find(clamped(i,:)==0); Daniel@0: LLnew = score_family(i, parents(new_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); Daniel@0: LLold = score_family(i, parents(old_dag, i), type{i}, scoring_fn, ns, discrete, data(:,u), params{i}); Daniel@0: bf2 = exp(LLnew - LLold); Daniel@0: else Daniel@0: bf2 = 1; Daniel@0: end Daniel@0: bfactor = bf1 * bf2; Daniel@0: Daniel@0: Daniel@0: %%%%%%%% Giudici stuff follows %%%%%%%%%% Daniel@0: Daniel@0: Daniel@0: % SML: This now updates A as it goes from digraph it choses Daniel@0: function [new_dag, op, i, j, new_A] = pick_digraph_nbr(dag, digraph_nbrs, ops, nodes, A) Daniel@0: Daniel@0: d = sample_discrete(normalise(ones(1, length(digraph_nbrs)))); Daniel@0: %d = myunidrnd(length(digraph_nbrs),1,1); Daniel@0: i = nodes(d, 1); j = nodes(d, 2); Daniel@0: new_dag = digraph_nbrs(:,:,d); Daniel@0: op = ops{d}; Daniel@0: new_A = update_ancestor_matrix(A, op, i, j, new_dag); Daniel@0: Daniel@0: Daniel@0: %%%%%%%%%%%%%% Daniel@0: Daniel@0: Daniel@0: function A = update_ancestor_matrix(A, op, i, j, dag) Daniel@0: Daniel@0: switch op Daniel@0: case 'add', Daniel@0: A = do_addition(A, op, i, j, dag); Daniel@0: case 'del', Daniel@0: A = do_removal(A, op, i, j, dag); Daniel@0: case 'rev', Daniel@0: A = do_removal(A, op, i, j, dag); Daniel@0: A = do_addition(A, op, j, i, dag); Daniel@0: end Daniel@0: Daniel@0: Daniel@0: %%%%%%%%%%%% Daniel@0: Daniel@0: function A = do_addition(A, op, i, j, dag) Daniel@0: Daniel@0: A(j,i) = 1; % i is an ancestor of j Daniel@0: anci = find(A(i,:)); Daniel@0: if ~isempty(anci) Daniel@0: A(j,anci) = 1; % all of i's ancestors are added to Anc(j) Daniel@0: end Daniel@0: ancj = find(A(j,:)); Daniel@0: descj = find(A(:,j)); Daniel@0: if ~isempty(ancj) Daniel@0: for k=descj(:)' Daniel@0: A(k,ancj) = 1; % all of j's ancestors are added to each descendant of j Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: %%%%%%%%%%% Daniel@0: function A = do_removal(A, op, i, j, dag) Daniel@0: Daniel@0: % find all the descendants of j, and put them in topological order Daniel@0: Daniel@0: % SML: originally Kevin had the next line commented and the %* lines Daniel@0: % being used but I think this is equivalent and much less expensive Daniel@0: % I assume he put it there for debugging and never changed it back...? Daniel@0: descj = find(A(:,j)); Daniel@0: %* R = reachability_graph(dag); Daniel@0: %* descj = find(R(j,:)); Daniel@0: Daniel@0: order = topological_sort(dag); Daniel@0: Daniel@0: % SML: originally Kevin used the %* line but this was extracting the Daniel@0: % wrong things to sort Daniel@0: %* descj_topnum = order(descj); Daniel@0: [junk, perm] = sort(order); %SML:node i is perm(i)-TH in order Daniel@0: descj_topnum = perm(descj); %SML:descj(i) is descj_topnum(i)-th in order Daniel@0: Daniel@0: % SML: now re-sort descj by rank in descj_topnum Daniel@0: [junk, perm] = sort(descj_topnum); Daniel@0: descj = descj(perm); Daniel@0: Daniel@0: % Update j and all its descendants Daniel@0: A = update_row(A, j, dag); Daniel@0: for k = descj(:)' Daniel@0: A = update_row(A, k, dag); Daniel@0: end Daniel@0: Daniel@0: %%%%%%%%%%% Daniel@0: Daniel@0: function A = old_do_removal(A, op, i, j, dag) Daniel@0: Daniel@0: % find all the descendants of j, and put them in topological order Daniel@0: % SML: originally Kevin had the next line commented and the %* lines Daniel@0: % being used but I think this is equivalent and much less expensive Daniel@0: % I assume he put it there for debugging and never changed it back...? Daniel@0: descj = find(A(:,j)); Daniel@0: %* R = reachability_graph(dag); Daniel@0: %* descj = find(R(j,:)); Daniel@0: Daniel@0: order = topological_sort(dag); Daniel@0: descj_topnum = order(descj); Daniel@0: [junk, perm] = sort(descj_topnum); Daniel@0: descj = descj(perm); Daniel@0: % Update j and all its descendants Daniel@0: A = update_row(A, j, dag); Daniel@0: for k = descj(:)' Daniel@0: A = update_row(A, k, dag); Daniel@0: end Daniel@0: Daniel@0: %%%%%%%%% Daniel@0: Daniel@0: function A = update_row(A, j, dag) Daniel@0: Daniel@0: % We compute row j of A Daniel@0: A(j, :) = 0; Daniel@0: ps = parents(dag, j); Daniel@0: if ~isempty(ps) Daniel@0: A(j, ps) = 1; Daniel@0: end Daniel@0: for k=ps(:)' Daniel@0: anck = find(A(k,:)); Daniel@0: if ~isempty(anck) Daniel@0: A(j, anck) = 1; Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: %%%%%%%% Daniel@0: Daniel@0: function A = init_ancestor_matrix(dag) Daniel@0: Daniel@0: order = topological_sort(dag); Daniel@0: A = zeros(length(dag)); Daniel@0: for j=order(:)' Daniel@0: A = update_row(A, j, dag); Daniel@0: end