annotate toolboxes/FullBNT-1.0.7/bnt/examples/static/fa1.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 % Factor analysis
wolffd@0 2 % Z -> X, Z in R^k, X in R^D, k << D (high dimensional observations explained by small source)
wolffd@0 3 % Z ~ N(0,I), X|Z ~ N(L z, Psi), where Psi is diagonal.
wolffd@0 4 %
wolffd@0 5 % We compare to Zoubin Ghahramani's code.
wolffd@0 6
wolffd@0 7 state = 0;
wolffd@0 8 rand('seed', state);
wolffd@0 9 randn('seed', state);
wolffd@0 10 max_iter = 3;
wolffd@0 11 k = 2;
wolffd@0 12 D = 4;
wolffd@0 13 N = 10;
wolffd@0 14 X = randn(N, D);
wolffd@0 15
wolffd@0 16 % Initialize as in Zoubin's ffa (fast factor analysis)
wolffd@0 17 X=X-ones(N,1)*mean(X);
wolffd@0 18 XX=X'*X/N;
wolffd@0 19 diagXX=diag(XX);
wolffd@0 20 cX=cov(X);
wolffd@0 21 scale=det(cX)^(1/D);
wolffd@0 22 randn('seed', 0); % must reset seed here so initial params are identical to mfa
wolffd@0 23 L0=randn(D,k)*sqrt(scale/k);
wolffd@0 24 W0 = L0;
wolffd@0 25 Psi0=diag(cX);
wolffd@0 26
wolffd@0 27 [L1, Psi1, LL1] = ffa(X,k,max_iter);
wolffd@0 28
wolffd@0 29
wolffd@0 30 ns = [k D];
wolffd@0 31 dag = zeros(2,2);
wolffd@0 32 dag(1,2) = 1;
wolffd@0 33 bnet = mk_bnet(dag, ns, 'discrete', [], 'observed', 2);
wolffd@0 34 bnet.CPD{1} = gaussian_CPD(bnet, 1, 'mean', zeros(k,1), 'cov', eye(k), 'cov_type', 'diag', ...
wolffd@0 35 'clamp_mean', 1, 'clamp_cov', 1);
wolffd@0 36 bnet.CPD{2} = gaussian_CPD(bnet, 2, 'mean', zeros(D,1), 'cov', diag(Psi0), 'weights', W0, ...
wolffd@0 37 'cov_type', 'diag', 'cov_prior_weight', 0, 'clamp_mean', 1);
wolffd@0 38
wolffd@0 39 engine = jtree_inf_engine(bnet);
wolffd@0 40 evidence = cell(2,N);
wolffd@0 41 evidence(2,:) = num2cell(X', 1);
wolffd@0 42
wolffd@0 43 [bnet2, LL2] = learn_params_em(engine, evidence, max_iter);
wolffd@0 44
wolffd@0 45 s = struct(bnet2.CPD{2});
wolffd@0 46 L2 = s.weights;
wolffd@0 47 Psi2 = s.cov;
wolffd@0 48
wolffd@0 49
wolffd@0 50
wolffd@0 51 % Compare to Zoubin's code
wolffd@0 52 assert(approxeq(LL2, LL1));
wolffd@0 53 assert(approxeq(Psi2, diag(Psi1)));
wolffd@0 54 assert(approxeq(L2, L1));
wolffd@0 55
wolffd@0 56
wolffd@0 57