Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/Square/learn_square_hhmm_cts.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % Try to learn a 3 level HHMM similar to mk_square_hhmm | |
2 % from hand-drawn squares. | |
3 | |
4 % Because startprob should be shared for t=1:T, | |
5 % but in the DBN is shared for t=2:T, we train using a single long sequence. | |
6 | |
7 discrete_obs = 0; | |
8 supervised = 1; | |
9 obs_finalF2 = 0; | |
10 % It is not possible to observe F2 if we learn | |
11 % because the update_ess method for hhmmF_CPD and hhmmQ_CPD assume | |
12 % the F nodes are always hidden (for speed). | |
13 % However, for generating, we might want to set the final F2=true | |
14 % to force all subroutines to finish. | |
15 | |
16 seed = 1; | |
17 rand('state', seed); | |
18 randn('state', seed); | |
19 | |
20 bnet = mk_square_hhmm(discrete_obs, 0); | |
21 | |
22 ss = 6; | |
23 Q1 = 1; Q2 = 2; Q3 = 3; F3 = 4; F2 = 5; Onode = 6; | |
24 Qnodes = [Q1 Q2 Q3]; Fnodes = [F2 F3]; | |
25 Qsizes = [2 4 1]; | |
26 | |
27 if supervised | |
28 bnet.observed = [Q1 Q2 Onode]; | |
29 else | |
30 bnet.observed = [Onode]; | |
31 end | |
32 | |
33 if obs_finalF2 | |
34 engine = jtree_dbn_inf_engine(bnet); | |
35 % can't use ndx version because sometimes F2 is hidden, sometimes observed | |
36 error('can''t observe F when learning') | |
37 else | |
38 if supervised | |
39 engine = jtree_ndx_dbn_inf_engine(bnet); | |
40 else | |
41 engine = jtree_hmm_inf_engine(bnet); | |
42 end | |
43 end | |
44 | |
45 load 'square4_cases' % cases{seq}{i,t} for i=1:ss | |
46 %plot_square_hhmm(cases{1}) | |
47 %long_seq = cat(2, cases{:}); | |
48 train_cases = cases(1:2); | |
49 long_seq = cat(2, train_cases{:}); | |
50 if ~supervised | |
51 T = size(long_seq,2); | |
52 for t=1:T | |
53 long_seq{Q1,t} = []; | |
54 long_seq{Q2,t} = []; | |
55 end | |
56 end | |
57 [bnet2, LL, engine2] = learn_params_dbn_em(engine, {long_seq}, 'max_iter', 2); | |
58 | |
59 eclass = bnet2.equiv_class; | |
60 CPDO=struct(bnet2.CPD{eclass(Onode,1)}); | |
61 mu = CPDO.mean; | |
62 Sigma = CPDO.cov; | |
63 CPDO_full = CPDO; | |
64 | |
65 % force diagonal covs after training | |
66 for k=1:size(Sigma,3) | |
67 Sigma(:,:,k) = diag(diag(Sigma(:,:,k))); | |
68 end | |
69 bnet2.CPD{6} = set_fields(bnet.CPD{6}, 'cov', Sigma); | |
70 | |
71 if 0 | |
72 % visualize each model by concatenating means for each model for nsteps in a row | |
73 nsteps = 5; | |
74 ev = cell(ss, nsteps*prod(Qsizes(2:3))); | |
75 t = 1; | |
76 for q2=1:Qsizes(2) | |
77 for q3=1:Qsizes(3) | |
78 for i=1:nsteps | |
79 ev{Onode,t} = mu(:,q2,q3); | |
80 ev{Q2,t} = q2; | |
81 t = t + 1; | |
82 end | |
83 end | |
84 end | |
85 plot_square_hhmm(ev) | |
86 end | |
87 | |
88 % bnet3 is the same as the learned model, except we will use it in testing mode | |
89 if supervised | |
90 bnet3 = bnet2; | |
91 bnet3.observed = [Onode]; | |
92 engine3 = hmm_inf_engine(bnet3); | |
93 %engine3 = jtree_ndx_dbn_inf_engine(bnet3); | |
94 else | |
95 bnet3 = bnet2; | |
96 engine3 = engine2; | |
97 end | |
98 | |
99 if 0 | |
100 % segment whole sequence | |
101 mpe = calc_mpe_dbn(engine3, long_seq); | |
102 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []); | |
103 end | |
104 | |
105 % segment each sequence | |
106 test_cases = cases(3:4); | |
107 for i=1:2 | |
108 ev = test_cases{i}; | |
109 T = size(ev, 2); | |
110 for t=1:T | |
111 ev{Q1,t} = []; | |
112 ev{Q2,t} = []; | |
113 end | |
114 %mpe = calc_mpe_dbn(engine3, ev); | |
115 mpe = find_mpe(engine3, ev) | |
116 subplot(1,2,i) | |
117 plot_square_hhmm(mpe) | |
118 %pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []); | |
119 q1s = cell2num(mpe(Q1,:)); | |
120 h = hist(q1s, 1:Qsizes(1)); | |
121 map_q1 = argmax(h); | |
122 str = sprintf('test seq %d is of type %d\n', i, map_q1); | |
123 title(str) | |
124 end | |
125 | |
126 | |
127 if 0 | |
128 % Estimate gotten by couting transitions in the labelled data | |
129 % Note that a self transition shouldnt count if F2=off. | |
130 Q2ev = cell2num(ev(Q2,:)); | |
131 Q2a = Q2ev(1:end-1); | |
132 Q2b = Q2ev(2:end); | |
133 counts = compute_counts([Q2a; Q2b], [4 4]); | |
134 end | |
135 | |
136 eclass = bnet2.equiv_class; | |
137 CPDQ1=struct(bnet2.CPD{eclass(Q1,2)}); | |
138 CPDQ2=struct(bnet2.CPD{eclass(Q2,2)}); | |
139 CPDQ3=struct(bnet2.CPD{eclass(Q3,2)}); | |
140 CPDF2=struct(bnet2.CPD{eclass(F2,1)}); | |
141 CPDF3=struct(bnet2.CPD{eclass(F3,1)}); | |
142 | |
143 | |
144 A=add_hhmm_end_state(CPDQ2.transprob, CPDF2.termprob(:,:,2)); | |
145 squeeze(A(:,1,:)); | |
146 CPDQ2.startprob; | |
147 | |
148 if 0 | |
149 S=struct(CPDF2.sub_CPD_term); | |
150 S.nsamples | |
151 reshape(S.counts, [2 4 2]) | |
152 end |