wolffd@0
|
1 function [loglik, gamma] = fhmm_infer(inter, CPTs_slice1, CPTs, obsmat, node_sizes)
|
wolffd@0
|
2 % FHMM_INFER Exact inference for a factorial HMM.
|
wolffd@0
|
3 % [loglik, gamma] = fhmm_infer(inter, CPTs_slice1, CPTs, obsmat, node_sizes)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % Inputs:
|
wolffd@0
|
6 % inter - the inter-slice adjacency matrix
|
wolffd@0
|
7 % CPTs_slice1{s}(j) = Pr(Q(s,1) = j) where Q(s,t) = hidden node s in slice t
|
wolffd@0
|
8 % CPT{s}(i1, i2, ..., j) = Pr(Q(s,t) = j | Pa(s,t-1) = i1, i2, ...),
|
wolffd@0
|
9 % obsmat(i,t) = Pr(y(t) | Q(t)=i)
|
wolffd@0
|
10 % node_sizes is a vector with the cardinality of the hidden nodes
|
wolffd@0
|
11 %
|
wolffd@0
|
12 % Outputs:
|
wolffd@0
|
13 % gamma(i,t) = Pr(X(t)=i | O(1:T)) as in an HMM,
|
wolffd@0
|
14 % except that i is interpreted as an M digit, base-K number (if there are M chains each of cardinality K).
|
wolffd@0
|
15 %
|
wolffd@0
|
16 %
|
wolffd@0
|
17 % For M chains each of cardinality K, the frontiers (i.e., cliques)
|
wolffd@0
|
18 % contain M+1 nodes, and it takes M steps to advance the frontier by one time step,
|
wolffd@0
|
19 % so the run time is O(T M K^(M+1)).
|
wolffd@0
|
20 % An HMM takes O(T S^2) where S is the size of the state space.
|
wolffd@0
|
21 % Collapsing the FHMM to an HMM results in S = K^M.
|
wolffd@0
|
22 % For details, see
|
wolffd@0
|
23 % "The Factored Frontier Algorithm for Approximate Inference in DBNs",
|
wolffd@0
|
24 % Kevin Murphy and Yair Weiss, submitted to NIPS 2000.
|
wolffd@0
|
25 %
|
wolffd@0
|
26 % The frontier algorithm makes the following topological assumptions:
|
wolffd@0
|
27 %
|
wolffd@0
|
28 % - All nodes are persistent (connect to the next slice)
|
wolffd@0
|
29 % - No connections within a timeslice
|
wolffd@0
|
30 % - There is a single observation variable, which depends on all the hidden nodes
|
wolffd@0
|
31 % - Each node can have several parents in the previous time slice (generalizes a FHMM slightly)
|
wolffd@0
|
32 %
|
wolffd@0
|
33
|
wolffd@0
|
34 % The forwards pass of the frontier algorithm can be explained with the following example.
|
wolffd@0
|
35 % Suppose we have 3 hidden nodes per slice, A, B, C.
|
wolffd@0
|
36 % The goal is to compute alpha(j, t) = Pr( (A_t,B_t,C_t)=j | Y(1:t))
|
wolffd@0
|
37 % We move alpha from t to t+1 one node at a time, as follows.
|
wolffd@0
|
38 % We define the following quantities:
|
wolffd@0
|
39 % s([a1 b1 c1], 1) = Prob(A(t)=a1, B(t)=b1, C(t)=c1 | Y(1:t)) = alpha(j, t)
|
wolffd@0
|
40 % s([a2 b1 c1], 2) = Prob(A(t+1)=a2, B(t)=b1, C(t)=c1 | Y(1:t))
|
wolffd@0
|
41 % s([a2 b2 c1], 3) = Prob(A(t+1)=a2, B(t+1)=b2, C(t)=c1 | Y(1:t))
|
wolffd@0
|
42 % s([a2 b2 c2], 4) = Prob(A(t+1)=a2, B(t+1)=b2, C(t+1)=c2 | Y(1:t))
|
wolffd@0
|
43 % s([a2 b2 c2], 5) = Prob(A(t+1)=a2, B(t+1)=b2, C(t+1)=c2 | Y(1:t+1)) = alpha(j, t+1)
|
wolffd@0
|
44 %
|
wolffd@0
|
45 % These can be computed recursively as follows:
|
wolffd@0
|
46 %
|
wolffd@0
|
47 % s([a2 b1 c1], 2) = sum_{a1} P(a2|a1) s([a1 b1 c1], 1)
|
wolffd@0
|
48 % s([a2 b2 c1], 3) = sum_{b1} P(b2|b1) s([a2 b1 c1], 2)
|
wolffd@0
|
49 % s([a2 b2 c2], 4) = sum_{c1} P(c2|c1) s([a2 b2 c1], 1)
|
wolffd@0
|
50 % s([a2 b2 c2], 5) = normalise( s([a2 b2 c2], 4) .* P(Y(t+1)|a2,b2,c2)
|
wolffd@0
|
51
|
wolffd@0
|
52
|
wolffd@0
|
53 [kk,ll,mm] = make_frontier_indices(inter, node_sizes); % can pass in as args
|
wolffd@0
|
54
|
wolffd@0
|
55 scaled = 1;
|
wolffd@0
|
56
|
wolffd@0
|
57 M = length(node_sizes);
|
wolffd@0
|
58 S = prod(node_sizes);
|
wolffd@0
|
59 T = size(obsmat, 2);
|
wolffd@0
|
60
|
wolffd@0
|
61 alpha = zeros(S, T);
|
wolffd@0
|
62 beta = zeros(S, T);
|
wolffd@0
|
63 gamma = zeros(S, T);
|
wolffd@0
|
64 scale = zeros(1,T);
|
wolffd@0
|
65 tiny = exp(-700);
|
wolffd@0
|
66
|
wolffd@0
|
67
|
wolffd@0
|
68 alpha(:,1) = make_prior_from_CPTs(CPTs_slice1, node_sizes);
|
wolffd@0
|
69 alpha(:,1) = alpha(:,1) .* obsmat(:, 1);
|
wolffd@0
|
70
|
wolffd@0
|
71 if scaled
|
wolffd@0
|
72 s = sum(alpha(:,1));
|
wolffd@0
|
73 if s==0, s = s + tiny; end
|
wolffd@0
|
74 scale(1) = 1/s;
|
wolffd@0
|
75 else
|
wolffd@0
|
76 scale(1) = 1;
|
wolffd@0
|
77 end
|
wolffd@0
|
78 alpha(:,1) = alpha(:,1) * scale(1);
|
wolffd@0
|
79
|
wolffd@0
|
80 %a = zeros(S, M+1);
|
wolffd@0
|
81 %b = zeros(S, M+1);
|
wolffd@0
|
82 anew = zeros(S,1);
|
wolffd@0
|
83 aold = zeros(S,1);
|
wolffd@0
|
84 bnew = zeros(S,1);
|
wolffd@0
|
85 bold = zeros(S,1);
|
wolffd@0
|
86
|
wolffd@0
|
87 for t=2:T
|
wolffd@0
|
88 %a(:,1) = alpha(:,t-1);
|
wolffd@0
|
89 aold = alpha(:,t-1);
|
wolffd@0
|
90
|
wolffd@0
|
91 c = 1;
|
wolffd@0
|
92 for i=1:M
|
wolffd@0
|
93 ns = node_sizes(i);
|
wolffd@0
|
94 cpt = CPTs{i};
|
wolffd@0
|
95 for j=1:S
|
wolffd@0
|
96 s = 0;
|
wolffd@0
|
97 for xx=1:ns
|
wolffd@0
|
98 %k = kk(xx,j,i);
|
wolffd@0
|
99 %l = ll(xx,j,i);
|
wolffd@0
|
100 k = kk(c);
|
wolffd@0
|
101 l = ll(c);
|
wolffd@0
|
102 c = c + 1;
|
wolffd@0
|
103 % s = s + a(k,i) * CPTs{i}(l);
|
wolffd@0
|
104 s = s + aold(k) * cpt(l);
|
wolffd@0
|
105 end
|
wolffd@0
|
106 %a(j,i+1) = s;
|
wolffd@0
|
107 anew(j) = s;
|
wolffd@0
|
108 end
|
wolffd@0
|
109 aold = anew;
|
wolffd@0
|
110 end
|
wolffd@0
|
111
|
wolffd@0
|
112 %alpha(:,t) = a(:,M+1) .* obsmat(:, obs(t));
|
wolffd@0
|
113 alpha(:,t) = anew .* obsmat(:, t);
|
wolffd@0
|
114
|
wolffd@0
|
115 if scaled
|
wolffd@0
|
116 s = sum(alpha(:,t));
|
wolffd@0
|
117 if s==0, s = s + tiny; end
|
wolffd@0
|
118 scale(t) = 1/s;
|
wolffd@0
|
119 else
|
wolffd@0
|
120 scale(t) = 1;
|
wolffd@0
|
121 end
|
wolffd@0
|
122 alpha(:,t) = alpha(:,t) * scale(t);
|
wolffd@0
|
123
|
wolffd@0
|
124 end
|
wolffd@0
|
125
|
wolffd@0
|
126
|
wolffd@0
|
127 beta(:,T) = ones(S,1) * scale(T);
|
wolffd@0
|
128 for t=T-1:-1:1
|
wolffd@0
|
129 %b(:,1) = beta(:,t+1) .* obsmat(:, obs(t+1));
|
wolffd@0
|
130 bold = beta(:,t+1) .* obsmat(:, t+1);
|
wolffd@0
|
131
|
wolffd@0
|
132 c = 1;
|
wolffd@0
|
133 for i=1:M
|
wolffd@0
|
134 ns = node_sizes(i);
|
wolffd@0
|
135 cpt = CPTs{i};
|
wolffd@0
|
136 for j=1:S
|
wolffd@0
|
137 s = 0;
|
wolffd@0
|
138 for xx=1:ns
|
wolffd@0
|
139 %k = kk(xx,j,i);
|
wolffd@0
|
140 %m = mm(xx,j,i);
|
wolffd@0
|
141 k = kk(c);
|
wolffd@0
|
142 m = mm(c);
|
wolffd@0
|
143 c = c + 1;
|
wolffd@0
|
144 % s = s + b(k,i) * CPTs{i}(m);
|
wolffd@0
|
145 s = s + bold(k) * cpt(m);
|
wolffd@0
|
146 end
|
wolffd@0
|
147 %b(j,i+1) = s;
|
wolffd@0
|
148 bnew(j) = s;
|
wolffd@0
|
149 end
|
wolffd@0
|
150 bold = bnew;
|
wolffd@0
|
151 end
|
wolffd@0
|
152 % beta(:,t) = b(:,M+1) * scale(t);
|
wolffd@0
|
153 beta(:,t) = bnew * scale(t);
|
wolffd@0
|
154 end
|
wolffd@0
|
155
|
wolffd@0
|
156
|
wolffd@0
|
157 if scaled
|
wolffd@0
|
158 loglik = -sum(log(scale)); % scale(i) is finite
|
wolffd@0
|
159 else
|
wolffd@0
|
160 lik = alpha(:,1)' * beta(:,1);
|
wolffd@0
|
161 loglik = log(lik+tiny);
|
wolffd@0
|
162 end
|
wolffd@0
|
163
|
wolffd@0
|
164 for t=1:T
|
wolffd@0
|
165 gamma(:,t) = normalise(alpha(:,t) .* beta(:,t));
|
wolffd@0
|
166 end
|
wolffd@0
|
167
|
wolffd@0
|
168 %%%%%%%%%%%
|
wolffd@0
|
169
|
wolffd@0
|
170 function [kk,ll,mm] = make_frontier_indices(inter, node_sizes)
|
wolffd@0
|
171 %
|
wolffd@0
|
172 % Precompute indices for use in the frontier algorithm.
|
wolffd@0
|
173 % These only depend on the topology, not the parameters or data.
|
wolffd@0
|
174 % Hence we can compute them outside of fhmm_infer.
|
wolffd@0
|
175 % This saves a lot of run-time computation.
|
wolffd@0
|
176
|
wolffd@0
|
177 M = length(node_sizes);
|
wolffd@0
|
178 S = prod(node_sizes);
|
wolffd@0
|
179
|
wolffd@0
|
180 mns = max(node_sizes);
|
wolffd@0
|
181 kk = zeros(mns, S, M);
|
wolffd@0
|
182 ll = zeros(mns, S, M);
|
wolffd@0
|
183 mm = zeros(mns, S, M);
|
wolffd@0
|
184
|
wolffd@0
|
185 for i=1:M
|
wolffd@0
|
186 for j=1:S
|
wolffd@0
|
187 u = ind2subv(node_sizes, j);
|
wolffd@0
|
188 x = u(i);
|
wolffd@0
|
189 for xx=1:node_sizes(i)
|
wolffd@0
|
190 uu = u;
|
wolffd@0
|
191 uu(i) = xx;
|
wolffd@0
|
192 k = subv2ind(node_sizes, uu);
|
wolffd@0
|
193 kk(xx,j,i) = k;
|
wolffd@0
|
194 ps = find(inter(:,i)==1);
|
wolffd@0
|
195 ps = ps(:)';
|
wolffd@0
|
196 l = subv2ind(node_sizes([ps i]), [uu(ps) x]); % sum over parent
|
wolffd@0
|
197 ll(xx,j,i) = l;
|
wolffd@0
|
198 m = subv2ind(node_sizes([ps i]), [u(ps) xx]); % sum over child
|
wolffd@0
|
199 mm(xx,j,i) = m;
|
wolffd@0
|
200 end
|
wolffd@0
|
201 end
|
wolffd@0
|
202 end
|
wolffd@0
|
203
|
wolffd@0
|
204 %%%%%%%%%
|
wolffd@0
|
205
|
wolffd@0
|
206 function prior=make_prior_from_CPTs(indiv_priors, node_sizes)
|
wolffd@0
|
207 %
|
wolffd@0
|
208 % composite_prior=make_prior(individual_priors, node_sizes)
|
wolffd@0
|
209 % Make the prior for the first node in a Markov chain
|
wolffd@0
|
210 % from the priors on each node in the equivalent DBN.
|
wolffd@0
|
211 % prior{i}(j) = Pr(X_i=j), where X_i is the i'th node in slice 1.
|
wolffd@0
|
212 % composite_prior(i) = Pr(slice1 = i).
|
wolffd@0
|
213
|
wolffd@0
|
214 n = length(indiv_priors);
|
wolffd@0
|
215 S = prod(node_sizes);
|
wolffd@0
|
216 prior = zeros(S,1);
|
wolffd@0
|
217 for i=1:S
|
wolffd@0
|
218 vi = ind2subv(node_sizes, i);
|
wolffd@0
|
219 p = 1;
|
wolffd@0
|
220 for k=1:n
|
wolffd@0
|
221 p = p * indiv_priors{k}(vi(k));
|
wolffd@0
|
222 end
|
wolffd@0
|
223 prior(i) = p;
|
wolffd@0
|
224 end
|
wolffd@0
|
225
|
wolffd@0
|
226
|
wolffd@0
|
227
|
wolffd@0
|
228 %%%%%%%%%%%
|
wolffd@0
|
229
|
wolffd@0
|
230 function [loglik, alpha, beta] = FHMM_slow(inter, CPTs_slice1, CPTs, obsmat, node_sizes, data)
|
wolffd@0
|
231 %
|
wolffd@0
|
232 % Same as the above, except we don't use the optimization of computing the indices outside the loop.
|
wolffd@0
|
233
|
wolffd@0
|
234
|
wolffd@0
|
235 scaled = 1;
|
wolffd@0
|
236
|
wolffd@0
|
237 M = length(node_sizes);
|
wolffd@0
|
238 S = prod(node_sizes);
|
wolffd@0
|
239 [numex T] = size(data);
|
wolffd@0
|
240
|
wolffd@0
|
241 obs = data;
|
wolffd@0
|
242
|
wolffd@0
|
243 alpha = zeros(S, T);
|
wolffd@0
|
244 beta = zeros(S, T);
|
wolffd@0
|
245 a = zeros(S, M+1);
|
wolffd@0
|
246 b = zeros(S, M+1);
|
wolffd@0
|
247 scale = zeros(1,T);
|
wolffd@0
|
248
|
wolffd@0
|
249 alpha(:,1) = make_prior_from_CPTs(CPTs_slice1, node_sizes);
|
wolffd@0
|
250 alpha(:,1) = alpha(:,1) .* obsmat(:, obs(1));
|
wolffd@0
|
251 if scaled
|
wolffd@0
|
252 s = sum(alpha(:,1));
|
wolffd@0
|
253 if s==0, s = s + tiny; end
|
wolffd@0
|
254 scale(1) = 1/s;
|
wolffd@0
|
255 else
|
wolffd@0
|
256 scale(1) = 1;
|
wolffd@0
|
257 end
|
wolffd@0
|
258 alpha(:,1) = alpha(:,1) * scale(1);
|
wolffd@0
|
259
|
wolffd@0
|
260 for t=2:T
|
wolffd@0
|
261 fprintf(1, 't %d\n', t);
|
wolffd@0
|
262 a(:,1) = alpha(:,t-1);
|
wolffd@0
|
263 for i=1:M
|
wolffd@0
|
264 for j=1:S
|
wolffd@0
|
265 u = ind2subv(node_sizes, j);
|
wolffd@0
|
266 xnew = u(i);
|
wolffd@0
|
267 s = 0;
|
wolffd@0
|
268 for xold=1:node_sizes(i)
|
wolffd@0
|
269 uold = u;
|
wolffd@0
|
270 uold(i) = xold;
|
wolffd@0
|
271 k = subv2ind(node_sizes, uold);
|
wolffd@0
|
272 ps = find(inter(:,i)==1);
|
wolffd@0
|
273 ps = ps(:)';
|
wolffd@0
|
274 l = subv2ind(node_sizes([ps i]), [uold(ps) xnew]);
|
wolffd@0
|
275 s = s + a(k,i) * CPTs{i}(l);
|
wolffd@0
|
276 end
|
wolffd@0
|
277 a(j,i+1) = s;
|
wolffd@0
|
278 end
|
wolffd@0
|
279 end
|
wolffd@0
|
280 alpha(:,t) = a(:,M+1) .* obsmat(:, obs(t));
|
wolffd@0
|
281
|
wolffd@0
|
282 if scaled
|
wolffd@0
|
283 s = sum(alpha(:,t));
|
wolffd@0
|
284 if s==0, s = s + tiny; end
|
wolffd@0
|
285 scale(t) = 1/s;
|
wolffd@0
|
286 else
|
wolffd@0
|
287 scale(t) = 1;
|
wolffd@0
|
288 end
|
wolffd@0
|
289 alpha(:,t) = alpha(:,t) * scale(t);
|
wolffd@0
|
290
|
wolffd@0
|
291 end
|
wolffd@0
|
292
|
wolffd@0
|
293
|
wolffd@0
|
294 beta(:,T) = ones(S,1) * scale(T);
|
wolffd@0
|
295 for t=T-1:-1:1
|
wolffd@0
|
296 fprintf(1, 't %d\n', t);
|
wolffd@0
|
297 b(:,1) = beta(:,t+1) .* obsmat(:, obs(t+1));
|
wolffd@0
|
298 for i=1:M
|
wolffd@0
|
299 for j=1:S
|
wolffd@0
|
300 u = ind2subv(node_sizes, j);
|
wolffd@0
|
301 xold = u(i);
|
wolffd@0
|
302 s = 0;
|
wolffd@0
|
303 for xnew=1:node_sizes(i)
|
wolffd@0
|
304 unew = u;
|
wolffd@0
|
305 unew(i) = xnew;
|
wolffd@0
|
306 k = subv2ind(node_sizes, unew);
|
wolffd@0
|
307 ps = find(inter(:,i)==1);
|
wolffd@0
|
308 ps = ps(:)';
|
wolffd@0
|
309 l = subv2ind(node_sizes([ps i]), [u(ps) xnew]);
|
wolffd@0
|
310 s = s + b(k,i) * CPTs{i}(l);
|
wolffd@0
|
311 end
|
wolffd@0
|
312 b(j,i+1) = s;
|
wolffd@0
|
313 end
|
wolffd@0
|
314 end
|
wolffd@0
|
315 beta(:,t) = b(:,M+1) * scale(t);
|
wolffd@0
|
316 end
|
wolffd@0
|
317
|
wolffd@0
|
318
|
wolffd@0
|
319 if scaled
|
wolffd@0
|
320 loglik = -sum(log(scale)); % scale(i) is finite
|
wolffd@0
|
321 else
|
wolffd@0
|
322 lik = alpha(:,1)' * beta(:,1);
|
wolffd@0
|
323 loglik = log(lik+tiny);
|
wolffd@0
|
324 end
|