Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/Square/learn_square_hhmm_discrete.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % Try to learn a 3 level HHMM similar to mk_square_hhmm | |
2 % from synthetic discrete sequences | |
3 | |
4 | |
5 discrete_obs = 1; | |
6 supervised = 0; | |
7 obs_finalF2 = 0; | |
8 | |
9 seed = 1; | |
10 rand('state', seed); | |
11 randn('state', seed); | |
12 | |
13 bnet_init = mk_square_hhmm(discrete_obs, 0); | |
14 | |
15 ss = 6; | |
16 Q1 = 1; Q2 = 2; Q3 = 3; F3 = 4; F2 = 5; Onode = 6; | |
17 Qnodes = [Q1 Q2 Q3]; Fnodes = [F2 F3]; | |
18 | |
19 if supervised | |
20 bnet_init.observed = [Q1 Q2 Onode]; | |
21 else | |
22 bnet_init.observed = [Onode]; | |
23 end | |
24 | |
25 if obs_finalF2 | |
26 engine_init = jtree_dbn_inf_engine(bnet_init); | |
27 % can't use ndx version because sometimes F2 is hidden, sometimes observed | |
28 error('can''t observe F when learning') | |
29 % It is not possible to observe F2 if we learn | |
30 % because the update_ess method for hhmmF_CPD and hhmmQ_CPD assume | |
31 % the F nodes are always hidden (for speed). | |
32 % However, for generating, we might want to set the final F2=true | |
33 % to force all subroutines to finish. | |
34 else | |
35 if supervised | |
36 engine_init = jtree_ndx_dbn_inf_engine(bnet_init); | |
37 else | |
38 engine_init = hmm_inf_engine(bnet_init); | |
39 end | |
40 end | |
41 | |
42 % generate some synthetic data (easier to debug) | |
43 chars = ['L', 'l', 'U', 'u', 'R', 'r', 'D', 'd']; | |
44 L=find(chars=='L'); l=find(chars=='l'); | |
45 U=find(chars=='U'); u=find(chars=='u'); | |
46 R=find(chars=='R'); r=find(chars=='r'); | |
47 D=find(chars=='D'); d=find(chars=='d'); | |
48 | |
49 cases = {}; | |
50 | |
51 T = 8; | |
52 ev = cell(ss, T); | |
53 ev(Onode,:) = num2cell([L l U u R r D d]); | |
54 if supervised | |
55 ev(Q1,:) = num2cell(1*ones(1,T)); | |
56 ev(Q2,:) = num2cell( [1 1 2 2 3 3 4 4]); | |
57 end | |
58 cases{1} = ev; | |
59 cases{3} = ev; | |
60 | |
61 T = 8; | |
62 ev = cell(ss, T); | |
63 %we start with R then r, even though we are running the model 'backwards'! | |
64 ev(Onode,:) = num2cell([R r U u L l D d]); | |
65 | |
66 if supervised | |
67 ev(Q1,:) = num2cell(2*ones(1,T)); | |
68 ev(Q2,:) = num2cell( [3 3 2 2 1 1 4 4]); | |
69 end | |
70 | |
71 cases{2} = ev; | |
72 cases{4} = ev; | |
73 | |
74 if obs_finalF2 | |
75 for i=1:length(cases) | |
76 T = size(cases{i},2); | |
77 cases{i}(F2,T)={2}; % force F2 to be finished at end of seq | |
78 end | |
79 end | |
80 | |
81 | |
82 % startprob should be shared for t=1:T, | |
83 % but in the DBN it is shared for t=2:T, | |
84 % so we train using a single long sequence. | |
85 long_seq = cat(2, cases{:}); | |
86 [bnet_learned, LL, engine_learned] = ... | |
87 learn_params_dbn_em(engine_init, {long_seq}, 'max_iter', 200); | |
88 | |
89 % figure out which subsequence each model is responsible for | |
90 mpe = calc_mpe_dbn(engine_learned, long_seq); | |
91 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, chars); | |
92 | |
93 | |
94 % The "true" segmentation of the training sequence is | |
95 % Q1: 1 2 | |
96 % O: L l U u R r D d | R r U u L l D d | etc. | |
97 % | |
98 % When we learn in a supervised fashion, we recover the "truth". | |
99 | |
100 % When we learn in an unsupervised fashion with seed=1, we get | |
101 % Q1: 2 1 | |
102 % O: L l U u R r D d R r | U u L l D d | etc. | |
103 % | |
104 % This means for model 1: | |
105 % starts in state 2 | |
106 % transitions 2->1, 1->4, 4->e, 3->2 | |
107 % | |
108 % For model 2, | |
109 % starts in state 1 | |
110 % transitions 1->2, 2->3, 3->4 or e, 4->3 | |
111 | |
112 % examine the params | |
113 eclass = bnet_learned.equiv_class; | |
114 CPDQ1=struct(bnet_learned.CPD{eclass(Q1,2)}); | |
115 CPDQ2=struct(bnet_learned.CPD{eclass(Q2,2)}); | |
116 CPDQ3=struct(bnet_learned.CPD{eclass(Q3,2)}); | |
117 CPDF2=struct(bnet_learned.CPD{eclass(F2,1)}); | |
118 CPDF3=struct(bnet_learned.CPD{eclass(F3,1)}); | |
119 CPDO=struct(bnet_learned.CPD{eclass(Onode,1)}); | |
120 | |
121 A_learned =add_hhmm_end_state(CPDQ2.transprob, CPDF2.termprob(:,:,2)); | |
122 squeeze(A_learned(:,1,:)) | |
123 squeeze(A_learned(:,2,:)) | |
124 | |
125 | |
126 % Does the "true" model have higher likelihood than the learned one? | |
127 % i.e., Does the unsupervised method learn the wrong model because | |
128 % we have the wrong cost fn, or because of local minima? | |
129 | |
130 bnet_true = mk_square_hhmm(discrete_obs,1); | |
131 | |
132 % examine the params | |
133 eclass = bnet_learned.equiv_class; | |
134 CPDQ1_true=struct(bnet_true.CPD{eclass(Q1,2)}); | |
135 CPDQ2_true=struct(bnet_true.CPD{eclass(Q2,2)}); | |
136 CPDQ3_true=struct(bnet_true.CPD{eclass(Q3,2)}); | |
137 CPDF2_true=struct(bnet_true.CPD{eclass(F2,1)}); | |
138 CPDF3_true=struct(bnet_true.CPD{eclass(F3,1)}); | |
139 | |
140 A_true =add_hhmm_end_state(CPDQ2_true.transprob, CPDF2_true.termprob(:,:,2)); | |
141 squeeze(A_true(:,1,:)) | |
142 | |
143 | |
144 if supervised | |
145 engine_true = jtree_ndx_dbn_inf_engine(bnet_true); | |
146 else | |
147 engine_true = hmm_inf_engine(bnet_true); | |
148 end | |
149 | |
150 %[engine_learned, ll_learned] = enter_evidence(engine_learned, long_seq); | |
151 %[engine_true, ll_true] = enter_evidence(engine_true, long_seq); | |
152 [engine_learned, ll_learned] = enter_evidence(engine_learned, cases{2}); | |
153 [engine_true, ll_true] = enter_evidence(engine_true, cases{2}); | |
154 ll_learned | |
155 ll_true | |
156 | |
157 | |
158 % remove concatentation artefacts | |
159 ll_learned = 0; | |
160 ll_true = 0; | |
161 for m=1:length(cases) | |
162 [engine_learned, ll_learned_tmp] = enter_evidence(engine_learned, cases{m}); | |
163 [engine_true, ll_true_tmp] = enter_evidence(engine_true, cases{m}); | |
164 ll_learned = ll_learned + ll_learned_tmp; | |
165 ll_true = ll_true + ll_true_tmp; | |
166 end | |
167 ll_learned | |
168 ll_true | |
169 | |
170 % In both cases, ll_learned >> ll_true | |
171 % which shows we are using the wrong cost function! |