diff toolboxes/FullBNT-1.0.7/HMM/dhmm_em.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/HMM/dhmm_em.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,124 @@
+function [LL, prior, transmat, obsmat, nrIterations] = ...
+   dhmm_em(data, prior, transmat, obsmat, varargin)
+% LEARN_DHMM Find the ML/MAP parameters of an HMM with discrete outputs using EM.
+% [ll_trace, prior, transmat, obsmat, iterNr] = learn_dhmm(data, prior0, transmat0, obsmat0, ...)
+%
+% Notation: Q(t) = hidden state, Y(t) = observation
+%
+% INPUTS:
+% data{ex} or data(ex,:) if all sequences have the same length
+% prior(i)
+% transmat(i,j)
+% obsmat(i,o)
+%
+% Optional parameters may be passed as 'param_name', param_value pairs.
+% Parameter names are shown below; default values in [] - if none, argument is mandatory.
+%
+% 'max_iter' - max number of EM iterations [10]
+% 'thresh' - convergence threshold [1e-4]
+% 'verbose' - if 1, print out loglik at every iteration [1]
+% 'obs_prior_weight' - weight to apply to uniform dirichlet prior on observation matrix [0]
+%
+% To clamp some of the parameters, so learning does not change them:
+% 'adj_prior' - if 0, do not change prior [1]
+% 'adj_trans' - if 0, do not change transmat [1]
+% 'adj_obs' - if 0, do not change obsmat [1]
+%
+% Modified by Herbert Jaeger so xi are not computed individually
+% but only their sum (over time) as xi_summed; this is the only way how they are used
+% and it saves a lot of memory.
+
+[max_iter, thresh, verbose, obs_prior_weight, adj_prior, adj_trans, adj_obs] = ...
+   process_options(varargin, 'max_iter', 10, 'thresh', 1e-4, 'verbose', 1, ...
+                   'obs_prior_weight', 0, 'adj_prior', 1, 'adj_trans', 1, 'adj_obs', 1);
+
+previous_loglik = -inf;
+loglik = 0;
+converged = 0;
+num_iter = 1;
+LL = [];
+
+if ~iscell(data)
+ data = num2cell(data, 2); % each row gets its own cell
+end
+
+while (num_iter <= max_iter) & ~converged
+ % E step
+ [loglik, exp_num_trans, exp_num_visits1, exp_num_emit] = ...
+     compute_ess_dhmm(prior, transmat, obsmat, data, obs_prior_weight);
+
+ % M step
+ if adj_prior
+   prior = normalise(exp_num_visits1);
+ end
+ if adj_trans & ~isempty(exp_num_trans)
+   transmat = mk_stochastic(exp_num_trans);
+ end
+ if adj_obs
+   obsmat = mk_stochastic(exp_num_emit);
+ end
+
+ if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
+ num_iter =  num_iter + 1;
+ converged = em_converged(loglik, previous_loglik, thresh);
+ previous_loglik = loglik;
+ LL = [LL loglik];
+end
+nrIterations = num_iter - 1;
+
+%%%%%%%%%%%%%%%%%%%%%%%
+
+function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ...
+   compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet)
+% COMPUTE_ESS_DHMM Compute the Expected Sufficient Statistics for an HMM with discrete outputs
+% function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ...
+%    compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet)
+%
+% INPUTS:
+% startprob(i)
+% transmat(i,j)
+% obsmat(i,o)
+% data{seq}(t)
+% dirichlet - weighting term for uniform dirichlet prior on expected emissions
+%
+% OUTPUTS:
+% exp_num_trans(i,j) = sum_l sum_{t=2}^T Pr(X(t-1) = i, X(t) = j| Obs(l))
+% exp_num_visits1(i) = sum_l Pr(X(1)=i | Obs(l))
+% exp_num_visitsT(i) = sum_l Pr(X(T)=i | Obs(l))
+% exp_num_emit(i,o) = sum_l sum_{t=1}^T Pr(X(t) = i, O(t)=o| Obs(l))
+% where Obs(l) = O_1 .. O_T for sequence l.
+
+numex = length(data);
+[S O] = size(obsmat);
+exp_num_trans = zeros(S,S);
+exp_num_visits1 = zeros(S,1);
+exp_num_visitsT = zeros(S,1);
+exp_num_emit = dirichlet*ones(S,O);
+loglik = 0;
+
+for ex=1:numex
+ obs = data{ex};
+ T = length(obs);
+ %obslik = eval_pdf_cond_multinomial(obs, obsmat);
+ obslik = multinomial_prob(obs, obsmat);
+ [alpha, beta, gamma, current_ll, xi_summed] = fwdback(startprob, transmat, obslik);
+
+ loglik = loglik +  current_ll;
+ exp_num_trans = exp_num_trans + xi_summed;
+ exp_num_visits1 = exp_num_visits1 + gamma(:,1);
+ exp_num_visitsT = exp_num_visitsT + gamma(:,T);
+ % loop over whichever is shorter
+ if T < O
+   for t=1:T
+     o = obs(t);
+     exp_num_emit(:,o) = exp_num_emit(:,o) + gamma(:,t);
+   end
+ else
+   for o=1:O
+     ndx = find(obs==o);
+     if ~isempty(ndx)
+       exp_num_emit(:,o) = exp_num_emit(:,o) + sum(gamma(:, ndx), 2);
+     end
+   end
+ end
+end