wolffd@0
|
1 % Try to learn a 3 level HHMM similar to mk_square_hhmm
|
wolffd@0
|
2 % from hand-drawn squares.
|
wolffd@0
|
3
|
wolffd@0
|
4 % Because startprob should be shared for t=1:T,
|
wolffd@0
|
5 % but in the DBN is shared for t=2:T, we train using a single long sequence.
|
wolffd@0
|
6
|
wolffd@0
|
7 discrete_obs = 0;
|
wolffd@0
|
8 supervised = 1;
|
wolffd@0
|
9 obs_finalF2 = 0;
|
wolffd@0
|
10 % It is not possible to observe F2 if we learn
|
wolffd@0
|
11 % because the update_ess method for hhmmF_CPD and hhmmQ_CPD assume
|
wolffd@0
|
12 % the F nodes are always hidden (for speed).
|
wolffd@0
|
13 % However, for generating, we might want to set the final F2=true
|
wolffd@0
|
14 % to force all subroutines to finish.
|
wolffd@0
|
15
|
wolffd@0
|
16 seed = 1;
|
wolffd@0
|
17 rand('state', seed);
|
wolffd@0
|
18 randn('state', seed);
|
wolffd@0
|
19
|
wolffd@0
|
20 bnet = mk_square_hhmm(discrete_obs, 0);
|
wolffd@0
|
21
|
wolffd@0
|
22 ss = 6;
|
wolffd@0
|
23 Q1 = 1; Q2 = 2; Q3 = 3; F3 = 4; F2 = 5; Onode = 6;
|
wolffd@0
|
24 Qnodes = [Q1 Q2 Q3]; Fnodes = [F2 F3];
|
wolffd@0
|
25 Qsizes = [2 4 1];
|
wolffd@0
|
26
|
wolffd@0
|
27 if supervised
|
wolffd@0
|
28 bnet.observed = [Q1 Q2 Onode];
|
wolffd@0
|
29 else
|
wolffd@0
|
30 bnet.observed = [Onode];
|
wolffd@0
|
31 end
|
wolffd@0
|
32
|
wolffd@0
|
33 if obs_finalF2
|
wolffd@0
|
34 engine = jtree_dbn_inf_engine(bnet);
|
wolffd@0
|
35 % can't use ndx version because sometimes F2 is hidden, sometimes observed
|
wolffd@0
|
36 error('can''t observe F when learning')
|
wolffd@0
|
37 else
|
wolffd@0
|
38 if supervised
|
wolffd@0
|
39 engine = jtree_ndx_dbn_inf_engine(bnet);
|
wolffd@0
|
40 else
|
wolffd@0
|
41 engine = jtree_hmm_inf_engine(bnet);
|
wolffd@0
|
42 end
|
wolffd@0
|
43 end
|
wolffd@0
|
44
|
wolffd@0
|
45 load 'square4_cases' % cases{seq}{i,t} for i=1:ss
|
wolffd@0
|
46 %plot_square_hhmm(cases{1})
|
wolffd@0
|
47 %long_seq = cat(2, cases{:});
|
wolffd@0
|
48 train_cases = cases(1:2);
|
wolffd@0
|
49 long_seq = cat(2, train_cases{:});
|
wolffd@0
|
50 if ~supervised
|
wolffd@0
|
51 T = size(long_seq,2);
|
wolffd@0
|
52 for t=1:T
|
wolffd@0
|
53 long_seq{Q1,t} = [];
|
wolffd@0
|
54 long_seq{Q2,t} = [];
|
wolffd@0
|
55 end
|
wolffd@0
|
56 end
|
wolffd@0
|
57 [bnet2, LL, engine2] = learn_params_dbn_em(engine, {long_seq}, 'max_iter', 2);
|
wolffd@0
|
58
|
wolffd@0
|
59 eclass = bnet2.equiv_class;
|
wolffd@0
|
60 CPDO=struct(bnet2.CPD{eclass(Onode,1)});
|
wolffd@0
|
61 mu = CPDO.mean;
|
wolffd@0
|
62 Sigma = CPDO.cov;
|
wolffd@0
|
63 CPDO_full = CPDO;
|
wolffd@0
|
64
|
wolffd@0
|
65 % force diagonal covs after training
|
wolffd@0
|
66 for k=1:size(Sigma,3)
|
wolffd@0
|
67 Sigma(:,:,k) = diag(diag(Sigma(:,:,k)));
|
wolffd@0
|
68 end
|
wolffd@0
|
69 bnet2.CPD{6} = set_fields(bnet.CPD{6}, 'cov', Sigma);
|
wolffd@0
|
70
|
wolffd@0
|
71 if 0
|
wolffd@0
|
72 % visualize each model by concatenating means for each model for nsteps in a row
|
wolffd@0
|
73 nsteps = 5;
|
wolffd@0
|
74 ev = cell(ss, nsteps*prod(Qsizes(2:3)));
|
wolffd@0
|
75 t = 1;
|
wolffd@0
|
76 for q2=1:Qsizes(2)
|
wolffd@0
|
77 for q3=1:Qsizes(3)
|
wolffd@0
|
78 for i=1:nsteps
|
wolffd@0
|
79 ev{Onode,t} = mu(:,q2,q3);
|
wolffd@0
|
80 ev{Q2,t} = q2;
|
wolffd@0
|
81 t = t + 1;
|
wolffd@0
|
82 end
|
wolffd@0
|
83 end
|
wolffd@0
|
84 end
|
wolffd@0
|
85 plot_square_hhmm(ev)
|
wolffd@0
|
86 end
|
wolffd@0
|
87
|
wolffd@0
|
88 % bnet3 is the same as the learned model, except we will use it in testing mode
|
wolffd@0
|
89 if supervised
|
wolffd@0
|
90 bnet3 = bnet2;
|
wolffd@0
|
91 bnet3.observed = [Onode];
|
wolffd@0
|
92 engine3 = hmm_inf_engine(bnet3);
|
wolffd@0
|
93 %engine3 = jtree_ndx_dbn_inf_engine(bnet3);
|
wolffd@0
|
94 else
|
wolffd@0
|
95 bnet3 = bnet2;
|
wolffd@0
|
96 engine3 = engine2;
|
wolffd@0
|
97 end
|
wolffd@0
|
98
|
wolffd@0
|
99 if 0
|
wolffd@0
|
100 % segment whole sequence
|
wolffd@0
|
101 mpe = calc_mpe_dbn(engine3, long_seq);
|
wolffd@0
|
102 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []);
|
wolffd@0
|
103 end
|
wolffd@0
|
104
|
wolffd@0
|
105 % segment each sequence
|
wolffd@0
|
106 test_cases = cases(3:4);
|
wolffd@0
|
107 for i=1:2
|
wolffd@0
|
108 ev = test_cases{i};
|
wolffd@0
|
109 T = size(ev, 2);
|
wolffd@0
|
110 for t=1:T
|
wolffd@0
|
111 ev{Q1,t} = [];
|
wolffd@0
|
112 ev{Q2,t} = [];
|
wolffd@0
|
113 end
|
wolffd@0
|
114 %mpe = calc_mpe_dbn(engine3, ev);
|
wolffd@0
|
115 mpe = find_mpe(engine3, ev)
|
wolffd@0
|
116 subplot(1,2,i)
|
wolffd@0
|
117 plot_square_hhmm(mpe)
|
wolffd@0
|
118 %pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []);
|
wolffd@0
|
119 q1s = cell2num(mpe(Q1,:));
|
wolffd@0
|
120 h = hist(q1s, 1:Qsizes(1));
|
wolffd@0
|
121 map_q1 = argmax(h);
|
wolffd@0
|
122 str = sprintf('test seq %d is of type %d\n', i, map_q1);
|
wolffd@0
|
123 title(str)
|
wolffd@0
|
124 end
|
wolffd@0
|
125
|
wolffd@0
|
126
|
wolffd@0
|
127 if 0
|
wolffd@0
|
128 % Estimate gotten by couting transitions in the labelled data
|
wolffd@0
|
129 % Note that a self transition shouldnt count if F2=off.
|
wolffd@0
|
130 Q2ev = cell2num(ev(Q2,:));
|
wolffd@0
|
131 Q2a = Q2ev(1:end-1);
|
wolffd@0
|
132 Q2b = Q2ev(2:end);
|
wolffd@0
|
133 counts = compute_counts([Q2a; Q2b], [4 4]);
|
wolffd@0
|
134 end
|
wolffd@0
|
135
|
wolffd@0
|
136 eclass = bnet2.equiv_class;
|
wolffd@0
|
137 CPDQ1=struct(bnet2.CPD{eclass(Q1,2)});
|
wolffd@0
|
138 CPDQ2=struct(bnet2.CPD{eclass(Q2,2)});
|
wolffd@0
|
139 CPDQ3=struct(bnet2.CPD{eclass(Q3,2)});
|
wolffd@0
|
140 CPDF2=struct(bnet2.CPD{eclass(F2,1)});
|
wolffd@0
|
141 CPDF3=struct(bnet2.CPD{eclass(F3,1)});
|
wolffd@0
|
142
|
wolffd@0
|
143
|
wolffd@0
|
144 A=add_hhmm_end_state(CPDQ2.transprob, CPDF2.termprob(:,:,2));
|
wolffd@0
|
145 squeeze(A(:,1,:));
|
wolffd@0
|
146 CPDQ2.startprob;
|
wolffd@0
|
147
|
wolffd@0
|
148 if 0
|
wolffd@0
|
149 S=struct(CPDF2.sub_CPD_term);
|
wolffd@0
|
150 S.nsamples
|
wolffd@0
|
151 reshape(S.counts, [2 4 2])
|
wolffd@0
|
152 end
|