wolffd@0: function [LL, prior, transmat, obsmat, nrIterations] = ... wolffd@0: dhmm_em(data, prior, transmat, obsmat, varargin) wolffd@0: % LEARN_DHMM Find the ML/MAP parameters of an HMM with discrete outputs using EM. wolffd@0: % [ll_trace, prior, transmat, obsmat, iterNr] = learn_dhmm(data, prior0, transmat0, obsmat0, ...) wolffd@0: % wolffd@0: % Notation: Q(t) = hidden state, Y(t) = observation wolffd@0: % wolffd@0: % INPUTS: wolffd@0: % data{ex} or data(ex,:) if all sequences have the same length wolffd@0: % prior(i) wolffd@0: % transmat(i,j) wolffd@0: % obsmat(i,o) wolffd@0: % wolffd@0: % Optional parameters may be passed as 'param_name', param_value pairs. wolffd@0: % Parameter names are shown below; default values in [] - if none, argument is mandatory. wolffd@0: % wolffd@0: % 'max_iter' - max number of EM iterations [10] wolffd@0: % 'thresh' - convergence threshold [1e-4] wolffd@0: % 'verbose' - if 1, print out loglik at every iteration [1] wolffd@0: % 'obs_prior_weight' - weight to apply to uniform dirichlet prior on observation matrix [0] wolffd@0: % wolffd@0: % To clamp some of the parameters, so learning does not change them: wolffd@0: % 'adj_prior' - if 0, do not change prior [1] wolffd@0: % 'adj_trans' - if 0, do not change transmat [1] wolffd@0: % 'adj_obs' - if 0, do not change obsmat [1] wolffd@0: % wolffd@0: % Modified by Herbert Jaeger so xi are not computed individually wolffd@0: % but only their sum (over time) as xi_summed; this is the only way how they are used wolffd@0: % and it saves a lot of memory. wolffd@0: wolffd@0: [max_iter, thresh, verbose, obs_prior_weight, adj_prior, adj_trans, adj_obs] = ... wolffd@0: process_options(varargin, 'max_iter', 10, 'thresh', 1e-4, 'verbose', 1, ... wolffd@0: 'obs_prior_weight', 0, 'adj_prior', 1, 'adj_trans', 1, 'adj_obs', 1); wolffd@0: wolffd@0: previous_loglik = -inf; wolffd@0: loglik = 0; wolffd@0: converged = 0; wolffd@0: num_iter = 1; wolffd@0: LL = []; wolffd@0: wolffd@0: if ~iscell(data) wolffd@0: data = num2cell(data, 2); % each row gets its own cell wolffd@0: end wolffd@0: wolffd@0: while (num_iter <= max_iter) & ~converged wolffd@0: % E step wolffd@0: [loglik, exp_num_trans, exp_num_visits1, exp_num_emit] = ... wolffd@0: compute_ess_dhmm(prior, transmat, obsmat, data, obs_prior_weight); wolffd@0: wolffd@0: % M step wolffd@0: if adj_prior wolffd@0: prior = normalise(exp_num_visits1); wolffd@0: end wolffd@0: if adj_trans & ~isempty(exp_num_trans) wolffd@0: transmat = mk_stochastic(exp_num_trans); wolffd@0: end wolffd@0: if adj_obs wolffd@0: obsmat = mk_stochastic(exp_num_emit); wolffd@0: end wolffd@0: wolffd@0: if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end wolffd@0: num_iter = num_iter + 1; wolffd@0: converged = em_converged(loglik, previous_loglik, thresh); wolffd@0: previous_loglik = loglik; wolffd@0: LL = [LL loglik]; wolffd@0: end wolffd@0: nrIterations = num_iter - 1; wolffd@0: wolffd@0: %%%%%%%%%%%%%%%%%%%%%%% wolffd@0: wolffd@0: function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ... wolffd@0: compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet) wolffd@0: % COMPUTE_ESS_DHMM Compute the Expected Sufficient Statistics for an HMM with discrete outputs wolffd@0: % function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ... wolffd@0: % compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet) wolffd@0: % wolffd@0: % INPUTS: wolffd@0: % startprob(i) wolffd@0: % transmat(i,j) wolffd@0: % obsmat(i,o) wolffd@0: % data{seq}(t) wolffd@0: % dirichlet - weighting term for uniform dirichlet prior on expected emissions wolffd@0: % wolffd@0: % OUTPUTS: wolffd@0: % exp_num_trans(i,j) = sum_l sum_{t=2}^T Pr(X(t-1) = i, X(t) = j| Obs(l)) wolffd@0: % exp_num_visits1(i) = sum_l Pr(X(1)=i | Obs(l)) wolffd@0: % exp_num_visitsT(i) = sum_l Pr(X(T)=i | Obs(l)) wolffd@0: % exp_num_emit(i,o) = sum_l sum_{t=1}^T Pr(X(t) = i, O(t)=o| Obs(l)) wolffd@0: % where Obs(l) = O_1 .. O_T for sequence l. wolffd@0: wolffd@0: numex = length(data); wolffd@0: [S O] = size(obsmat); wolffd@0: exp_num_trans = zeros(S,S); wolffd@0: exp_num_visits1 = zeros(S,1); wolffd@0: exp_num_visitsT = zeros(S,1); wolffd@0: exp_num_emit = dirichlet*ones(S,O); wolffd@0: loglik = 0; wolffd@0: wolffd@0: for ex=1:numex wolffd@0: obs = data{ex}; wolffd@0: T = length(obs); wolffd@0: %obslik = eval_pdf_cond_multinomial(obs, obsmat); wolffd@0: obslik = multinomial_prob(obs, obsmat); wolffd@0: [alpha, beta, gamma, current_ll, xi_summed] = fwdback(startprob, transmat, obslik); wolffd@0: wolffd@0: loglik = loglik + current_ll; wolffd@0: exp_num_trans = exp_num_trans + xi_summed; wolffd@0: exp_num_visits1 = exp_num_visits1 + gamma(:,1); wolffd@0: exp_num_visitsT = exp_num_visitsT + gamma(:,T); wolffd@0: % loop over whichever is shorter wolffd@0: if T < O wolffd@0: for t=1:T wolffd@0: o = obs(t); wolffd@0: exp_num_emit(:,o) = exp_num_emit(:,o) + gamma(:,t); wolffd@0: end wolffd@0: else wolffd@0: for o=1:O wolffd@0: ndx = find(obs==o); wolffd@0: if ~isempty(ndx) wolffd@0: exp_num_emit(:,o) = exp_num_emit(:,o) + sum(gamma(:, ndx), 2); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: end