wolffd@0
|
1 % Factor analysis
|
wolffd@0
|
2 % Z -> X, Z in R^k, X in R^D, k << D (high dimensional observations explained by small source)
|
wolffd@0
|
3 % Z ~ N(0,I), X|Z ~ N(L z, Psi), where Psi is diagonal.
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % We compare to Zoubin Ghahramani's code.
|
wolffd@0
|
6
|
wolffd@0
|
7 state = 0;
|
wolffd@0
|
8 rand('seed', state);
|
wolffd@0
|
9 randn('seed', state);
|
wolffd@0
|
10 max_iter = 3;
|
wolffd@0
|
11 k = 2;
|
wolffd@0
|
12 D = 4;
|
wolffd@0
|
13 N = 10;
|
wolffd@0
|
14 X = randn(N, D);
|
wolffd@0
|
15
|
wolffd@0
|
16 % Initialize as in Zoubin's ffa (fast factor analysis)
|
wolffd@0
|
17 X=X-ones(N,1)*mean(X);
|
wolffd@0
|
18 XX=X'*X/N;
|
wolffd@0
|
19 diagXX=diag(XX);
|
wolffd@0
|
20 cX=cov(X);
|
wolffd@0
|
21 scale=det(cX)^(1/D);
|
wolffd@0
|
22 randn('seed', 0); % must reset seed here so initial params are identical to mfa
|
wolffd@0
|
23 L0=randn(D,k)*sqrt(scale/k);
|
wolffd@0
|
24 W0 = L0;
|
wolffd@0
|
25 Psi0=diag(cX);
|
wolffd@0
|
26
|
wolffd@0
|
27 [L1, Psi1, LL1] = ffa(X,k,max_iter);
|
wolffd@0
|
28
|
wolffd@0
|
29
|
wolffd@0
|
30 ns = [k D];
|
wolffd@0
|
31 dag = zeros(2,2);
|
wolffd@0
|
32 dag(1,2) = 1;
|
wolffd@0
|
33 bnet = mk_bnet(dag, ns, 'discrete', [], 'observed', 2);
|
wolffd@0
|
34 bnet.CPD{1} = gaussian_CPD(bnet, 1, 'mean', zeros(k,1), 'cov', eye(k), 'cov_type', 'diag', ...
|
wolffd@0
|
35 'clamp_mean', 1, 'clamp_cov', 1);
|
wolffd@0
|
36 bnet.CPD{2} = gaussian_CPD(bnet, 2, 'mean', zeros(D,1), 'cov', diag(Psi0), 'weights', W0, ...
|
wolffd@0
|
37 'cov_type', 'diag', 'cov_prior_weight', 0, 'clamp_mean', 1);
|
wolffd@0
|
38
|
wolffd@0
|
39 engine = jtree_inf_engine(bnet);
|
wolffd@0
|
40 evidence = cell(2,N);
|
wolffd@0
|
41 evidence(2,:) = num2cell(X', 1);
|
wolffd@0
|
42
|
wolffd@0
|
43 [bnet2, LL2] = learn_params_em(engine, evidence, max_iter);
|
wolffd@0
|
44
|
wolffd@0
|
45 s = struct(bnet2.CPD{2});
|
wolffd@0
|
46 L2 = s.weights;
|
wolffd@0
|
47 Psi2 = s.cov;
|
wolffd@0
|
48
|
wolffd@0
|
49
|
wolffd@0
|
50
|
wolffd@0
|
51 % Compare to Zoubin's code
|
wolffd@0
|
52 assert(approxeq(LL2, LL1));
|
wolffd@0
|
53 assert(approxeq(Psi2, diag(Psi1)));
|
wolffd@0
|
54 assert(approxeq(L2, L1));
|
wolffd@0
|
55
|
wolffd@0
|
56
|
wolffd@0
|
57
|