wolffd@0
|
1 % Fit a mixture of Gaussians using netlab and BNT
|
wolffd@0
|
2
|
wolffd@0
|
3 rand('state', 0);
|
wolffd@0
|
4 randn('state', 0);
|
wolffd@0
|
5
|
wolffd@0
|
6 % Q -> Y
|
wolffd@0
|
7 ncenters = 2; dim = 2;
|
wolffd@0
|
8 cov_type = 'full';
|
wolffd@0
|
9
|
wolffd@0
|
10 % Generate the data from a mixture of 2 Gaussians
|
wolffd@0
|
11 %mu = randn(dim, ncenters);
|
wolffd@0
|
12 mu = zeros(dim, ncenters);
|
wolffd@0
|
13 mu(:,1) = [-1 -1]';
|
wolffd@0
|
14 mu(:,1) = [1 1]';
|
wolffd@0
|
15 Sigma = repmat(0.1*eye(dim),[1 1 ncenters]);
|
wolffd@0
|
16 ndat1 = 8; ndat2 = 8;
|
wolffd@0
|
17 %ndat1 = 2; ndat2 = 2;
|
wolffd@0
|
18 ndata = ndat1+ndat2;
|
wolffd@0
|
19 x1 = gsamp(mu(:,1), Sigma(:,:,1), ndat1);
|
wolffd@0
|
20 x2 = gsamp(mu(:,2), Sigma(:,:,2), ndat2);
|
wolffd@0
|
21 data = [x1; x2];
|
wolffd@0
|
22 %plot(x1(:,1),x1(:,2),'ro', x2(:,1),x2(:,2),'bx')
|
wolffd@0
|
23
|
wolffd@0
|
24 % Fit using netlab
|
wolffd@0
|
25 max_iter = 3;
|
wolffd@0
|
26 mix = gmm(dim, ncenters, cov_type);
|
wolffd@0
|
27 options = foptions;
|
wolffd@0
|
28 options(1) = 1; % verbose
|
wolffd@0
|
29 options(14) = max_iter;
|
wolffd@0
|
30
|
wolffd@0
|
31 % extract initial params
|
wolffd@0
|
32 %mix = gmminit(mix, x, options); % Initialize with K-means
|
wolffd@0
|
33 mu0 = mix.centres';
|
wolffd@0
|
34 pi0 = mix.priors(:);
|
wolffd@0
|
35 Sigma0 = mix.covars; % repmat(eye(dim), [1 1 ncenters]);
|
wolffd@0
|
36
|
wolffd@0
|
37 [mix, options] = gmmem(mix, data, options);
|
wolffd@0
|
38
|
wolffd@0
|
39 % Final params
|
wolffd@0
|
40 ll1 = options(8);
|
wolffd@0
|
41 mu1 = mix.centres';
|
wolffd@0
|
42 pi1 = mix.priors(:);
|
wolffd@0
|
43 Sigma1 = mix.covars;
|
wolffd@0
|
44
|
wolffd@0
|
45
|
wolffd@0
|
46
|
wolffd@0
|
47
|
wolffd@0
|
48 % BNT
|
wolffd@0
|
49
|
wolffd@0
|
50 dag = zeros(2);
|
wolffd@0
|
51 dag(1,2) = 1;
|
wolffd@0
|
52 node_sizes = [ncenters dim];
|
wolffd@0
|
53 discrete_nodes = 1;
|
wolffd@0
|
54 onodes = 2;
|
wolffd@0
|
55
|
wolffd@0
|
56 bnet = mk_bnet(dag, node_sizes, 'discrete', discrete_nodes, 'observed', onodes);
|
wolffd@0
|
57 bnet.CPD{1} = tabular_CPD(bnet, 1, pi0);
|
wolffd@0
|
58 bnet.CPD{2} = gaussian_CPD(bnet, 2, 'mean', mu0, 'cov', Sigma0, 'cov_type', cov_type, ...
|
wolffd@0
|
59 'cov_prior_weight', 0);
|
wolffd@0
|
60
|
wolffd@0
|
61 engine = jtree_inf_engine(bnet);
|
wolffd@0
|
62
|
wolffd@0
|
63 evidence = cell(2, ndata);
|
wolffd@0
|
64 evidence(2,:) = num2cell(data', 1);
|
wolffd@0
|
65
|
wolffd@0
|
66 [bnet2, LL] = learn_params_em(engine, evidence, max_iter);
|
wolffd@0
|
67
|
wolffd@0
|
68 ll2 = LL(end);
|
wolffd@0
|
69 s1 = struct(bnet2.CPD{1});
|
wolffd@0
|
70 pi2 = s1.CPT(:);
|
wolffd@0
|
71
|
wolffd@0
|
72 s2 = struct(bnet2.CPD{2});
|
wolffd@0
|
73 mu2 = s2.mean;
|
wolffd@0
|
74 Sigma2 = s2.cov;
|
wolffd@0
|
75
|
wolffd@0
|
76 % assert(approxeq(ll1, ll2)); % gmmem returns the value after the final M step, GMT before
|
wolffd@0
|
77 assert(approxeq(mu1, mu2));
|
wolffd@0
|
78 assert(approxeq(Sigma1, Sigma2))
|
wolffd@0
|
79 assert(approxeq(pi1, pi2))
|
wolffd@0
|
80
|
wolffd@0
|
81
|