diff toolboxes/FullBNT-1.0.7/bnt/inference/static/@likelihood_weighting_inf_engine/marginal_nodes.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/bnt/inference/static/@likelihood_weighting_inf_engine/marginal_nodes.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,53 @@
+function marginal = marginal_nodes(engine, nodes)
+% MARGINAL_NODES Compute the marginal on the specified query nodes (likelihood_weighting)
+% marginal = marginal_nodes(engine, nodes)
+
+bnet = bnet_from_engine(engine);
+ddom = myintersect(nodes, bnet.dnodes);
+cdom = myintersect(nodes, bnet.cnodes);
+nsamples = size(engine.samples, 1);
+ns = bnet.node_sizes;
+
+%w = normalise(engine.weights);
+w = engine.weights;
+if mysubset(nodes, ddom)
+  T = 0*myones(ns(nodes));
+  P = prod(ns(nodes));
+  indices = ind2subv(ns(nodes), 1:P);
+  samples = reshape(cat(1, engine.samples{:,nodes}), nsamples, length(nodes));
+  for j = 1:P
+    rows = find_rows(samples, indices(j,:));
+    T(j) = sum(w(rows));
+  end
+  T = normalise(T);
+  marginal.T = T;
+elseif subset(nodes, cdom)
+  samples = reshape(cat(1, engine.samples{:,nodes}), nsamples*sum(ns(nodes)), length(nodes));
+  [marginal.mu, marginal.Sigma] =  wstats(samples', normalise(w));
+else
+  error('can''t handle mixed marginals yet');
+end
+
+marginal.domain = nodes;
+
+%%%%%%%%%
+
+function rows = find_rows(M, v)
+% FINDROWS Find rows which are equal to a specified vector
+% rows = findrows(M, v)
+% Each row of M is a sample
+
+temp = abs(M - repmat(v, size(M, 1), 1));
+rows = find(sum(temp,2) == 0);      
+
+%%%%%%%%
+
+function [mu, Sigma] = wstats(X, w)
+
+% Computes the weighted mean and weighted covariance matrix for a given
+% set of observations X(:,i), and a set of normalised weights w(i).
+% Each column of X is a sample.
+
+d = X - repmat(X * w', 1, size(X, 2));
+mu = sum(X .* repmat(w, size(X, 1), 1), 2);
+Sigma = d * diag(w) * d';