comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/Square/Old/learn_square_hhmm.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 % Learn a 3 level HHMM similar to mk_square_hhmm
2
3 % Because startprob should be shared for t=1:T,
4 % but in the DBN is shared for t=2:T, we train using a single long sequence.
5
6 discrete_obs = 0;
7 supervised = 1;
8 obs_finalF2 = 0;
9 % It is not possible to observe F2 if we learn
10 % because the update_ess method for hhmmF_CPD and hhmmQ_CPD assume
11 % the F nodes are always hidden (for speed).
12 % However, for generating, we might want to set the final F2=true
13 % to force all subroutines to finish.
14
15 ss = 6;
16 Q1 = 1; Q2 = 2; Q3 = 3; F3 = 4; F2 = 5; Onode = 6;
17 Qnodes = [Q1 Q2 Q3]; Fnodes = [F2 F3];
18
19 seed = 1;
20 rand('state', seed);
21 randn('state', seed);
22
23 if discrete_obs
24 Qsizes = [2 4 2];
25 else
26 Qsizes = [2 4 1];
27 end
28
29 D = 3;
30 Qnodes = 1:D;
31 startprob = cell(1,D);
32 transprob = cell(1,D);
33 termprob = cell(1,D);
34
35 startprob{1} = 'unif';
36 transprob{1} = 'unif';
37
38 % In the unsupervised case, it is essential that we break symmetry
39 % in the initial param estimates.
40 %startprob{2} = 'unif';
41 %transprob{2} = 'unif';
42 %termprob{2} = 'unif';
43 startprob{2} = 'rnd';
44 transprob{2} = 'rnd';
45 termprob{2} = 'rnd';
46
47 leftright = 0;
48 if leftright
49 % Initialise base-level models as left-right.
50 % If we initialise with delta functions,
51 % they will remain delat funcitons after learning
52 startprob{3} = 'leftstart';
53 transprob{3} = 'leftright';
54 termprob{3} = 'rightstop';
55 else
56 % If we want to be able to run a base-level model backwards...
57 startprob{3} = 'rnd';
58 transprob{3} = 'rnd';
59 termprob{3} = 'rnd';
60 end
61
62 if discrete_obs
63 % Initialise observations of lowest level primitives in a way which we can interpret
64 chars = ['L', 'l', 'U', 'u', 'R', 'r', 'D', 'd'];
65 L=find(chars=='L'); l=find(chars=='l');
66 U=find(chars=='U'); u=find(chars=='u');
67 R=find(chars=='R'); r=find(chars=='r');
68 D=find(chars=='D'); d=find(chars=='d');
69 Osize = length(chars);
70
71 p = 0.9;
72 obsprob = (1-p)*ones([4 2 Osize]);
73 % Q2 Q3 O
74 obsprob(1, 1, L) = p;
75 obsprob(1, 2, l) = p;
76 obsprob(2, 1, U) = p;
77 obsprob(2, 2, u) = p;
78 obsprob(3, 1, R) = p;
79 obsprob(3, 2, r) = p;
80 obsprob(4, 1, D) = p;
81 obsprob(4, 2, d) = p;
82 obsprob = mk_stochastic(obsprob);
83 Oargs = {'CPT', obsprob};
84
85 else
86 % Initialise means of lowest level primitives in a way which we can interpret
87 % These means are little vectors in the east, south, west, north directions.
88 % (left-right=east, up-down=south, right-left=west, down-up=north)
89 Osize = 2;
90 mu = zeros(2, Qsizes(2), Qsizes(3));
91 noise = 0;
92 scale = 3;
93 for q3=1:Qsizes(3)
94 mu(:, 1, q3) = scale*[1;0] + noise*rand(2,1);
95 end
96 for q3=1:Qsizes(3)
97 mu(:, 2, q3) = scale*[0;-1] + noise*rand(2,1);
98 end
99 for q3=1:Qsizes(3)
100 mu(:, 3, q3) = scale*[-1;0] + noise*rand(2,1);
101 end
102 for q3=1:Qsizes(3)
103 mu(:, 4, q3) = scale*[0;1] + noise*rand(2,1);
104 end
105 Sigma = repmat(reshape(scale*eye(2), [2 2 1 1 ]), [1 1 Qsizes(2) Qsizes(3)]);
106 Oargs = {'mean', mu, 'cov', Sigma, 'cov_type', 'diag'};
107 end
108
109 bnet = mk_hhmm('Qsizes', Qsizes, 'Osize', Osize', 'discrete_obs', discrete_obs,...
110 'Oargs', Oargs, 'Ops', Qnodes(2:3), ...
111 'startprob', startprob, 'transprob', transprob, 'termprob', termprob);
112
113 if supervised
114 bnet.observed = [Q1 Q2 Onode];
115 else
116 bnet.observed = [Onode];
117 end
118
119 if obs_finalF2
120 engine = jtree_dbn_inf_engine(bnet);
121 % can't use ndx version because sometimes F2 is hidden, sometimes observed
122 error('can''t observe F when learning')
123 else
124 if supervised
125 engine = jtree_ndx_dbn_inf_engine(bnet);
126 else
127 engine = jtree_hmm_inf_engine(bnet);
128 end
129 end
130
131 if discrete_obs
132 % generate some synthetic data (easier to debug)
133 cases = {};
134
135 T = 8;
136 ev = cell(ss, T);
137 ev(Onode,:) = num2cell([L l U u R r D d]);
138 if supervised
139 ev(Q1,:) = num2cell(1*ones(1,T));
140 ev(Q2,:) = num2cell( [1 1 2 2 3 3 4 4]);
141 end
142 cases{1} = ev;
143 cases{3} = ev;
144
145 T = 8;
146 ev = cell(ss, T);
147 if leftright % base model is left-right
148 ev(Onode,:) = num2cell([R r U u L l D d]);
149 else
150 ev(Onode,:) = num2cell([r R u U l L d D]);
151 end
152 if supervised
153 ev(Q1,:) = num2cell(2*ones(1,T));
154 ev(Q2,:) = num2cell( [3 3 2 2 1 1 4 4]);
155 end
156
157 cases{2} = ev;
158 cases{4} = ev;
159
160 if obs_finalF2
161 for i=1:length(cases)
162 T = size(cases{i},2);
163 cases{i}(F2,T)={2}; % force F2 to be finished at end of seq
164 end
165 end
166
167 if 0
168 ev = cases{4};
169 engine2 = enter_evidence(engine2, ev);
170 T = size(ev,2);
171 for t=1:T
172 m=marginal_family(engine2, F2, t);
173 fprintf('t=%d\n', t);
174 reshape(m.T, [2 2])
175 end
176 end
177
178 % [bnet2, LL] = learn_params_dbn_em(engine, cases, 'max_iter', 10);
179 long_seq = cat(2, cases{:});
180 [bnet2, LL, engine2] = learn_params_dbn_em(engine, {long_seq}, 'max_iter', 200);
181
182 % figure out which subsequence each model is responsible for
183 mpe = calc_mpe_dbn(engine2, long_seq);
184 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, chars);
185
186 else
187 load 'square4_cases' % cases{seq}{i,t} for i=1:ss
188 %plot_square_hhmm(cases{1})
189 %long_seq = cat(2, cases{:});
190 train_cases = cases(1:2);
191 long_seq = cat(2, train_cases{:});
192 if ~supervised
193 T = size(long_seq,2);
194 for t=1:T
195 long_seq{Q1,t} = [];
196 long_seq{Q2,t} = [];
197 end
198 end
199 [bnet2, LL, engine2] = learn_params_dbn_em(engine, {long_seq}, 'max_iter', 100);
200
201 CPDO=struct(bnet2.CPD{eclass(Onode,1)});
202 mu = CPDO.mean;
203 Sigma = CPDO.cov;
204 CPDO_full = CPDO;
205
206 % force diagonal covs after training
207 for k=1:size(Sigma,3)
208 Sigma(:,:,k) = diag(diag(Sigma(:,:,k)));
209 end
210 bnet2.CPD{6} = set_fields(bnet.CPD{6}, 'cov', Sigma);
211
212 if 0
213 % visualize each model by concatenating means for each model for nsteps in a row
214 nsteps = 5;
215 ev = cell(ss, nsteps*prod(Qsizes(2:3)));
216 t = 1;
217 for q2=1:Qsizes(2)
218 for q3=1:Qsizes(3)
219 for i=1:nsteps
220 ev{Onode,t} = mu(:,q2,q3);
221 ev{Q2,t} = q2;
222 t = t + 1;
223 end
224 end
225 end
226 plot_square_hhmm(ev)
227 end
228
229 % bnet3 is the same as the learned model, except we will use it in testing mode
230 if supervised
231 bnet3 = bnet2;
232 bnet3.observed = [Onode];
233 engine3 = hmm_inf_engine(bnet3);
234 %engine3 = jtree_ndx_dbn_inf_engine(bnet3);
235 else
236 bnet3 = bnet2;
237 engine3 = engine2;
238 end
239
240 if 0
241 % segment whole sequence
242 mpe = calc_mpe_dbn(engine3, long_seq);
243 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []);
244 end
245
246 % segment each sequence
247 test_cases = cases(3:4);
248 for i=1:2
249 ev = test_cases{i};
250 T = size(ev, 2);
251 for t=1:T
252 ev{Q1,t} = [];
253 ev{Q2,t} = [];
254 end
255 mpe = calc_mpe_dbn(engine3, ev);
256 subplot(1,2,i)
257 plot_square_hhmm(mpe)
258 %pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []);
259 q1s = cell2num(mpe(Q1,:));
260 h = hist(q1s, 1:Qsizes(1));
261 map_q1 = argmax(h);
262 str = sprintf('test seq %d is of type %d\n', i, map_q1);
263 title(str)
264 end
265
266 end
267
268 if 0
269 % Estimate gotten by couting transitions in the labelled data
270 % Note that a self transition shouldnt count if F2=off.
271 Q2ev = cell2num(ev(Q2,:));
272 Q2a = Q2ev(1:end-1);
273 Q2b = Q2ev(2:end);
274 counts = compute_counts([Q2a; Q2b], [4 4]);
275 end
276
277 eclass = bnet2.equiv_class;
278 CPDQ1=struct(bnet2.CPD{eclass(Q1,2)});
279 CPDQ2=struct(bnet2.CPD{eclass(Q2,2)});
280 CPDQ3=struct(bnet2.CPD{eclass(Q3,2)});
281 CPDF2=struct(bnet2.CPD{eclass(F2,1)});
282 CPDF3=struct(bnet2.CPD{eclass(F3,1)});
283
284
285 A=add_hhmm_end_state(CPDQ2.transprob, CPDF2.termprob(:,:,2));
286 squeeze(A(:,1,:))
287 squeeze(A(:,2,:))
288 CPDQ2.startprob
289
290 if 0
291 S=struct(CPDF2.sub_CPD_term);
292 S.nsamples
293 reshape(S.counts, [2 4 2])
294 end