Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/Square/Old/learn_square_hhmm.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % Learn a 3 level HHMM similar to mk_square_hhmm | |
2 | |
3 % Because startprob should be shared for t=1:T, | |
4 % but in the DBN is shared for t=2:T, we train using a single long sequence. | |
5 | |
6 discrete_obs = 0; | |
7 supervised = 1; | |
8 obs_finalF2 = 0; | |
9 % It is not possible to observe F2 if we learn | |
10 % because the update_ess method for hhmmF_CPD and hhmmQ_CPD assume | |
11 % the F nodes are always hidden (for speed). | |
12 % However, for generating, we might want to set the final F2=true | |
13 % to force all subroutines to finish. | |
14 | |
15 ss = 6; | |
16 Q1 = 1; Q2 = 2; Q3 = 3; F3 = 4; F2 = 5; Onode = 6; | |
17 Qnodes = [Q1 Q2 Q3]; Fnodes = [F2 F3]; | |
18 | |
19 seed = 1; | |
20 rand('state', seed); | |
21 randn('state', seed); | |
22 | |
23 if discrete_obs | |
24 Qsizes = [2 4 2]; | |
25 else | |
26 Qsizes = [2 4 1]; | |
27 end | |
28 | |
29 D = 3; | |
30 Qnodes = 1:D; | |
31 startprob = cell(1,D); | |
32 transprob = cell(1,D); | |
33 termprob = cell(1,D); | |
34 | |
35 startprob{1} = 'unif'; | |
36 transprob{1} = 'unif'; | |
37 | |
38 % In the unsupervised case, it is essential that we break symmetry | |
39 % in the initial param estimates. | |
40 %startprob{2} = 'unif'; | |
41 %transprob{2} = 'unif'; | |
42 %termprob{2} = 'unif'; | |
43 startprob{2} = 'rnd'; | |
44 transprob{2} = 'rnd'; | |
45 termprob{2} = 'rnd'; | |
46 | |
47 leftright = 0; | |
48 if leftright | |
49 % Initialise base-level models as left-right. | |
50 % If we initialise with delta functions, | |
51 % they will remain delat funcitons after learning | |
52 startprob{3} = 'leftstart'; | |
53 transprob{3} = 'leftright'; | |
54 termprob{3} = 'rightstop'; | |
55 else | |
56 % If we want to be able to run a base-level model backwards... | |
57 startprob{3} = 'rnd'; | |
58 transprob{3} = 'rnd'; | |
59 termprob{3} = 'rnd'; | |
60 end | |
61 | |
62 if discrete_obs | |
63 % Initialise observations of lowest level primitives in a way which we can interpret | |
64 chars = ['L', 'l', 'U', 'u', 'R', 'r', 'D', 'd']; | |
65 L=find(chars=='L'); l=find(chars=='l'); | |
66 U=find(chars=='U'); u=find(chars=='u'); | |
67 R=find(chars=='R'); r=find(chars=='r'); | |
68 D=find(chars=='D'); d=find(chars=='d'); | |
69 Osize = length(chars); | |
70 | |
71 p = 0.9; | |
72 obsprob = (1-p)*ones([4 2 Osize]); | |
73 % Q2 Q3 O | |
74 obsprob(1, 1, L) = p; | |
75 obsprob(1, 2, l) = p; | |
76 obsprob(2, 1, U) = p; | |
77 obsprob(2, 2, u) = p; | |
78 obsprob(3, 1, R) = p; | |
79 obsprob(3, 2, r) = p; | |
80 obsprob(4, 1, D) = p; | |
81 obsprob(4, 2, d) = p; | |
82 obsprob = mk_stochastic(obsprob); | |
83 Oargs = {'CPT', obsprob}; | |
84 | |
85 else | |
86 % Initialise means of lowest level primitives in a way which we can interpret | |
87 % These means are little vectors in the east, south, west, north directions. | |
88 % (left-right=east, up-down=south, right-left=west, down-up=north) | |
89 Osize = 2; | |
90 mu = zeros(2, Qsizes(2), Qsizes(3)); | |
91 noise = 0; | |
92 scale = 3; | |
93 for q3=1:Qsizes(3) | |
94 mu(:, 1, q3) = scale*[1;0] + noise*rand(2,1); | |
95 end | |
96 for q3=1:Qsizes(3) | |
97 mu(:, 2, q3) = scale*[0;-1] + noise*rand(2,1); | |
98 end | |
99 for q3=1:Qsizes(3) | |
100 mu(:, 3, q3) = scale*[-1;0] + noise*rand(2,1); | |
101 end | |
102 for q3=1:Qsizes(3) | |
103 mu(:, 4, q3) = scale*[0;1] + noise*rand(2,1); | |
104 end | |
105 Sigma = repmat(reshape(scale*eye(2), [2 2 1 1 ]), [1 1 Qsizes(2) Qsizes(3)]); | |
106 Oargs = {'mean', mu, 'cov', Sigma, 'cov_type', 'diag'}; | |
107 end | |
108 | |
109 bnet = mk_hhmm('Qsizes', Qsizes, 'Osize', Osize', 'discrete_obs', discrete_obs,... | |
110 'Oargs', Oargs, 'Ops', Qnodes(2:3), ... | |
111 'startprob', startprob, 'transprob', transprob, 'termprob', termprob); | |
112 | |
113 if supervised | |
114 bnet.observed = [Q1 Q2 Onode]; | |
115 else | |
116 bnet.observed = [Onode]; | |
117 end | |
118 | |
119 if obs_finalF2 | |
120 engine = jtree_dbn_inf_engine(bnet); | |
121 % can't use ndx version because sometimes F2 is hidden, sometimes observed | |
122 error('can''t observe F when learning') | |
123 else | |
124 if supervised | |
125 engine = jtree_ndx_dbn_inf_engine(bnet); | |
126 else | |
127 engine = jtree_hmm_inf_engine(bnet); | |
128 end | |
129 end | |
130 | |
131 if discrete_obs | |
132 % generate some synthetic data (easier to debug) | |
133 cases = {}; | |
134 | |
135 T = 8; | |
136 ev = cell(ss, T); | |
137 ev(Onode,:) = num2cell([L l U u R r D d]); | |
138 if supervised | |
139 ev(Q1,:) = num2cell(1*ones(1,T)); | |
140 ev(Q2,:) = num2cell( [1 1 2 2 3 3 4 4]); | |
141 end | |
142 cases{1} = ev; | |
143 cases{3} = ev; | |
144 | |
145 T = 8; | |
146 ev = cell(ss, T); | |
147 if leftright % base model is left-right | |
148 ev(Onode,:) = num2cell([R r U u L l D d]); | |
149 else | |
150 ev(Onode,:) = num2cell([r R u U l L d D]); | |
151 end | |
152 if supervised | |
153 ev(Q1,:) = num2cell(2*ones(1,T)); | |
154 ev(Q2,:) = num2cell( [3 3 2 2 1 1 4 4]); | |
155 end | |
156 | |
157 cases{2} = ev; | |
158 cases{4} = ev; | |
159 | |
160 if obs_finalF2 | |
161 for i=1:length(cases) | |
162 T = size(cases{i},2); | |
163 cases{i}(F2,T)={2}; % force F2 to be finished at end of seq | |
164 end | |
165 end | |
166 | |
167 if 0 | |
168 ev = cases{4}; | |
169 engine2 = enter_evidence(engine2, ev); | |
170 T = size(ev,2); | |
171 for t=1:T | |
172 m=marginal_family(engine2, F2, t); | |
173 fprintf('t=%d\n', t); | |
174 reshape(m.T, [2 2]) | |
175 end | |
176 end | |
177 | |
178 % [bnet2, LL] = learn_params_dbn_em(engine, cases, 'max_iter', 10); | |
179 long_seq = cat(2, cases{:}); | |
180 [bnet2, LL, engine2] = learn_params_dbn_em(engine, {long_seq}, 'max_iter', 200); | |
181 | |
182 % figure out which subsequence each model is responsible for | |
183 mpe = calc_mpe_dbn(engine2, long_seq); | |
184 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, chars); | |
185 | |
186 else | |
187 load 'square4_cases' % cases{seq}{i,t} for i=1:ss | |
188 %plot_square_hhmm(cases{1}) | |
189 %long_seq = cat(2, cases{:}); | |
190 train_cases = cases(1:2); | |
191 long_seq = cat(2, train_cases{:}); | |
192 if ~supervised | |
193 T = size(long_seq,2); | |
194 for t=1:T | |
195 long_seq{Q1,t} = []; | |
196 long_seq{Q2,t} = []; | |
197 end | |
198 end | |
199 [bnet2, LL, engine2] = learn_params_dbn_em(engine, {long_seq}, 'max_iter', 100); | |
200 | |
201 CPDO=struct(bnet2.CPD{eclass(Onode,1)}); | |
202 mu = CPDO.mean; | |
203 Sigma = CPDO.cov; | |
204 CPDO_full = CPDO; | |
205 | |
206 % force diagonal covs after training | |
207 for k=1:size(Sigma,3) | |
208 Sigma(:,:,k) = diag(diag(Sigma(:,:,k))); | |
209 end | |
210 bnet2.CPD{6} = set_fields(bnet.CPD{6}, 'cov', Sigma); | |
211 | |
212 if 0 | |
213 % visualize each model by concatenating means for each model for nsteps in a row | |
214 nsteps = 5; | |
215 ev = cell(ss, nsteps*prod(Qsizes(2:3))); | |
216 t = 1; | |
217 for q2=1:Qsizes(2) | |
218 for q3=1:Qsizes(3) | |
219 for i=1:nsteps | |
220 ev{Onode,t} = mu(:,q2,q3); | |
221 ev{Q2,t} = q2; | |
222 t = t + 1; | |
223 end | |
224 end | |
225 end | |
226 plot_square_hhmm(ev) | |
227 end | |
228 | |
229 % bnet3 is the same as the learned model, except we will use it in testing mode | |
230 if supervised | |
231 bnet3 = bnet2; | |
232 bnet3.observed = [Onode]; | |
233 engine3 = hmm_inf_engine(bnet3); | |
234 %engine3 = jtree_ndx_dbn_inf_engine(bnet3); | |
235 else | |
236 bnet3 = bnet2; | |
237 engine3 = engine2; | |
238 end | |
239 | |
240 if 0 | |
241 % segment whole sequence | |
242 mpe = calc_mpe_dbn(engine3, long_seq); | |
243 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []); | |
244 end | |
245 | |
246 % segment each sequence | |
247 test_cases = cases(3:4); | |
248 for i=1:2 | |
249 ev = test_cases{i}; | |
250 T = size(ev, 2); | |
251 for t=1:T | |
252 ev{Q1,t} = []; | |
253 ev{Q2,t} = []; | |
254 end | |
255 mpe = calc_mpe_dbn(engine3, ev); | |
256 subplot(1,2,i) | |
257 plot_square_hhmm(mpe) | |
258 %pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []); | |
259 q1s = cell2num(mpe(Q1,:)); | |
260 h = hist(q1s, 1:Qsizes(1)); | |
261 map_q1 = argmax(h); | |
262 str = sprintf('test seq %d is of type %d\n', i, map_q1); | |
263 title(str) | |
264 end | |
265 | |
266 end | |
267 | |
268 if 0 | |
269 % Estimate gotten by couting transitions in the labelled data | |
270 % Note that a self transition shouldnt count if F2=off. | |
271 Q2ev = cell2num(ev(Q2,:)); | |
272 Q2a = Q2ev(1:end-1); | |
273 Q2b = Q2ev(2:end); | |
274 counts = compute_counts([Q2a; Q2b], [4 4]); | |
275 end | |
276 | |
277 eclass = bnet2.equiv_class; | |
278 CPDQ1=struct(bnet2.CPD{eclass(Q1,2)}); | |
279 CPDQ2=struct(bnet2.CPD{eclass(Q2,2)}); | |
280 CPDQ3=struct(bnet2.CPD{eclass(Q3,2)}); | |
281 CPDF2=struct(bnet2.CPD{eclass(F2,1)}); | |
282 CPDF3=struct(bnet2.CPD{eclass(F3,1)}); | |
283 | |
284 | |
285 A=add_hhmm_end_state(CPDQ2.transprob, CPDF2.termprob(:,:,2)); | |
286 squeeze(A(:,1,:)) | |
287 squeeze(A(:,2,:)) | |
288 CPDQ2.startprob | |
289 | |
290 if 0 | |
291 S=struct(CPDF2.sub_CPD_term); | |
292 S.nsamples | |
293 reshape(S.counts, [2 4 2]) | |
294 end |