wolffd@0
|
1 % Learn a 3 level HHMM similar to mk_square_hhmm
|
wolffd@0
|
2
|
wolffd@0
|
3 % Because startprob should be shared for t=1:T,
|
wolffd@0
|
4 % but in the DBN is shared for t=2:T, we train using a single long sequence.
|
wolffd@0
|
5
|
wolffd@0
|
6 discrete_obs = 0;
|
wolffd@0
|
7 supervised = 1;
|
wolffd@0
|
8 obs_finalF2 = 0;
|
wolffd@0
|
9 % It is not possible to observe F2 if we learn
|
wolffd@0
|
10 % because the update_ess method for hhmmF_CPD and hhmmQ_CPD assume
|
wolffd@0
|
11 % the F nodes are always hidden (for speed).
|
wolffd@0
|
12 % However, for generating, we might want to set the final F2=true
|
wolffd@0
|
13 % to force all subroutines to finish.
|
wolffd@0
|
14
|
wolffd@0
|
15 ss = 6;
|
wolffd@0
|
16 Q1 = 1; Q2 = 2; Q3 = 3; F3 = 4; F2 = 5; Onode = 6;
|
wolffd@0
|
17 Qnodes = [Q1 Q2 Q3]; Fnodes = [F2 F3];
|
wolffd@0
|
18
|
wolffd@0
|
19 seed = 1;
|
wolffd@0
|
20 rand('state', seed);
|
wolffd@0
|
21 randn('state', seed);
|
wolffd@0
|
22
|
wolffd@0
|
23 if discrete_obs
|
wolffd@0
|
24 Qsizes = [2 4 2];
|
wolffd@0
|
25 else
|
wolffd@0
|
26 Qsizes = [2 4 1];
|
wolffd@0
|
27 end
|
wolffd@0
|
28
|
wolffd@0
|
29 D = 3;
|
wolffd@0
|
30 Qnodes = 1:D;
|
wolffd@0
|
31 startprob = cell(1,D);
|
wolffd@0
|
32 transprob = cell(1,D);
|
wolffd@0
|
33 termprob = cell(1,D);
|
wolffd@0
|
34
|
wolffd@0
|
35 startprob{1} = 'unif';
|
wolffd@0
|
36 transprob{1} = 'unif';
|
wolffd@0
|
37
|
wolffd@0
|
38 % In the unsupervised case, it is essential that we break symmetry
|
wolffd@0
|
39 % in the initial param estimates.
|
wolffd@0
|
40 %startprob{2} = 'unif';
|
wolffd@0
|
41 %transprob{2} = 'unif';
|
wolffd@0
|
42 %termprob{2} = 'unif';
|
wolffd@0
|
43 startprob{2} = 'rnd';
|
wolffd@0
|
44 transprob{2} = 'rnd';
|
wolffd@0
|
45 termprob{2} = 'rnd';
|
wolffd@0
|
46
|
wolffd@0
|
47 leftright = 0;
|
wolffd@0
|
48 if leftright
|
wolffd@0
|
49 % Initialise base-level models as left-right.
|
wolffd@0
|
50 % If we initialise with delta functions,
|
wolffd@0
|
51 % they will remain delat funcitons after learning
|
wolffd@0
|
52 startprob{3} = 'leftstart';
|
wolffd@0
|
53 transprob{3} = 'leftright';
|
wolffd@0
|
54 termprob{3} = 'rightstop';
|
wolffd@0
|
55 else
|
wolffd@0
|
56 % If we want to be able to run a base-level model backwards...
|
wolffd@0
|
57 startprob{3} = 'rnd';
|
wolffd@0
|
58 transprob{3} = 'rnd';
|
wolffd@0
|
59 termprob{3} = 'rnd';
|
wolffd@0
|
60 end
|
wolffd@0
|
61
|
wolffd@0
|
62 if discrete_obs
|
wolffd@0
|
63 % Initialise observations of lowest level primitives in a way which we can interpret
|
wolffd@0
|
64 chars = ['L', 'l', 'U', 'u', 'R', 'r', 'D', 'd'];
|
wolffd@0
|
65 L=find(chars=='L'); l=find(chars=='l');
|
wolffd@0
|
66 U=find(chars=='U'); u=find(chars=='u');
|
wolffd@0
|
67 R=find(chars=='R'); r=find(chars=='r');
|
wolffd@0
|
68 D=find(chars=='D'); d=find(chars=='d');
|
wolffd@0
|
69 Osize = length(chars);
|
wolffd@0
|
70
|
wolffd@0
|
71 p = 0.9;
|
wolffd@0
|
72 obsprob = (1-p)*ones([4 2 Osize]);
|
wolffd@0
|
73 % Q2 Q3 O
|
wolffd@0
|
74 obsprob(1, 1, L) = p;
|
wolffd@0
|
75 obsprob(1, 2, l) = p;
|
wolffd@0
|
76 obsprob(2, 1, U) = p;
|
wolffd@0
|
77 obsprob(2, 2, u) = p;
|
wolffd@0
|
78 obsprob(3, 1, R) = p;
|
wolffd@0
|
79 obsprob(3, 2, r) = p;
|
wolffd@0
|
80 obsprob(4, 1, D) = p;
|
wolffd@0
|
81 obsprob(4, 2, d) = p;
|
wolffd@0
|
82 obsprob = mk_stochastic(obsprob);
|
wolffd@0
|
83 Oargs = {'CPT', obsprob};
|
wolffd@0
|
84
|
wolffd@0
|
85 else
|
wolffd@0
|
86 % Initialise means of lowest level primitives in a way which we can interpret
|
wolffd@0
|
87 % These means are little vectors in the east, south, west, north directions.
|
wolffd@0
|
88 % (left-right=east, up-down=south, right-left=west, down-up=north)
|
wolffd@0
|
89 Osize = 2;
|
wolffd@0
|
90 mu = zeros(2, Qsizes(2), Qsizes(3));
|
wolffd@0
|
91 noise = 0;
|
wolffd@0
|
92 scale = 3;
|
wolffd@0
|
93 for q3=1:Qsizes(3)
|
wolffd@0
|
94 mu(:, 1, q3) = scale*[1;0] + noise*rand(2,1);
|
wolffd@0
|
95 end
|
wolffd@0
|
96 for q3=1:Qsizes(3)
|
wolffd@0
|
97 mu(:, 2, q3) = scale*[0;-1] + noise*rand(2,1);
|
wolffd@0
|
98 end
|
wolffd@0
|
99 for q3=1:Qsizes(3)
|
wolffd@0
|
100 mu(:, 3, q3) = scale*[-1;0] + noise*rand(2,1);
|
wolffd@0
|
101 end
|
wolffd@0
|
102 for q3=1:Qsizes(3)
|
wolffd@0
|
103 mu(:, 4, q3) = scale*[0;1] + noise*rand(2,1);
|
wolffd@0
|
104 end
|
wolffd@0
|
105 Sigma = repmat(reshape(scale*eye(2), [2 2 1 1 ]), [1 1 Qsizes(2) Qsizes(3)]);
|
wolffd@0
|
106 Oargs = {'mean', mu, 'cov', Sigma, 'cov_type', 'diag'};
|
wolffd@0
|
107 end
|
wolffd@0
|
108
|
wolffd@0
|
109 bnet = mk_hhmm('Qsizes', Qsizes, 'Osize', Osize', 'discrete_obs', discrete_obs,...
|
wolffd@0
|
110 'Oargs', Oargs, 'Ops', Qnodes(2:3), ...
|
wolffd@0
|
111 'startprob', startprob, 'transprob', transprob, 'termprob', termprob);
|
wolffd@0
|
112
|
wolffd@0
|
113 if supervised
|
wolffd@0
|
114 bnet.observed = [Q1 Q2 Onode];
|
wolffd@0
|
115 else
|
wolffd@0
|
116 bnet.observed = [Onode];
|
wolffd@0
|
117 end
|
wolffd@0
|
118
|
wolffd@0
|
119 if obs_finalF2
|
wolffd@0
|
120 engine = jtree_dbn_inf_engine(bnet);
|
wolffd@0
|
121 % can't use ndx version because sometimes F2 is hidden, sometimes observed
|
wolffd@0
|
122 error('can''t observe F when learning')
|
wolffd@0
|
123 else
|
wolffd@0
|
124 if supervised
|
wolffd@0
|
125 engine = jtree_ndx_dbn_inf_engine(bnet);
|
wolffd@0
|
126 else
|
wolffd@0
|
127 engine = jtree_hmm_inf_engine(bnet);
|
wolffd@0
|
128 end
|
wolffd@0
|
129 end
|
wolffd@0
|
130
|
wolffd@0
|
131 if discrete_obs
|
wolffd@0
|
132 % generate some synthetic data (easier to debug)
|
wolffd@0
|
133 cases = {};
|
wolffd@0
|
134
|
wolffd@0
|
135 T = 8;
|
wolffd@0
|
136 ev = cell(ss, T);
|
wolffd@0
|
137 ev(Onode,:) = num2cell([L l U u R r D d]);
|
wolffd@0
|
138 if supervised
|
wolffd@0
|
139 ev(Q1,:) = num2cell(1*ones(1,T));
|
wolffd@0
|
140 ev(Q2,:) = num2cell( [1 1 2 2 3 3 4 4]);
|
wolffd@0
|
141 end
|
wolffd@0
|
142 cases{1} = ev;
|
wolffd@0
|
143 cases{3} = ev;
|
wolffd@0
|
144
|
wolffd@0
|
145 T = 8;
|
wolffd@0
|
146 ev = cell(ss, T);
|
wolffd@0
|
147 if leftright % base model is left-right
|
wolffd@0
|
148 ev(Onode,:) = num2cell([R r U u L l D d]);
|
wolffd@0
|
149 else
|
wolffd@0
|
150 ev(Onode,:) = num2cell([r R u U l L d D]);
|
wolffd@0
|
151 end
|
wolffd@0
|
152 if supervised
|
wolffd@0
|
153 ev(Q1,:) = num2cell(2*ones(1,T));
|
wolffd@0
|
154 ev(Q2,:) = num2cell( [3 3 2 2 1 1 4 4]);
|
wolffd@0
|
155 end
|
wolffd@0
|
156
|
wolffd@0
|
157 cases{2} = ev;
|
wolffd@0
|
158 cases{4} = ev;
|
wolffd@0
|
159
|
wolffd@0
|
160 if obs_finalF2
|
wolffd@0
|
161 for i=1:length(cases)
|
wolffd@0
|
162 T = size(cases{i},2);
|
wolffd@0
|
163 cases{i}(F2,T)={2}; % force F2 to be finished at end of seq
|
wolffd@0
|
164 end
|
wolffd@0
|
165 end
|
wolffd@0
|
166
|
wolffd@0
|
167 if 0
|
wolffd@0
|
168 ev = cases{4};
|
wolffd@0
|
169 engine2 = enter_evidence(engine2, ev);
|
wolffd@0
|
170 T = size(ev,2);
|
wolffd@0
|
171 for t=1:T
|
wolffd@0
|
172 m=marginal_family(engine2, F2, t);
|
wolffd@0
|
173 fprintf('t=%d\n', t);
|
wolffd@0
|
174 reshape(m.T, [2 2])
|
wolffd@0
|
175 end
|
wolffd@0
|
176 end
|
wolffd@0
|
177
|
wolffd@0
|
178 % [bnet2, LL] = learn_params_dbn_em(engine, cases, 'max_iter', 10);
|
wolffd@0
|
179 long_seq = cat(2, cases{:});
|
wolffd@0
|
180 [bnet2, LL, engine2] = learn_params_dbn_em(engine, {long_seq}, 'max_iter', 200);
|
wolffd@0
|
181
|
wolffd@0
|
182 % figure out which subsequence each model is responsible for
|
wolffd@0
|
183 mpe = calc_mpe_dbn(engine2, long_seq);
|
wolffd@0
|
184 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, chars);
|
wolffd@0
|
185
|
wolffd@0
|
186 else
|
wolffd@0
|
187 load 'square4_cases' % cases{seq}{i,t} for i=1:ss
|
wolffd@0
|
188 %plot_square_hhmm(cases{1})
|
wolffd@0
|
189 %long_seq = cat(2, cases{:});
|
wolffd@0
|
190 train_cases = cases(1:2);
|
wolffd@0
|
191 long_seq = cat(2, train_cases{:});
|
wolffd@0
|
192 if ~supervised
|
wolffd@0
|
193 T = size(long_seq,2);
|
wolffd@0
|
194 for t=1:T
|
wolffd@0
|
195 long_seq{Q1,t} = [];
|
wolffd@0
|
196 long_seq{Q2,t} = [];
|
wolffd@0
|
197 end
|
wolffd@0
|
198 end
|
wolffd@0
|
199 [bnet2, LL, engine2] = learn_params_dbn_em(engine, {long_seq}, 'max_iter', 100);
|
wolffd@0
|
200
|
wolffd@0
|
201 CPDO=struct(bnet2.CPD{eclass(Onode,1)});
|
wolffd@0
|
202 mu = CPDO.mean;
|
wolffd@0
|
203 Sigma = CPDO.cov;
|
wolffd@0
|
204 CPDO_full = CPDO;
|
wolffd@0
|
205
|
wolffd@0
|
206 % force diagonal covs after training
|
wolffd@0
|
207 for k=1:size(Sigma,3)
|
wolffd@0
|
208 Sigma(:,:,k) = diag(diag(Sigma(:,:,k)));
|
wolffd@0
|
209 end
|
wolffd@0
|
210 bnet2.CPD{6} = set_fields(bnet.CPD{6}, 'cov', Sigma);
|
wolffd@0
|
211
|
wolffd@0
|
212 if 0
|
wolffd@0
|
213 % visualize each model by concatenating means for each model for nsteps in a row
|
wolffd@0
|
214 nsteps = 5;
|
wolffd@0
|
215 ev = cell(ss, nsteps*prod(Qsizes(2:3)));
|
wolffd@0
|
216 t = 1;
|
wolffd@0
|
217 for q2=1:Qsizes(2)
|
wolffd@0
|
218 for q3=1:Qsizes(3)
|
wolffd@0
|
219 for i=1:nsteps
|
wolffd@0
|
220 ev{Onode,t} = mu(:,q2,q3);
|
wolffd@0
|
221 ev{Q2,t} = q2;
|
wolffd@0
|
222 t = t + 1;
|
wolffd@0
|
223 end
|
wolffd@0
|
224 end
|
wolffd@0
|
225 end
|
wolffd@0
|
226 plot_square_hhmm(ev)
|
wolffd@0
|
227 end
|
wolffd@0
|
228
|
wolffd@0
|
229 % bnet3 is the same as the learned model, except we will use it in testing mode
|
wolffd@0
|
230 if supervised
|
wolffd@0
|
231 bnet3 = bnet2;
|
wolffd@0
|
232 bnet3.observed = [Onode];
|
wolffd@0
|
233 engine3 = hmm_inf_engine(bnet3);
|
wolffd@0
|
234 %engine3 = jtree_ndx_dbn_inf_engine(bnet3);
|
wolffd@0
|
235 else
|
wolffd@0
|
236 bnet3 = bnet2;
|
wolffd@0
|
237 engine3 = engine2;
|
wolffd@0
|
238 end
|
wolffd@0
|
239
|
wolffd@0
|
240 if 0
|
wolffd@0
|
241 % segment whole sequence
|
wolffd@0
|
242 mpe = calc_mpe_dbn(engine3, long_seq);
|
wolffd@0
|
243 pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []);
|
wolffd@0
|
244 end
|
wolffd@0
|
245
|
wolffd@0
|
246 % segment each sequence
|
wolffd@0
|
247 test_cases = cases(3:4);
|
wolffd@0
|
248 for i=1:2
|
wolffd@0
|
249 ev = test_cases{i};
|
wolffd@0
|
250 T = size(ev, 2);
|
wolffd@0
|
251 for t=1:T
|
wolffd@0
|
252 ev{Q1,t} = [];
|
wolffd@0
|
253 ev{Q2,t} = [];
|
wolffd@0
|
254 end
|
wolffd@0
|
255 mpe = calc_mpe_dbn(engine3, ev);
|
wolffd@0
|
256 subplot(1,2,i)
|
wolffd@0
|
257 plot_square_hhmm(mpe)
|
wolffd@0
|
258 %pretty_print_hhmm_parse(mpe, Qnodes, Fnodes, Onode, []);
|
wolffd@0
|
259 q1s = cell2num(mpe(Q1,:));
|
wolffd@0
|
260 h = hist(q1s, 1:Qsizes(1));
|
wolffd@0
|
261 map_q1 = argmax(h);
|
wolffd@0
|
262 str = sprintf('test seq %d is of type %d\n', i, map_q1);
|
wolffd@0
|
263 title(str)
|
wolffd@0
|
264 end
|
wolffd@0
|
265
|
wolffd@0
|
266 end
|
wolffd@0
|
267
|
wolffd@0
|
268 if 0
|
wolffd@0
|
269 % Estimate gotten by couting transitions in the labelled data
|
wolffd@0
|
270 % Note that a self transition shouldnt count if F2=off.
|
wolffd@0
|
271 Q2ev = cell2num(ev(Q2,:));
|
wolffd@0
|
272 Q2a = Q2ev(1:end-1);
|
wolffd@0
|
273 Q2b = Q2ev(2:end);
|
wolffd@0
|
274 counts = compute_counts([Q2a; Q2b], [4 4]);
|
wolffd@0
|
275 end
|
wolffd@0
|
276
|
wolffd@0
|
277 eclass = bnet2.equiv_class;
|
wolffd@0
|
278 CPDQ1=struct(bnet2.CPD{eclass(Q1,2)});
|
wolffd@0
|
279 CPDQ2=struct(bnet2.CPD{eclass(Q2,2)});
|
wolffd@0
|
280 CPDQ3=struct(bnet2.CPD{eclass(Q3,2)});
|
wolffd@0
|
281 CPDF2=struct(bnet2.CPD{eclass(F2,1)});
|
wolffd@0
|
282 CPDF3=struct(bnet2.CPD{eclass(F3,1)});
|
wolffd@0
|
283
|
wolffd@0
|
284
|
wolffd@0
|
285 A=add_hhmm_end_state(CPDQ2.transprob, CPDF2.termprob(:,:,2));
|
wolffd@0
|
286 squeeze(A(:,1,:))
|
wolffd@0
|
287 squeeze(A(:,2,:))
|
wolffd@0
|
288 CPDQ2.startprob
|
wolffd@0
|
289
|
wolffd@0
|
290 if 0
|
wolffd@0
|
291 S=struct(CPDF2.sub_CPD_term);
|
wolffd@0
|
292 S.nsamples
|
wolffd@0
|
293 reshape(S.counts, [2 4 2])
|
wolffd@0
|
294 end
|