comparison toolboxes/FullBNT-1.0.7/bnt/examples/static/mog1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 % Fit a mixture of Gaussians using netlab and BNT
2
3 rand('state', 0);
4 randn('state', 0);
5
6 % Q -> Y
7 ncenters = 2; dim = 2;
8 cov_type = 'full';
9
10 % Generate the data from a mixture of 2 Gaussians
11 %mu = randn(dim, ncenters);
12 mu = zeros(dim, ncenters);
13 mu(:,1) = [-1 -1]';
14 mu(:,1) = [1 1]';
15 Sigma = repmat(0.1*eye(dim),[1 1 ncenters]);
16 ndat1 = 8; ndat2 = 8;
17 %ndat1 = 2; ndat2 = 2;
18 ndata = ndat1+ndat2;
19 x1 = gsamp(mu(:,1), Sigma(:,:,1), ndat1);
20 x2 = gsamp(mu(:,2), Sigma(:,:,2), ndat2);
21 data = [x1; x2];
22 %plot(x1(:,1),x1(:,2),'ro', x2(:,1),x2(:,2),'bx')
23
24 % Fit using netlab
25 max_iter = 3;
26 mix = gmm(dim, ncenters, cov_type);
27 options = foptions;
28 options(1) = 1; % verbose
29 options(14) = max_iter;
30
31 % extract initial params
32 %mix = gmminit(mix, x, options); % Initialize with K-means
33 mu0 = mix.centres';
34 pi0 = mix.priors(:);
35 Sigma0 = mix.covars; % repmat(eye(dim), [1 1 ncenters]);
36
37 [mix, options] = gmmem(mix, data, options);
38
39 % Final params
40 ll1 = options(8);
41 mu1 = mix.centres';
42 pi1 = mix.priors(:);
43 Sigma1 = mix.covars;
44
45
46
47
48 % BNT
49
50 dag = zeros(2);
51 dag(1,2) = 1;
52 node_sizes = [ncenters dim];
53 discrete_nodes = 1;
54 onodes = 2;
55
56 bnet = mk_bnet(dag, node_sizes, 'discrete', discrete_nodes, 'observed', onodes);
57 bnet.CPD{1} = tabular_CPD(bnet, 1, pi0);
58 bnet.CPD{2} = gaussian_CPD(bnet, 2, 'mean', mu0, 'cov', Sigma0, 'cov_type', cov_type, ...
59 'cov_prior_weight', 0);
60
61 engine = jtree_inf_engine(bnet);
62
63 evidence = cell(2, ndata);
64 evidence(2,:) = num2cell(data', 1);
65
66 [bnet2, LL] = learn_params_em(engine, evidence, max_iter);
67
68 ll2 = LL(end);
69 s1 = struct(bnet2.CPD{1});
70 pi2 = s1.CPT(:);
71
72 s2 = struct(bnet2.CPD{2});
73 mu2 = s2.mean;
74 Sigma2 = s2.cov;
75
76 % assert(approxeq(ll1, ll2)); % gmmem returns the value after the final M step, GMT before
77 assert(approxeq(mu1, mu2));
78 assert(approxeq(Sigma1, Sigma2))
79 assert(approxeq(pi1, pi2))
80
81