wolffd@0: function [alpha, beta, gamma, loglik, xi_summed, gamma2] = fwdback(init_state_distrib, ... wolffd@0: transmat, obslik, varargin) wolffd@0: % FWDBACK Compute the posterior probs. in an HMM using the forwards backwards algo. wolffd@0: % wolffd@0: % [alpha, beta, gamma, loglik, xi, gamma2] = fwdback(init_state_distrib, transmat, obslik, ...) wolffd@0: % wolffd@0: % Notation: wolffd@0: % Y(t) = observation, Q(t) = hidden state, M(t) = mixture variable (for MOG outputs) wolffd@0: % A(t) = discrete input (action) (for POMDP models) wolffd@0: % wolffd@0: % INPUT: wolffd@0: % init_state_distrib(i) = Pr(Q(1) = i) wolffd@0: % transmat(i,j) = Pr(Q(t) = j | Q(t-1)=i) wolffd@0: % or transmat{a}(i,j) = Pr(Q(t) = j | Q(t-1)=i, A(t-1)=a) if there are discrete inputs wolffd@0: % obslik(i,t) = Pr(Y(t)| Q(t)=i) wolffd@0: % (Compute obslik using eval_pdf_xxx on your data sequence first.) wolffd@0: % wolffd@0: % Optional parameters may be passed as 'param_name', param_value pairs. wolffd@0: % Parameter names are shown below; default values in [] - if none, argument is mandatory. wolffd@0: % wolffd@0: % For HMMs with MOG outputs: if you want to compute gamma2, you must specify wolffd@0: % 'obslik2' - obslik(i,j,t) = Pr(Y(t)| Q(t)=i,M(t)=j) [] wolffd@0: % 'mixmat' - mixmat(i,j) = Pr(M(t) = j | Q(t)=i) [] wolffd@0: % wolffd@0: % For HMMs with discrete inputs: wolffd@0: % 'act' - act(t) = action performed at step t wolffd@0: % wolffd@0: % Optional arguments: wolffd@0: % 'fwd_only' - if 1, only do a forwards pass and set beta=[], gamma2=[] [0] wolffd@0: % 'scaled' - if 1, normalize alphas and betas to prevent underflow [1] wolffd@0: % 'maximize' - if 1, use max-product instead of sum-product [0] wolffd@0: % wolffd@0: % OUTPUTS: wolffd@0: % alpha(i,t) = p(Q(t)=i | y(1:t)) (or p(Q(t)=i, y(1:t)) if scaled=0) wolffd@0: % beta(i,t) = p(y(t+1:T) | Q(t)=i)*p(y(t+1:T)|y(1:t)) (or p(y(t+1:T) | Q(t)=i) if scaled=0) wolffd@0: % gamma(i,t) = p(Q(t)=i | y(1:T)) wolffd@0: % loglik = log p(y(1:T)) wolffd@0: % xi(i,j,t-1) = p(Q(t-1)=i, Q(t)=j | y(1:T)) - NO LONGER COMPUTED wolffd@0: % xi_summed(i,j) = sum_{t=}^{T-1} xi(i,j,t) - changed made by Herbert Jaeger wolffd@0: % gamma2(j,k,t) = p(Q(t)=j, M(t)=k | y(1:T)) (only for MOG outputs) wolffd@0: % wolffd@0: % If fwd_only = 1, these become wolffd@0: % alpha(i,t) = p(Q(t)=i | y(1:t)) wolffd@0: % beta = [] wolffd@0: % gamma(i,t) = p(Q(t)=i | y(1:t)) wolffd@0: % xi(i,j,t-1) = p(Q(t-1)=i, Q(t)=j | y(1:t)) wolffd@0: % gamma2 = [] wolffd@0: % wolffd@0: % Note: we only compute xi if it is requested as a return argument, since it can be very large. wolffd@0: % Similarly, we only compute gamma2 on request (and if using MOG outputs). wolffd@0: % wolffd@0: % Examples: wolffd@0: % wolffd@0: % [alpha, beta, gamma, loglik] = fwdback(pi, A, multinomial_prob(sequence, B)); wolffd@0: % wolffd@0: % [B, B2] = mixgauss_prob(data, mu, Sigma, mixmat); wolffd@0: % [alpha, beta, gamma, loglik, xi, gamma2] = fwdback(pi, A, B, 'obslik2', B2, 'mixmat', mixmat); wolffd@0: wolffd@0: if nargout >= 5, compute_xi = 1; else compute_xi = 0; end wolffd@0: if nargout >= 6, compute_gamma2 = 1; else compute_gamma2 = 0; end wolffd@0: wolffd@0: [obslik2, mixmat, fwd_only, scaled, act, maximize, compute_xi, compute_gamma2] = ... wolffd@0: process_options(varargin, ... wolffd@0: 'obslik2', [], 'mixmat', [], ... wolffd@0: 'fwd_only', 0, 'scaled', 1, 'act', [], 'maximize', 0, ... wolffd@0: 'compute_xi', compute_xi, 'compute_gamma2', compute_gamma2); wolffd@0: wolffd@0: [Q T] = size(obslik); wolffd@0: wolffd@0: if isempty(obslik2) wolffd@0: compute_gamma2 = 0; wolffd@0: end wolffd@0: wolffd@0: if isempty(act) wolffd@0: act = ones(1,T); wolffd@0: transmat = { transmat } ; wolffd@0: end wolffd@0: wolffd@0: scale = ones(1,T); wolffd@0: wolffd@0: % scale(t) = Pr(O(t) | O(1:t-1)) = 1/c(t) as defined by Rabiner (1989). wolffd@0: % Hence prod_t scale(t) = Pr(O(1)) Pr(O(2)|O(1)) Pr(O(3) | O(1:2)) ... = Pr(O(1), ... ,O(T)) wolffd@0: % or log P = sum_t log scale(t). wolffd@0: % Rabiner suggests multiplying beta(t) by scale(t), but we can instead wolffd@0: % normalise beta(t) - the constants will cancel when we compute gamma. wolffd@0: wolffd@0: loglik = 0; wolffd@0: wolffd@0: alpha = zeros(Q,T); wolffd@0: gamma = zeros(Q,T); wolffd@0: if compute_xi wolffd@0: xi_summed = zeros(Q,Q); wolffd@0: else wolffd@0: xi_summed = []; wolffd@0: end wolffd@0: wolffd@0: %%%%%%%%% Forwards %%%%%%%%%% wolffd@0: wolffd@0: t = 1; wolffd@0: alpha(:,1) = init_state_distrib(:) .* obslik(:,t); wolffd@0: if scaled wolffd@0: %[alpha(:,t), scale(t)] = normaliseC(alpha(:,t)); wolffd@0: [alpha(:,t), scale(t)] = normalise(alpha(:,t)); wolffd@0: end wolffd@0: assert(approxeq(sum(alpha(:,t)),1)) wolffd@0: for t=2:T wolffd@0: %trans = transmat(:,:,act(t-1))'; wolffd@0: trans = transmat{act(t-1)}; wolffd@0: if maximize wolffd@0: m = max_mult(trans', alpha(:,t-1)); wolffd@0: %A = repmat(alpha(:,t-1), [1 Q]); wolffd@0: %m = max(trans .* A, [], 1); wolffd@0: else wolffd@0: m = trans' * alpha(:,t-1); wolffd@0: end wolffd@0: alpha(:,t) = m(:) .* obslik(:,t); wolffd@0: if scaled wolffd@0: %[alpha(:,t), scale(t)] = normaliseC(alpha(:,t)); wolffd@0: [alpha(:,t), scale(t)] = normalise(alpha(:,t)); wolffd@0: end wolffd@0: if compute_xi & fwd_only % useful for online EM wolffd@0: %xi(:,:,t-1) = normaliseC((alpha(:,t-1) * obslik(:,t)') .* trans); wolffd@0: xi_summed = xi_summed + normalise((alpha(:,t-1) * obslik(:,t)') .* trans); wolffd@0: end wolffd@0: assert(approxeq(sum(alpha(:,t)),1)) wolffd@0: end wolffd@0: if scaled wolffd@0: if any(scale==0) wolffd@0: loglik = -inf; wolffd@0: else wolffd@0: loglik = sum(log(scale)); wolffd@0: end wolffd@0: else wolffd@0: loglik = log(sum(alpha(:,T))); wolffd@0: end wolffd@0: wolffd@0: if fwd_only wolffd@0: gamma = alpha; wolffd@0: beta = []; wolffd@0: gamma2 = []; wolffd@0: return; wolffd@0: end wolffd@0: wolffd@0: %%%%%%%%% Backwards %%%%%%%%%% wolffd@0: wolffd@0: beta = zeros(Q,T); wolffd@0: if compute_gamma2 wolffd@0: M = size(mixmat, 2); wolffd@0: gamma2 = zeros(Q,M,T); wolffd@0: else wolffd@0: gamma2 = []; wolffd@0: end wolffd@0: wolffd@0: beta(:,T) = ones(Q,1); wolffd@0: %gamma(:,T) = normaliseC(alpha(:,T) .* beta(:,T)); wolffd@0: gamma(:,T) = normalise(alpha(:,T) .* beta(:,T)); wolffd@0: t=T; wolffd@0: if compute_gamma2 wolffd@0: denom = obslik(:,t) + (obslik(:,t)==0); % replace 0s with 1s before dividing wolffd@0: gamma2(:,:,t) = obslik2(:,:,t) .* mixmat .* repmat(gamma(:,t), [1 M]) ./ repmat(denom, [1 M]); wolffd@0: %gamma2(:,:,t) = normaliseC(obslik2(:,:,t) .* mixmat .* repmat(gamma(:,t), [1 M])); % wrong! wolffd@0: end wolffd@0: for t=T-1:-1:1 wolffd@0: b = beta(:,t+1) .* obslik(:,t+1); wolffd@0: %trans = transmat(:,:,act(t)); wolffd@0: trans = transmat{act(t)}; wolffd@0: if maximize wolffd@0: B = repmat(b(:)', Q, 1); wolffd@0: beta(:,t) = max(trans .* B, [], 2); wolffd@0: else wolffd@0: beta(:,t) = trans * b; wolffd@0: end wolffd@0: if scaled wolffd@0: %beta(:,t) = normaliseC(beta(:,t)); wolffd@0: beta(:,t) = normalise(beta(:,t)); wolffd@0: end wolffd@0: %gamma(:,t) = normaliseC(alpha(:,t) .* beta(:,t)); wolffd@0: gamma(:,t) = normalise(alpha(:,t) .* beta(:,t)); wolffd@0: if compute_xi wolffd@0: %xi(:,:,t) = normaliseC((trans .* (alpha(:,t) * b'))); wolffd@0: xi_summed = xi_summed + normalise((trans .* (alpha(:,t) * b'))); wolffd@0: end wolffd@0: if compute_gamma2 wolffd@0: denom = obslik(:,t) + (obslik(:,t)==0); % replace 0s with 1s before dividing wolffd@0: gamma2(:,:,t) = obslik2(:,:,t) .* mixmat .* repmat(gamma(:,t), [1 M]) ./ repmat(denom, [1 M]); wolffd@0: %gamma2(:,:,t) = normaliseC(obslik2(:,:,t) .* mixmat .* repmat(gamma(:,t), [1 M])); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % We now explain the equation for gamma2 wolffd@0: % Let zt=y(1:t-1,t+1:T) be all observations except y(t) wolffd@0: % gamma2(Q,M,t) = P(Qt,Mt|yt,zt) = P(yt|Qt,Mt,zt) P(Qt,Mt|zt) / P(yt|zt) wolffd@0: % = P(yt|Qt,Mt) P(Mt|Qt) P(Qt|zt) / P(yt|zt) wolffd@0: % Now gamma(Q,t) = P(Qt|yt,zt) = P(yt|Qt) P(Qt|zt) / P(yt|zt) wolffd@0: % hence wolffd@0: % P(Qt,Mt|yt,zt) = P(yt|Qt,Mt) P(Mt|Qt) [P(Qt|yt,zt) P(yt|zt) / P(yt|Qt)] / P(yt|zt) wolffd@0: % = P(yt|Qt,Mt) P(Mt|Qt) P(Qt|yt,zt) / P(yt|Qt)