wolffd@0
|
1 function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(varargin)
|
wolffd@0
|
2 % MK_HHMM Make a Hierarchical HMM
|
wolffd@0
|
3 % function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(...)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % e.g. 3-layer hierarchical HMM where level 1 only connects to level 2
|
wolffd@0
|
6 % and the parents of the observed node are levels 2 and 3.
|
wolffd@0
|
7 % (This DBN is the same as Fig 10 in my tech report.)
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % Q1 ----------> Q1
|
wolffd@0
|
10 % | \ ^ |
|
wolffd@0
|
11 % | v / |
|
wolffd@0
|
12 % | F2 ------/ |
|
wolffd@0
|
13 % | ^ ^ \ |
|
wolffd@0
|
14 % | / | \ |
|
wolffd@0
|
15 % | / | ||
|
wolffd@0
|
16 % v | vv
|
wolffd@0
|
17 % Q2----| --------> Q2
|
wolffd@0
|
18 % /| \ | ^|
|
wolffd@0
|
19 % / | v | / |
|
wolffd@0
|
20 % | | F3 --------/ |
|
wolffd@0
|
21 % | | ^ \ |
|
wolffd@0
|
22 % | v / v v
|
wolffd@0
|
23 % | Q3 -----------> Q3
|
wolffd@0
|
24 % | |
|
wolffd@0
|
25 % \ |
|
wolffd@0
|
26 % v v
|
wolffd@0
|
27 % O
|
wolffd@0
|
28 %
|
wolffd@0
|
29 %
|
wolffd@0
|
30 % Optional arguments in name/value format [default value in brackets]
|
wolffd@0
|
31 %
|
wolffd@0
|
32 % Qsizes - sizes at each level [ none ]
|
wolffd@0
|
33 % allQ - 1 means level i connects to all Q levels below, 0 means just to i+1 [0]
|
wolffd@0
|
34 % transprob - transprob{d}(i,k,j) = P(Q(d,t)=j|Q(d,t-1)=i,Q(1:d-1,t)=k) ['leftright']
|
wolffd@0
|
35 % startprob - startprob{d}(k,j) = P(Q(d,t)=j|Q(1:d-1,t)=k) ['leftstart']
|
wolffd@0
|
36 % termprob - termprob{d}(k,j) = P(F(d,t)=2|Q(1:d-1,t)=k,Q(d,t)=j) for d>1 ['rightstop']
|
wolffd@0
|
37 % selfprop - prob of a self transition (termprob default = 1-selfprop) [0.8]
|
wolffd@0
|
38 % Osize - size of O node
|
wolffd@0
|
39 % discrete_obs - 1 means O is tabular_CPD, 0 means gaussian_CPD [0]
|
wolffd@0
|
40 % Oargs - cell array of args to pass to the O CPD [ {} ]
|
wolffd@0
|
41 % Ops - Q parents of O [Qnodes(end)]
|
wolffd@0
|
42 % F1 - 1 means level 1 can finish (restart), else there is no F1->Q1 arc [0]
|
wolffd@0
|
43 % clamp1 - 1 means we clamp the params of the Q nodes in slice 1 (Qt1params) [1]
|
wolffd@0
|
44 % Note: the Qt1params are startprob, which should be shared with other slices.
|
wolffd@0
|
45 % However, in the current implementation, the Qt1params will only be estimated
|
wolffd@0
|
46 % from the initial state of each sequence.
|
wolffd@0
|
47 %
|
wolffd@0
|
48 % For d=1, startprob{1}(1,j) is only used in the first slice and
|
wolffd@0
|
49 % termprob{1} is ignored, since we assume the top level never resets.
|
wolffd@0
|
50 % Also, transprob{1}(i,j) can be used instead of transprob{1}(i,1,j).
|
wolffd@0
|
51 %
|
wolffd@0
|
52 % leftstart means the model always starts in state 1.
|
wolffd@0
|
53 % rightstop means the model can only finish in its last state (Qsize(d)).
|
wolffd@0
|
54 % unif means each state is equally like to reach any other
|
wolffd@0
|
55 % rnd means the transition/starting probs are random (drawn from rand)
|
wolffd@0
|
56 %
|
wolffd@0
|
57 % Q1:QD in slice 1 are of type tabular_CPD
|
wolffd@0
|
58 % Q1:QD in slice 2 are of type hhmmQ_CPD.
|
wolffd@0
|
59 % F(2:D-1) is of type hhmmF_CPD, FD is of type tabular_CPD.
|
wolffd@0
|
60
|
wolffd@0
|
61 args = varargin;
|
wolffd@0
|
62 nargs = length(args);
|
wolffd@0
|
63
|
wolffd@0
|
64 % get sizes of nodes and topology
|
wolffd@0
|
65 Qsizes = [];
|
wolffd@0
|
66 Osize = [];
|
wolffd@0
|
67 allQ = 0;
|
wolffd@0
|
68 Ops = [];
|
wolffd@0
|
69 F1 = 0;
|
wolffd@0
|
70 for i=1:2:nargs
|
wolffd@0
|
71 switch args{i},
|
wolffd@0
|
72 case 'Qsizes', Qsizes = args{i+1};
|
wolffd@0
|
73 case 'Osize', Osize = args{i+1};
|
wolffd@0
|
74 case 'allQ', allQ = args{i+1};
|
wolffd@0
|
75 case 'Ops', Ops = args{i+1};
|
wolffd@0
|
76 case 'F1', F1 = args{i+1};
|
wolffd@0
|
77 end
|
wolffd@0
|
78 end
|
wolffd@0
|
79 if isempty(Qsizes), error('must specify Qsizes'); end
|
wolffd@0
|
80 if Osize==0, error('must specify Osize'); end
|
wolffd@0
|
81 D = length(Qsizes);
|
wolffd@0
|
82 Qnodes = 1:D;
|
wolffd@0
|
83
|
wolffd@0
|
84 if isempty(Ops), Ops = Qnodes(end); end
|
wolffd@0
|
85
|
wolffd@0
|
86
|
wolffd@0
|
87 [intra, inter, Qnodes, Fnodes, Onode] = mk_hhmm_topo(D, allQ, Ops, F1);
|
wolffd@0
|
88 ss = length(intra);
|
wolffd@0
|
89 names = {};
|
wolffd@0
|
90
|
wolffd@0
|
91 if F1
|
wolffd@0
|
92 Fnodes_ndx = Fnodes;
|
wolffd@0
|
93 else
|
wolffd@0
|
94 Fnodes_ndx = [-1 Fnodes]; % Fnodes(1) is a dummy index
|
wolffd@0
|
95 end
|
wolffd@0
|
96
|
wolffd@0
|
97 % set default params
|
wolffd@0
|
98 discrete_obs = 0;
|
wolffd@0
|
99 Oargs = {};
|
wolffd@0
|
100 startprob = cell(1,D);
|
wolffd@0
|
101 startprob{1} = 'unif';
|
wolffd@0
|
102 for d=2:D
|
wolffd@0
|
103 startprob{d} = 'leftstart';
|
wolffd@0
|
104 end
|
wolffd@0
|
105 transprob = cell(1,D);
|
wolffd@0
|
106 transprob{1} = 'unif';
|
wolffd@0
|
107 for d=2:D
|
wolffd@0
|
108 transprob{d} = 'leftright';
|
wolffd@0
|
109 end
|
wolffd@0
|
110 termprob = cell(1,D);
|
wolffd@0
|
111 for d=2:D
|
wolffd@0
|
112 termprob{d} = 'rightstop';
|
wolffd@0
|
113 end
|
wolffd@0
|
114 selfprob = 0.8;
|
wolffd@0
|
115 clamp1 = 1;
|
wolffd@0
|
116
|
wolffd@0
|
117 for i=1:2:nargs
|
wolffd@0
|
118 switch args{i},
|
wolffd@0
|
119 case 'discrete_obs', discrete_obs = args{i+1};
|
wolffd@0
|
120 case 'Oargs', Oargs = args{i+1};
|
wolffd@0
|
121 case 'startprob', startprob = args{i+1};
|
wolffd@0
|
122 case 'transprob', transprob = args{i+1};
|
wolffd@0
|
123 case 'termprob', termprob = args{i+1};
|
wolffd@0
|
124 case 'selfprob', selfprob = args{i+1};
|
wolffd@0
|
125 case 'clamp1', clamp1 = args{i+1};
|
wolffd@0
|
126 end
|
wolffd@0
|
127 end
|
wolffd@0
|
128
|
wolffd@0
|
129 ns = zeros(1,ss);
|
wolffd@0
|
130 ns(Qnodes) = Qsizes;
|
wolffd@0
|
131 ns(Onode) = Osize;
|
wolffd@0
|
132 ns(Fnodes) = 2;
|
wolffd@0
|
133
|
wolffd@0
|
134 dnodes = [Qnodes Fnodes];
|
wolffd@0
|
135 if discrete_obs
|
wolffd@0
|
136 dnodes = [dnodes Onode];
|
wolffd@0
|
137 end
|
wolffd@0
|
138 onodes = [Onode];
|
wolffd@0
|
139
|
wolffd@0
|
140 bnet = mk_dbn(intra, inter, ns, 'observed', onodes, 'discrete', dnodes, 'names', names);
|
wolffd@0
|
141 eclass = bnet.equiv_class;
|
wolffd@0
|
142
|
wolffd@0
|
143 for d=1:D
|
wolffd@0
|
144 if d==1
|
wolffd@0
|
145 Qps = [];
|
wolffd@0
|
146 elseif allQ
|
wolffd@0
|
147 Qps = Qnodes(1:d-1);
|
wolffd@0
|
148 else
|
wolffd@0
|
149 Qps = Qnodes(d-1);
|
wolffd@0
|
150 end
|
wolffd@0
|
151 Qpsz = prod(ns(Qps));
|
wolffd@0
|
152 Qsz = ns(Qnodes(d));
|
wolffd@0
|
153 if isstr(startprob{d})
|
wolffd@0
|
154 switch startprob{d}
|
wolffd@0
|
155 case 'unif', startprob{d} = mk_stochastic(ones(Qpsz, Qsz));
|
wolffd@0
|
156 case 'rnd', startprob{d} = mk_stochastic(rand(Qpsz, Qsz));
|
wolffd@0
|
157 case 'leftstart', startprob{d} = zeros(Qpsz, Qsz); startprob{d}(:,1) = 1;
|
wolffd@0
|
158 end
|
wolffd@0
|
159 end
|
wolffd@0
|
160 if isstr(transprob{d})
|
wolffd@0
|
161 switch transprob{d}
|
wolffd@0
|
162 case 'unif', transprob{d} = mk_stochastic(ones(Qsz, Qpsz, Qsz));
|
wolffd@0
|
163 case 'rnd', transprob{d} = mk_stochastic(rand(Qsz, Qpsz, Qsz));
|
wolffd@0
|
164 case 'leftright',
|
wolffd@0
|
165 LR = mk_leftright_transmat(Qsz, selfprob);
|
wolffd@0
|
166 temp = repmat(reshape(LR, [1 Qsz Qsz]), [Qpsz 1 1]); % transprob(k,i,j)
|
wolffd@0
|
167 transprob{d} = permute(temp, [2 1 3]); % now transprob(i,k,j)
|
wolffd@0
|
168 end
|
wolffd@0
|
169 end
|
wolffd@0
|
170 if isstr(termprob{d})
|
wolffd@0
|
171 switch termprob{d}
|
wolffd@0
|
172 case 'unif', termprob{d} = mk_stochastic(ones(Qpsz, Qsz, 2));
|
wolffd@0
|
173 case 'rnd', termprob{d} = mk_stochastic(rand(Qpsz, Qsz, 2));
|
wolffd@0
|
174 case 'rightstop',
|
wolffd@0
|
175 %termprob(k,i,t) Might terminate if i=Qsz; will not terminate if i<Qsz
|
wolffd@0
|
176 stopprob = 1-selfprob;
|
wolffd@0
|
177 termprob{d} = zeros(Qpsz, Qsz, 2);
|
wolffd@0
|
178 termprob{d}(:,Qsz,2) = stopprob;
|
wolffd@0
|
179 termprob{d}(:,Qsz,1) = 1-stopprob;
|
wolffd@0
|
180 termprob{d}(:,1:(Qsz-1),1) = 1;
|
wolffd@0
|
181 otherwise, error(['unrecognized termprob ' termprob{d}])
|
wolffd@0
|
182 end
|
wolffd@0
|
183 elseif d>1 % passed in termprob{d}(k,j)
|
wolffd@0
|
184 temp = termprob{d};
|
wolffd@0
|
185 termprob{d} = zeros(Qpsz, Qsz, 2);
|
wolffd@0
|
186 termprob{d}(:,:,2) = temp;
|
wolffd@0
|
187 termprob{d}(:,:,1) = ones(Qpsz,Qsz) - temp;
|
wolffd@0
|
188 end
|
wolffd@0
|
189 end
|
wolffd@0
|
190
|
wolffd@0
|
191
|
wolffd@0
|
192 % SLICE 1
|
wolffd@0
|
193
|
wolffd@0
|
194 for d=1:D
|
wolffd@0
|
195 bnet.CPD{eclass(Qnodes(d),1)} = tabular_CPD(bnet, Qnodes(d), 'CPT', startprob{d}, 'adjustable', clamp1);
|
wolffd@0
|
196 end
|
wolffd@0
|
197
|
wolffd@0
|
198 if F1
|
wolffd@0
|
199 d = 1;
|
wolffd@0
|
200 bnet.CPD{eclass(Fnodes_ndx(d),1)} = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ...
|
wolffd@0
|
201 'termprob', termprob{d});
|
wolffd@0
|
202 end
|
wolffd@0
|
203 for d=2:D-1
|
wolffd@0
|
204 if allQ
|
wolffd@0
|
205 Qps = Qnodes(1:d-1);
|
wolffd@0
|
206 else
|
wolffd@0
|
207 Qps = Qnodes(d-1);
|
wolffd@0
|
208 end
|
wolffd@0
|
209 bnet.CPD{eclass(Fnodes_ndx(d),1)} = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ...
|
wolffd@0
|
210 'Qps', Qps, 'termprob', termprob{d});
|
wolffd@0
|
211 end
|
wolffd@0
|
212 bnet.CPD{eclass(Fnodes_ndx(D),1)} = tabular_CPD(bnet, Fnodes_ndx(D), 'CPT', termprob{D});
|
wolffd@0
|
213
|
wolffd@0
|
214 if discrete_obs
|
wolffd@0
|
215 bnet.CPD{eclass(Onode,1)} = tabular_CPD(bnet, Onode, Oargs{:});
|
wolffd@0
|
216 else
|
wolffd@0
|
217 bnet.CPD{eclass(Onode,1)} = gaussian_CPD(bnet, Onode, Oargs{:});
|
wolffd@0
|
218 end
|
wolffd@0
|
219
|
wolffd@0
|
220 % SLICE 2
|
wolffd@0
|
221
|
wolffd@0
|
222 %for d=1:D
|
wolffd@0
|
223 % bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, Qnodes, d, D, ...
|
wolffd@0
|
224 % 'startprob', startprob{d}, 'transprob', transprob{d}, ...
|
wolffd@0
|
225 % 'allQ', allQ);
|
wolffd@0
|
226 %end
|
wolffd@0
|
227
|
wolffd@0
|
228 d = 1;
|
wolffd@0
|
229 if F1
|
wolffd@0
|
230 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ...
|
wolffd@0
|
231 'Fbelow', Fnodes_ndx(d+1), ...
|
wolffd@0
|
232 'startprob', startprob{d}, 'transprob', transprob{d});
|
wolffd@0
|
233 else
|
wolffd@0
|
234 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, ...
|
wolffd@0
|
235 'Fbelow', Fnodes_ndx(d+1), ...
|
wolffd@0
|
236 'startprob', startprob{d}, 'transprob', transprob{d});
|
wolffd@0
|
237 end
|
wolffd@0
|
238 for d=2:D-1
|
wolffd@0
|
239 if allQ
|
wolffd@0
|
240 Qps = Qnodes(1:d-1);
|
wolffd@0
|
241 else
|
wolffd@0
|
242 Qps = Qnodes(d-1);
|
wolffd@0
|
243 end
|
wolffd@0
|
244 Qps = Qps + ss; % since all in slice 2
|
wolffd@0
|
245 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ...
|
wolffd@0
|
246 'Fbelow', Fnodes_ndx(d+1), 'Qps', Qps, ...
|
wolffd@0
|
247 'startprob', startprob{d}, 'transprob', transprob{d});
|
wolffd@0
|
248 end
|
wolffd@0
|
249 d = D;
|
wolffd@0
|
250 if allQ
|
wolffd@0
|
251 Qps = Qnodes(1:d-1);
|
wolffd@0
|
252 else
|
wolffd@0
|
253 Qps = Qnodes(d-1);
|
wolffd@0
|
254 end
|
wolffd@0
|
255 Qps = Qps + ss; % since all in slice 2
|
wolffd@0
|
256 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ...
|
wolffd@0
|
257 'Qps', Qps, ...
|
wolffd@0
|
258 'startprob', startprob{d}, 'transprob', transprob{d});
|