Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/static/mog1.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % Fit a mixture of Gaussians using netlab and BNT | |
2 | |
3 rand('state', 0); | |
4 randn('state', 0); | |
5 | |
6 % Q -> Y | |
7 ncenters = 2; dim = 2; | |
8 cov_type = 'full'; | |
9 | |
10 % Generate the data from a mixture of 2 Gaussians | |
11 %mu = randn(dim, ncenters); | |
12 mu = zeros(dim, ncenters); | |
13 mu(:,1) = [-1 -1]'; | |
14 mu(:,1) = [1 1]'; | |
15 Sigma = repmat(0.1*eye(dim),[1 1 ncenters]); | |
16 ndat1 = 8; ndat2 = 8; | |
17 %ndat1 = 2; ndat2 = 2; | |
18 ndata = ndat1+ndat2; | |
19 x1 = gsamp(mu(:,1), Sigma(:,:,1), ndat1); | |
20 x2 = gsamp(mu(:,2), Sigma(:,:,2), ndat2); | |
21 data = [x1; x2]; | |
22 %plot(x1(:,1),x1(:,2),'ro', x2(:,1),x2(:,2),'bx') | |
23 | |
24 % Fit using netlab | |
25 max_iter = 3; | |
26 mix = gmm(dim, ncenters, cov_type); | |
27 options = foptions; | |
28 options(1) = 1; % verbose | |
29 options(14) = max_iter; | |
30 | |
31 % extract initial params | |
32 %mix = gmminit(mix, x, options); % Initialize with K-means | |
33 mu0 = mix.centres'; | |
34 pi0 = mix.priors(:); | |
35 Sigma0 = mix.covars; % repmat(eye(dim), [1 1 ncenters]); | |
36 | |
37 [mix, options] = gmmem(mix, data, options); | |
38 | |
39 % Final params | |
40 ll1 = options(8); | |
41 mu1 = mix.centres'; | |
42 pi1 = mix.priors(:); | |
43 Sigma1 = mix.covars; | |
44 | |
45 | |
46 | |
47 | |
48 % BNT | |
49 | |
50 dag = zeros(2); | |
51 dag(1,2) = 1; | |
52 node_sizes = [ncenters dim]; | |
53 discrete_nodes = 1; | |
54 onodes = 2; | |
55 | |
56 bnet = mk_bnet(dag, node_sizes, 'discrete', discrete_nodes, 'observed', onodes); | |
57 bnet.CPD{1} = tabular_CPD(bnet, 1, pi0); | |
58 bnet.CPD{2} = gaussian_CPD(bnet, 2, 'mean', mu0, 'cov', Sigma0, 'cov_type', cov_type, ... | |
59 'cov_prior_weight', 0); | |
60 | |
61 engine = jtree_inf_engine(bnet); | |
62 | |
63 evidence = cell(2, ndata); | |
64 evidence(2,:) = num2cell(data', 1); | |
65 | |
66 [bnet2, LL] = learn_params_em(engine, evidence, max_iter); | |
67 | |
68 ll2 = LL(end); | |
69 s1 = struct(bnet2.CPD{1}); | |
70 pi2 = s1.CPT(:); | |
71 | |
72 s2 = struct(bnet2.CPD{2}); | |
73 mu2 = s2.mean; | |
74 Sigma2 = s2.cov; | |
75 | |
76 % assert(approxeq(ll1, ll2)); % gmmem returns the value after the final M step, GMT before | |
77 assert(approxeq(mu1, mu2)); | |
78 assert(approxeq(Sigma1, Sigma2)) | |
79 assert(approxeq(pi1, pi2)) | |
80 | |
81 |