comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/Square/learn_square_hhmm_cts.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 % Try to learn a 3 level HHMM similar to mk_square_hhmm
2 % from hand-drawn squares.
3
4 % Because startprob should be shared for t=1:T,
5 % but in the DBN is shared for t=2:T, we train using a single long sequence.
6
7 discrete_obs = 0;
8 supervised = 1;
9 obs_finalF2 = 0;
10 % It is not possible to observe F2 if we learn
11 % because the update_ess method for hhmmF_CPD and hhmmQ_CPD assume
12 % the F nodes are always hidden (for speed).
13 % However, for generating, we might want to set the final F2=true
14 % to force all subroutines to finish.
15
16 seed = 1;
17 rand('state', seed);
18 randn('state', seed);
19
20 bnet = mk_square_hhmm(discrete_obs, 0);
21
22 ss = 6;
23 Q1 = 1; Q2 = 2; Q3 = 3; F3 = 4; F2 = 5; Onode = 6;
24 Qnodes = [Q1 Q2 Q3]; Fnodes = [F2 F3];
25 Qsizes = [2 4 1];
26
27 if supervised
28 bnet.observed = [Q1 Q2 Onode];
29 else
30 bnet.observed = [Onode];
31 end
32
33 if obs_finalF2
34 engine = jtree_dbn_inf_engine(bnet);
35 % can't use ndx version because sometimes F2 is hidden, sometimes observed
36 error('can''t observe F when learning')
37 else
38 if supervised
39 engine = jtree_ndx_dbn_inf_engine(bnet);
40 else
41 engine = jtree_hmm_inf_engine(bnet);
42 end
43 end
44
45 load 'square4_cases' % cases{seq}{i,t} for i=1:ss
46 %plot_square_hhmm(cases{1})
47 %long_seq = cat(2, cases{:});
48 train_cases = cases(1:2);
49 long_seq = cat(2, train_cases{:});
50 if ~supervised
51 T = size(long_seq,2);
52 for t=1:T
53 long_seq{Q1,t} = [];
54 long_seq{Q2,t} = [];
55 end
56 end
57 [bnet2, LL, engine2] = learn_params_dbn_em(engine, {long_seq}, 'max_iter', 2);
58
59 eclass = bnet2.equiv_class;
60 CPDO=struct(bnet2.CPD{eclass(Onode,1)});
61 mu = CPDO.mean;
62 Sigma = CPDO.cov;
63 CPDO_full = CPDO;
64
65 % force diagonal covs after training
66 for k=1:size(Sigma,3)
67 Sigma(:,:,k) = diag(diag(Sigma(:,:,k)));
68 end
69 bnet2.CPD{6} = set_fields(bnet.CPD{6}, 'cov', Sigma);
70
71 if 0
72 % visualize each model by concatenating means for each model for nsteps in a row
73 nsteps = 5;
74 ev = cell(ss, nsteps*prod(Qsizes(2:3)));
75 t = 1;
76 for q2=1:Qsizes(2)
77 for q3=1:Qsizes(3)
78 for i=1:nsteps
79 ev{Onode,t} = mu(:,q2,q3);
80 ev{Q2,t} = q2;
81 t = t + 1;
82 end
83 end
84 end
85 plot_square_hhmm(ev)
86 end
87
88 % bnet3 is the same as the learned model, except we will use it in testing mode
89 if supervised
90 bnet3 = bnet2;
91 bnet3.observed = [Onode];
92 engine3 = hmm_inf_engine(bnet3);
93 %engine3 = jtree_ndx_dbn_inf_engine(bnet3);
94 else
95 bnet3 = bnet2;
96 engine3 = engine2;
97 end
98
99 if 0
100 % segment whole sequence
101 mpe = calc_mpe_dbn(engine3, long_seq);
102 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []);
103 end
104
105 % segment each sequence
106 test_cases = cases(3:4);
107 for i=1:2
108 ev = test_cases{i};
109 T = size(ev, 2);
110 for t=1:T
111 ev{Q1,t} = [];
112 ev{Q2,t} = [];
113 end
114 %mpe = calc_mpe_dbn(engine3, ev);
115 mpe = find_mpe(engine3, ev)
116 subplot(1,2,i)
117 plot_square_hhmm(mpe)
118 %pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []);
119 q1s = cell2num(mpe(Q1,:));
120 h = hist(q1s, 1:Qsizes(1));
121 map_q1 = argmax(h);
122 str = sprintf('test seq %d is of type %d\n', i, map_q1);
123 title(str)
124 end
125
126
127 if 0
128 % Estimate gotten by couting transitions in the labelled data
129 % Note that a self transition shouldnt count if F2=off.
130 Q2ev = cell2num(ev(Q2,:));
131 Q2a = Q2ev(1:end-1);
132 Q2b = Q2ev(2:end);
133 counts = compute_counts([Q2a; Q2b], [4 4]);
134 end
135
136 eclass = bnet2.equiv_class;
137 CPDQ1=struct(bnet2.CPD{eclass(Q1,2)});
138 CPDQ2=struct(bnet2.CPD{eclass(Q2,2)});
139 CPDQ3=struct(bnet2.CPD{eclass(Q3,2)});
140 CPDF2=struct(bnet2.CPD{eclass(F2,1)});
141 CPDF3=struct(bnet2.CPD{eclass(F3,1)});
142
143
144 A=add_hhmm_end_state(CPDQ2.transprob, CPDF2.termprob(:,:,2));
145 squeeze(A(:,1,:));
146 CPDQ2.startprob;
147
148 if 0
149 S=struct(CPDF2.sub_CPD_term);
150 S.nsamples
151 reshape(S.counts, [2 4 2])
152 end