Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/mk_hhmm.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(varargin) | |
2 % MK_HHMM Make a Hierarchical HMM | |
3 % function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(...) | |
4 % | |
5 % e.g. 3-layer hierarchical HMM where level 1 only connects to level 2 | |
6 % and the parents of the observed node are levels 2 and 3. | |
7 % (This DBN is the same as Fig 10 in my tech report.) | |
8 % | |
9 % Q1 ----------> Q1 | |
10 % | \ ^ | | |
11 % | v / | | |
12 % | F2 ------/ | | |
13 % | ^ ^ \ | | |
14 % | / | \ | | |
15 % | / | || | |
16 % v | vv | |
17 % Q2----| --------> Q2 | |
18 % /| \ | ^| | |
19 % / | v | / | | |
20 % | | F3 --------/ | | |
21 % | | ^ \ | | |
22 % | v / v v | |
23 % | Q3 -----------> Q3 | |
24 % | | | |
25 % \ | | |
26 % v v | |
27 % O | |
28 % | |
29 % | |
30 % Optional arguments in name/value format [default value in brackets] | |
31 % | |
32 % Qsizes - sizes at each level [ none ] | |
33 % allQ - 1 means level i connects to all Q levels below, 0 means just to i+1 [0] | |
34 % transprob - transprob{d}(i,k,j) = P(Q(d,t)=j|Q(d,t-1)=i,Q(1:d-1,t)=k) ['leftright'] | |
35 % startprob - startprob{d}(k,j) = P(Q(d,t)=j|Q(1:d-1,t)=k) ['leftstart'] | |
36 % termprob - termprob{d}(k,j) = P(F(d,t)=2|Q(1:d-1,t)=k,Q(d,t)=j) for d>1 ['rightstop'] | |
37 % selfprop - prob of a self transition (termprob default = 1-selfprop) [0.8] | |
38 % Osize - size of O node | |
39 % discrete_obs - 1 means O is tabular_CPD, 0 means gaussian_CPD [0] | |
40 % Oargs - cell array of args to pass to the O CPD [ {} ] | |
41 % Ops - Q parents of O [Qnodes(end)] | |
42 % F1 - 1 means level 1 can finish (restart), else there is no F1->Q1 arc [0] | |
43 % clamp1 - 1 means we clamp the params of the Q nodes in slice 1 (Qt1params) [1] | |
44 % Note: the Qt1params are startprob, which should be shared with other slices. | |
45 % However, in the current implementation, the Qt1params will only be estimated | |
46 % from the initial state of each sequence. | |
47 % | |
48 % For d=1, startprob{1}(1,j) is only used in the first slice and | |
49 % termprob{1} is ignored, since we assume the top level never resets. | |
50 % Also, transprob{1}(i,j) can be used instead of transprob{1}(i,1,j). | |
51 % | |
52 % leftstart means the model always starts in state 1. | |
53 % rightstop means the model can only finish in its last state (Qsize(d)). | |
54 % unif means each state is equally like to reach any other | |
55 % rnd means the transition/starting probs are random (drawn from rand) | |
56 % | |
57 % Q1:QD in slice 1 are of type tabular_CPD | |
58 % Q1:QD in slice 2 are of type hhmmQ_CPD. | |
59 % F(2:D-1) is of type hhmmF_CPD, FD is of type tabular_CPD. | |
60 | |
61 args = varargin; | |
62 nargs = length(args); | |
63 | |
64 % get sizes of nodes and topology | |
65 Qsizes = []; | |
66 Osize = []; | |
67 allQ = 0; | |
68 Ops = []; | |
69 F1 = 0; | |
70 for i=1:2:nargs | |
71 switch args{i}, | |
72 case 'Qsizes', Qsizes = args{i+1}; | |
73 case 'Osize', Osize = args{i+1}; | |
74 case 'allQ', allQ = args{i+1}; | |
75 case 'Ops', Ops = args{i+1}; | |
76 case 'F1', F1 = args{i+1}; | |
77 end | |
78 end | |
79 if isempty(Qsizes), error('must specify Qsizes'); end | |
80 if Osize==0, error('must specify Osize'); end | |
81 D = length(Qsizes); | |
82 Qnodes = 1:D; | |
83 | |
84 if isempty(Ops), Ops = Qnodes(end); end | |
85 | |
86 | |
87 [intra, inter, Qnodes, Fnodes, Onode] = mk_hhmm_topo(D, allQ, Ops, F1); | |
88 ss = length(intra); | |
89 names = {}; | |
90 | |
91 if F1 | |
92 Fnodes_ndx = Fnodes; | |
93 else | |
94 Fnodes_ndx = [-1 Fnodes]; % Fnodes(1) is a dummy index | |
95 end | |
96 | |
97 % set default params | |
98 discrete_obs = 0; | |
99 Oargs = {}; | |
100 startprob = cell(1,D); | |
101 startprob{1} = 'unif'; | |
102 for d=2:D | |
103 startprob{d} = 'leftstart'; | |
104 end | |
105 transprob = cell(1,D); | |
106 transprob{1} = 'unif'; | |
107 for d=2:D | |
108 transprob{d} = 'leftright'; | |
109 end | |
110 termprob = cell(1,D); | |
111 for d=2:D | |
112 termprob{d} = 'rightstop'; | |
113 end | |
114 selfprob = 0.8; | |
115 clamp1 = 1; | |
116 | |
117 for i=1:2:nargs | |
118 switch args{i}, | |
119 case 'discrete_obs', discrete_obs = args{i+1}; | |
120 case 'Oargs', Oargs = args{i+1}; | |
121 case 'startprob', startprob = args{i+1}; | |
122 case 'transprob', transprob = args{i+1}; | |
123 case 'termprob', termprob = args{i+1}; | |
124 case 'selfprob', selfprob = args{i+1}; | |
125 case 'clamp1', clamp1 = args{i+1}; | |
126 end | |
127 end | |
128 | |
129 ns = zeros(1,ss); | |
130 ns(Qnodes) = Qsizes; | |
131 ns(Onode) = Osize; | |
132 ns(Fnodes) = 2; | |
133 | |
134 dnodes = [Qnodes Fnodes]; | |
135 if discrete_obs | |
136 dnodes = [dnodes Onode]; | |
137 end | |
138 onodes = [Onode]; | |
139 | |
140 bnet = mk_dbn(intra, inter, ns, 'observed', onodes, 'discrete', dnodes, 'names', names); | |
141 eclass = bnet.equiv_class; | |
142 | |
143 for d=1:D | |
144 if d==1 | |
145 Qps = []; | |
146 elseif allQ | |
147 Qps = Qnodes(1:d-1); | |
148 else | |
149 Qps = Qnodes(d-1); | |
150 end | |
151 Qpsz = prod(ns(Qps)); | |
152 Qsz = ns(Qnodes(d)); | |
153 if isstr(startprob{d}) | |
154 switch startprob{d} | |
155 case 'unif', startprob{d} = mk_stochastic(ones(Qpsz, Qsz)); | |
156 case 'rnd', startprob{d} = mk_stochastic(rand(Qpsz, Qsz)); | |
157 case 'leftstart', startprob{d} = zeros(Qpsz, Qsz); startprob{d}(:,1) = 1; | |
158 end | |
159 end | |
160 if isstr(transprob{d}) | |
161 switch transprob{d} | |
162 case 'unif', transprob{d} = mk_stochastic(ones(Qsz, Qpsz, Qsz)); | |
163 case 'rnd', transprob{d} = mk_stochastic(rand(Qsz, Qpsz, Qsz)); | |
164 case 'leftright', | |
165 LR = mk_leftright_transmat(Qsz, selfprob); | |
166 temp = repmat(reshape(LR, [1 Qsz Qsz]), [Qpsz 1 1]); % transprob(k,i,j) | |
167 transprob{d} = permute(temp, [2 1 3]); % now transprob(i,k,j) | |
168 end | |
169 end | |
170 if isstr(termprob{d}) | |
171 switch termprob{d} | |
172 case 'unif', termprob{d} = mk_stochastic(ones(Qpsz, Qsz, 2)); | |
173 case 'rnd', termprob{d} = mk_stochastic(rand(Qpsz, Qsz, 2)); | |
174 case 'rightstop', | |
175 %termprob(k,i,t) Might terminate if i=Qsz; will not terminate if i<Qsz | |
176 stopprob = 1-selfprob; | |
177 termprob{d} = zeros(Qpsz, Qsz, 2); | |
178 termprob{d}(:,Qsz,2) = stopprob; | |
179 termprob{d}(:,Qsz,1) = 1-stopprob; | |
180 termprob{d}(:,1:(Qsz-1),1) = 1; | |
181 otherwise, error(['unrecognized termprob ' termprob{d}]) | |
182 end | |
183 elseif d>1 % passed in termprob{d}(k,j) | |
184 temp = termprob{d}; | |
185 termprob{d} = zeros(Qpsz, Qsz, 2); | |
186 termprob{d}(:,:,2) = temp; | |
187 termprob{d}(:,:,1) = ones(Qpsz,Qsz) - temp; | |
188 end | |
189 end | |
190 | |
191 | |
192 % SLICE 1 | |
193 | |
194 for d=1:D | |
195 bnet.CPD{eclass(Qnodes(d),1)} = tabular_CPD(bnet, Qnodes(d), 'CPT', startprob{d}, 'adjustable', clamp1); | |
196 end | |
197 | |
198 if F1 | |
199 d = 1; | |
200 bnet.CPD{eclass(Fnodes_ndx(d),1)} = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ... | |
201 'termprob', termprob{d}); | |
202 end | |
203 for d=2:D-1 | |
204 if allQ | |
205 Qps = Qnodes(1:d-1); | |
206 else | |
207 Qps = Qnodes(d-1); | |
208 end | |
209 bnet.CPD{eclass(Fnodes_ndx(d),1)} = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ... | |
210 'Qps', Qps, 'termprob', termprob{d}); | |
211 end | |
212 bnet.CPD{eclass(Fnodes_ndx(D),1)} = tabular_CPD(bnet, Fnodes_ndx(D), 'CPT', termprob{D}); | |
213 | |
214 if discrete_obs | |
215 bnet.CPD{eclass(Onode,1)} = tabular_CPD(bnet, Onode, Oargs{:}); | |
216 else | |
217 bnet.CPD{eclass(Onode,1)} = gaussian_CPD(bnet, Onode, Oargs{:}); | |
218 end | |
219 | |
220 % SLICE 2 | |
221 | |
222 %for d=1:D | |
223 % bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, Qnodes, d, D, ... | |
224 % 'startprob', startprob{d}, 'transprob', transprob{d}, ... | |
225 % 'allQ', allQ); | |
226 %end | |
227 | |
228 d = 1; | |
229 if F1 | |
230 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ... | |
231 'Fbelow', Fnodes_ndx(d+1), ... | |
232 'startprob', startprob{d}, 'transprob', transprob{d}); | |
233 else | |
234 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, ... | |
235 'Fbelow', Fnodes_ndx(d+1), ... | |
236 'startprob', startprob{d}, 'transprob', transprob{d}); | |
237 end | |
238 for d=2:D-1 | |
239 if allQ | |
240 Qps = Qnodes(1:d-1); | |
241 else | |
242 Qps = Qnodes(d-1); | |
243 end | |
244 Qps = Qps + ss; % since all in slice 2 | |
245 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ... | |
246 'Fbelow', Fnodes_ndx(d+1), 'Qps', Qps, ... | |
247 'startprob', startprob{d}, 'transprob', transprob{d}); | |
248 end | |
249 d = D; | |
250 if allQ | |
251 Qps = Qnodes(1:d-1); | |
252 else | |
253 Qps = Qnodes(d-1); | |
254 end | |
255 Qps = Qps + ss; % since all in slice 2 | |
256 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ... | |
257 'Qps', Qps, ... | |
258 'startprob', startprob{d}, 'transprob', transprob{d}); |