diff toolboxes/FullBNT-1.0.7/bnt/examples/static/mog1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/bnt/examples/static/mog1.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,81 @@
+% Fit a mixture of Gaussians using netlab and BNT
+
+rand('state', 0);
+randn('state', 0);
+
+% Q -> Y
+ncenters = 2; dim = 2;
+cov_type = 'full';
+
+% Generate the data from a mixture of 2 Gaussians
+%mu = randn(dim, ncenters);
+mu = zeros(dim, ncenters);
+mu(:,1) = [-1 -1]';
+mu(:,1) = [1 1]';
+Sigma = repmat(0.1*eye(dim),[1 1 ncenters]);
+ndat1 = 8; ndat2 = 8;
+%ndat1 = 2; ndat2 = 2;
+ndata = ndat1+ndat2;
+x1 = gsamp(mu(:,1), Sigma(:,:,1), ndat1);
+x2 = gsamp(mu(:,2), Sigma(:,:,2), ndat2);
+data = [x1; x2];
+%plot(x1(:,1),x1(:,2),'ro', x2(:,1),x2(:,2),'bx')
+
+% Fit using netlab
+max_iter = 3;
+mix = gmm(dim, ncenters, cov_type);
+options = foptions;
+options(1) = 1; % verbose
+options(14) = max_iter;
+
+% extract initial params
+%mix = gmminit(mix, x, options); % Initialize with K-means
+mu0 = mix.centres';
+pi0 = mix.priors(:);
+Sigma0 = mix.covars; % repmat(eye(dim), [1 1 ncenters]);
+
+[mix, options] = gmmem(mix, data, options);
+
+% Final params
+ll1 = options(8);
+mu1 = mix.centres';
+pi1 = mix.priors(:);
+Sigma1 = mix.covars;
+
+
+
+
+% BNT
+
+dag = zeros(2);
+dag(1,2) = 1;
+node_sizes = [ncenters dim];
+discrete_nodes = 1;
+onodes = 2;
+
+bnet = mk_bnet(dag, node_sizes, 'discrete', discrete_nodes, 'observed', onodes);
+bnet.CPD{1} = tabular_CPD(bnet, 1, pi0);
+bnet.CPD{2} = gaussian_CPD(bnet, 2, 'mean', mu0, 'cov', Sigma0, 'cov_type', cov_type, ...
+			   'cov_prior_weight', 0);
+
+engine = jtree_inf_engine(bnet);
+
+evidence = cell(2, ndata);
+evidence(2,:) = num2cell(data', 1);
+
+[bnet2, LL] = learn_params_em(engine, evidence, max_iter);
+
+ll2 = LL(end);
+s1 = struct(bnet2.CPD{1});
+pi2 = s1.CPT(:);
+
+s2 = struct(bnet2.CPD{2});
+mu2 = s2.mean;
+Sigma2 = s2.cov;
+
+% assert(approxeq(ll1, ll2)); % gmmem returns the value after the final M step, GMT before
+assert(approxeq(mu1, mu2));
+assert(approxeq(Sigma1, Sigma2))
+assert(approxeq(pi1, pi2))
+
+