Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/examples/static/cmp_inference_static.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/examples/static/cmp_inference_static.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,114 @@ +function [time, engine] = cmp_inference_static(bnet, engine, varargin) +% CMP_INFERENCE Compare several inference engines on a BN +% function [time, engine] = cmp_inference_static(bnet, engine, ...) +% +% engine{i} is the i'th inference engine. +% time(e) = elapsed time for doing inference with engine e +% +% The list below gives optional arguments [default value in brackets]. +% +% exact - specifies which engines do exact inference [ 1:length(engine) ] +% singletons_only - if 1, we only call marginal_nodes, else this and marginal_family [0] +% maximize - 1 means we do max-propagation, 0 means sum-propagation [0] +% check_ll - 1 means we check that the log-likelihoods are correct [1] +% observed - list of the observed ndoes [ bnet.observed ] +% check_converged - list of loopy engines that should be checked for convergence [ [] ] +% If an engine has converged, it is added to the exact list. + + +% set default params +exact = 1:length(engine); +singletons_only = 0; +maximize = 0; +check_ll = 1; +observed = bnet.observed; +check_converged = []; + +args = varargin; +nargs = length(args); +for i=1:2:nargs + switch args{i}, + case 'exact', exact = args{i+1}; + case 'singletons_only', singletons_only = args{i+1}; + case 'maximize', maximize = args{i+1}; + case 'check_ll', check_ll = args{i+1}; + case 'observed', observed = args{i+1}; + case 'check_converged', check_converged = args{i+1}; + otherwise, + error(['unrecognized argument ' args{i}]) + end +end + +E = length(engine); +ref = exact(1); % reference + +N = length(bnet.dag); +ev = sample_bnet(bnet); +evidence = cell(1,N); +evidence(observed) = ev(observed); +%celldisp(evidence(observed)) + +for i=1:E + tic; + if check_ll + [engine{i}, ll(i)] = enter_evidence(engine{i}, evidence, 'maximize', maximize); + else + engine{i} = enter_evidence(engine{i}, evidence, 'maximize', maximize); + end + time(i)=toc; +end + +for i=check_converged(:)' + niter = loopy_converged(engine{i}); + if niter > 0 + fprintf('loopy engine %d converged in %d iterations\n', i, niter); +% exact = myunion(exact, i); + else + fprintf('loopy engine %d has not converged\n', i); + end +end + +cmp = exact(2:end); +if check_ll + for i=cmp(:)' + assert(approxeq(ll(ref), ll(i))); + end +end + +hnodes = mysetdiff(1:N, observed); + +if ~singletons_only + get_marginals(engine, hnodes, exact, 0); +end +get_marginals(engine, hnodes, exact, 1); + +%%%%%%%%%% + +function get_marginals(engine, hnodes, exact, singletons) + +bnet = bnet_from_engine(engine{1}); +N = length(bnet.dag); +cnodes_bitv = zeros(1,N); +cnodes_bitv(bnet.cnodes) = 1; +ref = exact(1); % reference +cmp = exact(2:end); +E = length(engine); + +for n=hnodes(:)' + for e=1:E + if singletons + m{e} = marginal_nodes(engine{e}, n); + else + m{e} = marginal_family(engine{e}, n); + end + end + for e=cmp(:)' + if cnodes_bitv(n) + assert(approxeq(m{ref}.mu, m{e}.mu)) + assert(approxeq(m{ref}.Sigma, m{e}.Sigma)) + else + assert(approxeq(m{ref}.T, m{e}.T)) + end + assert(isequal(m{e}.domain, m{ref}.domain)); + end +end