diff toolboxes/FullBNT-1.0.7/bnt/examples/static/cmp_inference_static.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/bnt/examples/static/cmp_inference_static.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,114 @@
+function [time, engine] = cmp_inference_static(bnet, engine, varargin)
+% CMP_INFERENCE Compare several inference engines on a BN
+% function [time, engine] = cmp_inference_static(bnet, engine, ...)
+%
+% engine{i} is the i'th inference engine.
+% time(e) = elapsed time for doing inference with engine e
+%
+% The list below gives optional arguments [default value in brackets].
+%
+% exact - specifies which engines do exact inference [ 1:length(engine) ]
+% singletons_only - if 1, we only call marginal_nodes, else this  and marginal_family [0]
+% maximize - 1 means we do max-propagation, 0 means sum-propagation [0]
+% check_ll - 1 means we check that the log-likelihoods are correct [1]
+% observed - list of the observed ndoes [ bnet.observed ]
+% check_converged - list of loopy engines that should be checked for convergence [ [] ]
+%    If an engine has converged, it is added to the exact list.
+
+
+% set default params
+exact = 1:length(engine);
+singletons_only = 0;
+maximize = 0;
+check_ll = 1;
+observed = bnet.observed;
+check_converged = [];
+
+args = varargin;
+nargs = length(args);
+for i=1:2:nargs
+  switch args{i},
+   case 'exact', exact = args{i+1};
+   case 'singletons_only', singletons_only = args{i+1};
+   case 'maximize', maximize = args{i+1};
+   case 'check_ll', check_ll = args{i+1};
+   case 'observed', observed = args{i+1};
+   case 'check_converged', check_converged = args{i+1};
+   otherwise,
+    error(['unrecognized argument ' args{i}])
+  end
+end
+
+E = length(engine);
+ref = exact(1); % reference
+
+N = length(bnet.dag);
+ev = sample_bnet(bnet);
+evidence = cell(1,N);
+evidence(observed) = ev(observed);
+%celldisp(evidence(observed))
+
+for i=1:E
+  tic;
+  if check_ll
+    [engine{i}, ll(i)] = enter_evidence(engine{i}, evidence, 'maximize', maximize);
+  else
+    engine{i} = enter_evidence(engine{i}, evidence, 'maximize', maximize);
+  end
+  time(i)=toc;
+end
+
+for i=check_converged(:)'
+  niter = loopy_converged(engine{i});
+  if niter > 0
+    fprintf('loopy engine %d  converged in %d iterations\n', i, niter);
+%    exact = myunion(exact, i);
+  else
+    fprintf('loopy engine %d has not converged\n', i);
+  end
+end
+
+cmp = exact(2:end);
+if check_ll
+  for i=cmp(:)'
+    assert(approxeq(ll(ref), ll(i)));
+  end
+end
+
+hnodes = mysetdiff(1:N, observed);
+
+if ~singletons_only
+  get_marginals(engine, hnodes, exact, 0);
+end
+get_marginals(engine, hnodes, exact, 1);
+
+%%%%%%%%%%
+
+function get_marginals(engine, hnodes, exact, singletons)
+
+bnet = bnet_from_engine(engine{1});
+N = length(bnet.dag);
+cnodes_bitv = zeros(1,N);
+cnodes_bitv(bnet.cnodes) = 1;
+ref = exact(1); % reference
+cmp = exact(2:end);
+E = length(engine);
+
+for n=hnodes(:)'
+  for e=1:E
+    if singletons
+      m{e} = marginal_nodes(engine{e}, n);
+    else
+      m{e} = marginal_family(engine{e}, n);
+    end
+  end
+  for e=cmp(:)'
+    if cnodes_bitv(n)
+      assert(approxeq(m{ref}.mu, m{e}.mu))
+      assert(approxeq(m{ref}.Sigma, m{e}.Sigma))
+    else
+      assert(approxeq(m{ref}.T, m{e}.T))
+    end
+    assert(isequal(m{e}.domain, m{ref}.domain));
+  end
+end