| wolffd@0 | 1 % Sigmoid Belief Net | 
| wolffd@0 | 2 | 
| wolffd@0 | 3 clear all | 
| wolffd@0 | 4 clc | 
| wolffd@0 | 5 dum1 = 1; | 
| wolffd@0 | 6 dum2 = 2; | 
| wolffd@0 | 7 dum3 = 3; | 
| wolffd@0 | 8 Q1 = 4; | 
| wolffd@0 | 9 Q2 = 5; | 
| wolffd@0 | 10 Y = 6; | 
| wolffd@0 | 11 dag = zeros(6,6); | 
| wolffd@0 | 12 dag(dum1,[Q1 Y]) = 1; | 
| wolffd@0 | 13 dag(dum2, Q2)=1; | 
| wolffd@0 | 14 dag(dum3, [Q1 Q2])=1; | 
| wolffd@0 | 15 dag(Q1,[Q2 Y]) = 1; | 
| wolffd@0 | 16 dag(Q2, Y)=1; | 
| wolffd@0 | 17 | 
| wolffd@0 | 18 ns = [2 2 3 3 4 3]; | 
| wolffd@0 | 19 dnodes = [1:6]; | 
| wolffd@0 | 20 bnet = mk_bnet(dag,ns, dnodes); | 
| wolffd@0 | 21 | 
| wolffd@0 | 22 rand('state',0); randn('state',0); | 
| wolffd@0 | 23 n_iter=10; | 
| wolffd@0 | 24 clamped=0; | 
| wolffd@0 | 25 | 
| wolffd@0 | 26 bnet.CPD{1} = tabular_CPD(bnet, 1); | 
| wolffd@0 | 27 bnet.CPD{2} = tabular_CPD(bnet, 2); | 
| wolffd@0 | 28 bnet.CPD{3} = tabular_CPD(bnet, 3); | 
| wolffd@0 | 29 % CPD = dsoftmax_CPD(bnet, self, dummy_pars, w, b, clamped, max_iter, verbose, wthresh,... | 
| wolffd@0 | 30 %    llthresh, approx_hess) | 
| wolffd@0 | 31 bnet.CPD{4} = softmax_CPD(bnet, 4, 'discrete', [1 3]); | 
| wolffd@0 | 32 bnet.CPD{5} = softmax_CPD(bnet, 5, 'discrete', [2 3]); | 
| wolffd@0 | 33 bnet.CPD{6} = softmax_CPD(bnet, 6, 'discrete', [1 4]); | 
| wolffd@0 | 34 | 
| wolffd@0 | 35 T=5; | 
| wolffd@0 | 36 cases = cell(6, T); | 
| wolffd@0 | 37 cases(1,:)=num2cell(round(rand(1,T)*1)+1); | 
| wolffd@0 | 38 %cases(2,:)=num2cell(round(rand(1,T)*1)+1); | 
| wolffd@0 | 39 cases(3,:)=num2cell(round(rand(1,T)*2)+1); | 
| wolffd@0 | 40 cases(4,:)=num2cell(round(rand(1,T)*2)+1); | 
| wolffd@0 | 41 %cases(5,:)=num2cell(round(rand(1,T)*3)+1); | 
| wolffd@0 | 42 cases(6,:)=num2cell(round(rand(1,T)*2)+1); | 
| wolffd@0 | 43 | 
| wolffd@0 | 44 engine = jtree_inf_engine(bnet); | 
| wolffd@0 | 45 | 
| wolffd@0 | 46 [engine, loglik] = enter_evidence(engine, cases); | 
| wolffd@0 | 47 | 
| wolffd@0 | 48 disp('learning-------------------------------------------') | 
| wolffd@0 | 49 [bnet2, LL2, eng2] = learn_params_em(engine, cases, n_iter); |