comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/fhmm_infer.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [loglik, gamma] = fhmm_infer(inter, CPTs_slice1, CPTs, obsmat, node_sizes)
2 % FHMM_INFER Exact inference for a factorial HMM.
3 % [loglik, gamma] = fhmm_infer(inter, CPTs_slice1, CPTs, obsmat, node_sizes)
4 %
5 % Inputs:
6 % inter - the inter-slice adjacency matrix
7 % CPTs_slice1{s}(j) = Pr(Q(s,1) = j) where Q(s,t) = hidden node s in slice t
8 % CPT{s}(i1, i2, ..., j) = Pr(Q(s,t) = j | Pa(s,t-1) = i1, i2, ...),
9 % obsmat(i,t) = Pr(y(t) | Q(t)=i)
10 % node_sizes is a vector with the cardinality of the hidden nodes
11 %
12 % Outputs:
13 % gamma(i,t) = Pr(X(t)=i | O(1:T)) as in an HMM,
14 % except that i is interpreted as an M digit, base-K number (if there are M chains each of cardinality K).
15 %
16 %
17 % For M chains each of cardinality K, the frontiers (i.e., cliques)
18 % contain M+1 nodes, and it takes M steps to advance the frontier by one time step,
19 % so the run time is O(T M K^(M+1)).
20 % An HMM takes O(T S^2) where S is the size of the state space.
21 % Collapsing the FHMM to an HMM results in S = K^M.
22 % For details, see
23 % "The Factored Frontier Algorithm for Approximate Inference in DBNs",
24 % Kevin Murphy and Yair Weiss, submitted to NIPS 2000.
25 %
26 % The frontier algorithm makes the following topological assumptions:
27 %
28 % - All nodes are persistent (connect to the next slice)
29 % - No connections within a timeslice
30 % - There is a single observation variable, which depends on all the hidden nodes
31 % - Each node can have several parents in the previous time slice (generalizes a FHMM slightly)
32 %
33
34 % The forwards pass of the frontier algorithm can be explained with the following example.
35 % Suppose we have 3 hidden nodes per slice, A, B, C.
36 % The goal is to compute alpha(j, t) = Pr( (A_t,B_t,C_t)=j | Y(1:t))
37 % We move alpha from t to t+1 one node at a time, as follows.
38 % We define the following quantities:
39 % s([a1 b1 c1], 1) = Prob(A(t)=a1, B(t)=b1, C(t)=c1 | Y(1:t)) = alpha(j, t)
40 % s([a2 b1 c1], 2) = Prob(A(t+1)=a2, B(t)=b1, C(t)=c1 | Y(1:t))
41 % s([a2 b2 c1], 3) = Prob(A(t+1)=a2, B(t+1)=b2, C(t)=c1 | Y(1:t))
42 % s([a2 b2 c2], 4) = Prob(A(t+1)=a2, B(t+1)=b2, C(t+1)=c2 | Y(1:t))
43 % s([a2 b2 c2], 5) = Prob(A(t+1)=a2, B(t+1)=b2, C(t+1)=c2 | Y(1:t+1)) = alpha(j, t+1)
44 %
45 % These can be computed recursively as follows:
46 %
47 % s([a2 b1 c1], 2) = sum_{a1} P(a2|a1) s([a1 b1 c1], 1)
48 % s([a2 b2 c1], 3) = sum_{b1} P(b2|b1) s([a2 b1 c1], 2)
49 % s([a2 b2 c2], 4) = sum_{c1} P(c2|c1) s([a2 b2 c1], 1)
50 % s([a2 b2 c2], 5) = normalise( s([a2 b2 c2], 4) .* P(Y(t+1)|a2,b2,c2)
51
52
53 [kk,ll,mm] = make_frontier_indices(inter, node_sizes); % can pass in as args
54
55 scaled = 1;
56
57 M = length(node_sizes);
58 S = prod(node_sizes);
59 T = size(obsmat, 2);
60
61 alpha = zeros(S, T);
62 beta = zeros(S, T);
63 gamma = zeros(S, T);
64 scale = zeros(1,T);
65 tiny = exp(-700);
66
67
68 alpha(:,1) = make_prior_from_CPTs(CPTs_slice1, node_sizes);
69 alpha(:,1) = alpha(:,1) .* obsmat(:, 1);
70
71 if scaled
72 s = sum(alpha(:,1));
73 if s==0, s = s + tiny; end
74 scale(1) = 1/s;
75 else
76 scale(1) = 1;
77 end
78 alpha(:,1) = alpha(:,1) * scale(1);
79
80 %a = zeros(S, M+1);
81 %b = zeros(S, M+1);
82 anew = zeros(S,1);
83 aold = zeros(S,1);
84 bnew = zeros(S,1);
85 bold = zeros(S,1);
86
87 for t=2:T
88 %a(:,1) = alpha(:,t-1);
89 aold = alpha(:,t-1);
90
91 c = 1;
92 for i=1:M
93 ns = node_sizes(i);
94 cpt = CPTs{i};
95 for j=1:S
96 s = 0;
97 for xx=1:ns
98 %k = kk(xx,j,i);
99 %l = ll(xx,j,i);
100 k = kk(c);
101 l = ll(c);
102 c = c + 1;
103 % s = s + a(k,i) * CPTs{i}(l);
104 s = s + aold(k) * cpt(l);
105 end
106 %a(j,i+1) = s;
107 anew(j) = s;
108 end
109 aold = anew;
110 end
111
112 %alpha(:,t) = a(:,M+1) .* obsmat(:, obs(t));
113 alpha(:,t) = anew .* obsmat(:, t);
114
115 if scaled
116 s = sum(alpha(:,t));
117 if s==0, s = s + tiny; end
118 scale(t) = 1/s;
119 else
120 scale(t) = 1;
121 end
122 alpha(:,t) = alpha(:,t) * scale(t);
123
124 end
125
126
127 beta(:,T) = ones(S,1) * scale(T);
128 for t=T-1:-1:1
129 %b(:,1) = beta(:,t+1) .* obsmat(:, obs(t+1));
130 bold = beta(:,t+1) .* obsmat(:, t+1);
131
132 c = 1;
133 for i=1:M
134 ns = node_sizes(i);
135 cpt = CPTs{i};
136 for j=1:S
137 s = 0;
138 for xx=1:ns
139 %k = kk(xx,j,i);
140 %m = mm(xx,j,i);
141 k = kk(c);
142 m = mm(c);
143 c = c + 1;
144 % s = s + b(k,i) * CPTs{i}(m);
145 s = s + bold(k) * cpt(m);
146 end
147 %b(j,i+1) = s;
148 bnew(j) = s;
149 end
150 bold = bnew;
151 end
152 % beta(:,t) = b(:,M+1) * scale(t);
153 beta(:,t) = bnew * scale(t);
154 end
155
156
157 if scaled
158 loglik = -sum(log(scale)); % scale(i) is finite
159 else
160 lik = alpha(:,1)' * beta(:,1);
161 loglik = log(lik+tiny);
162 end
163
164 for t=1:T
165 gamma(:,t) = normalise(alpha(:,t) .* beta(:,t));
166 end
167
168 %%%%%%%%%%%
169
170 function [kk,ll,mm] = make_frontier_indices(inter, node_sizes)
171 %
172 % Precompute indices for use in the frontier algorithm.
173 % These only depend on the topology, not the parameters or data.
174 % Hence we can compute them outside of fhmm_infer.
175 % This saves a lot of run-time computation.
176
177 M = length(node_sizes);
178 S = prod(node_sizes);
179
180 mns = max(node_sizes);
181 kk = zeros(mns, S, M);
182 ll = zeros(mns, S, M);
183 mm = zeros(mns, S, M);
184
185 for i=1:M
186 for j=1:S
187 u = ind2subv(node_sizes, j);
188 x = u(i);
189 for xx=1:node_sizes(i)
190 uu = u;
191 uu(i) = xx;
192 k = subv2ind(node_sizes, uu);
193 kk(xx,j,i) = k;
194 ps = find(inter(:,i)==1);
195 ps = ps(:)';
196 l = subv2ind(node_sizes([ps i]), [uu(ps) x]); % sum over parent
197 ll(xx,j,i) = l;
198 m = subv2ind(node_sizes([ps i]), [u(ps) xx]); % sum over child
199 mm(xx,j,i) = m;
200 end
201 end
202 end
203
204 %%%%%%%%%
205
206 function prior=make_prior_from_CPTs(indiv_priors, node_sizes)
207 %
208 % composite_prior=make_prior(individual_priors, node_sizes)
209 % Make the prior for the first node in a Markov chain
210 % from the priors on each node in the equivalent DBN.
211 % prior{i}(j) = Pr(X_i=j), where X_i is the i'th node in slice 1.
212 % composite_prior(i) = Pr(slice1 = i).
213
214 n = length(indiv_priors);
215 S = prod(node_sizes);
216 prior = zeros(S,1);
217 for i=1:S
218 vi = ind2subv(node_sizes, i);
219 p = 1;
220 for k=1:n
221 p = p * indiv_priors{k}(vi(k));
222 end
223 prior(i) = p;
224 end
225
226
227
228 %%%%%%%%%%%
229
230 function [loglik, alpha, beta] = FHMM_slow(inter, CPTs_slice1, CPTs, obsmat, node_sizes, data)
231 %
232 % Same as the above, except we don't use the optimization of computing the indices outside the loop.
233
234
235 scaled = 1;
236
237 M = length(node_sizes);
238 S = prod(node_sizes);
239 [numex T] = size(data);
240
241 obs = data;
242
243 alpha = zeros(S, T);
244 beta = zeros(S, T);
245 a = zeros(S, M+1);
246 b = zeros(S, M+1);
247 scale = zeros(1,T);
248
249 alpha(:,1) = make_prior_from_CPTs(CPTs_slice1, node_sizes);
250 alpha(:,1) = alpha(:,1) .* obsmat(:, obs(1));
251 if scaled
252 s = sum(alpha(:,1));
253 if s==0, s = s + tiny; end
254 scale(1) = 1/s;
255 else
256 scale(1) = 1;
257 end
258 alpha(:,1) = alpha(:,1) * scale(1);
259
260 for t=2:T
261 fprintf(1, 't %d\n', t);
262 a(:,1) = alpha(:,t-1);
263 for i=1:M
264 for j=1:S
265 u = ind2subv(node_sizes, j);
266 xnew = u(i);
267 s = 0;
268 for xold=1:node_sizes(i)
269 uold = u;
270 uold(i) = xold;
271 k = subv2ind(node_sizes, uold);
272 ps = find(inter(:,i)==1);
273 ps = ps(:)';
274 l = subv2ind(node_sizes([ps i]), [uold(ps) xnew]);
275 s = s + a(k,i) * CPTs{i}(l);
276 end
277 a(j,i+1) = s;
278 end
279 end
280 alpha(:,t) = a(:,M+1) .* obsmat(:, obs(t));
281
282 if scaled
283 s = sum(alpha(:,t));
284 if s==0, s = s + tiny; end
285 scale(t) = 1/s;
286 else
287 scale(t) = 1;
288 end
289 alpha(:,t) = alpha(:,t) * scale(t);
290
291 end
292
293
294 beta(:,T) = ones(S,1) * scale(T);
295 for t=T-1:-1:1
296 fprintf(1, 't %d\n', t);
297 b(:,1) = beta(:,t+1) .* obsmat(:, obs(t+1));
298 for i=1:M
299 for j=1:S
300 u = ind2subv(node_sizes, j);
301 xold = u(i);
302 s = 0;
303 for xnew=1:node_sizes(i)
304 unew = u;
305 unew(i) = xnew;
306 k = subv2ind(node_sizes, unew);
307 ps = find(inter(:,i)==1);
308 ps = ps(:)';
309 l = subv2ind(node_sizes([ps i]), [u(ps) xnew]);
310 s = s + b(k,i) * CPTs{i}(l);
311 end
312 b(j,i+1) = s;
313 end
314 end
315 beta(:,t) = b(:,M+1) * scale(t);
316 end
317
318
319 if scaled
320 loglik = -sum(log(scale)); % scale(i) is finite
321 else
322 lik = alpha(:,1)' * beta(:,1);
323 loglik = log(lik+tiny);
324 end