annotate toolboxes/FullBNT-1.0.7/HMM/dhmm_em.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [LL, prior, transmat, obsmat, nrIterations] = ...
wolffd@0 2 dhmm_em(data, prior, transmat, obsmat, varargin)
wolffd@0 3 % LEARN_DHMM Find the ML/MAP parameters of an HMM with discrete outputs using EM.
wolffd@0 4 % [ll_trace, prior, transmat, obsmat, iterNr] = learn_dhmm(data, prior0, transmat0, obsmat0, ...)
wolffd@0 5 %
wolffd@0 6 % Notation: Q(t) = hidden state, Y(t) = observation
wolffd@0 7 %
wolffd@0 8 % INPUTS:
wolffd@0 9 % data{ex} or data(ex,:) if all sequences have the same length
wolffd@0 10 % prior(i)
wolffd@0 11 % transmat(i,j)
wolffd@0 12 % obsmat(i,o)
wolffd@0 13 %
wolffd@0 14 % Optional parameters may be passed as 'param_name', param_value pairs.
wolffd@0 15 % Parameter names are shown below; default values in [] - if none, argument is mandatory.
wolffd@0 16 %
wolffd@0 17 % 'max_iter' - max number of EM iterations [10]
wolffd@0 18 % 'thresh' - convergence threshold [1e-4]
wolffd@0 19 % 'verbose' - if 1, print out loglik at every iteration [1]
wolffd@0 20 % 'obs_prior_weight' - weight to apply to uniform dirichlet prior on observation matrix [0]
wolffd@0 21 %
wolffd@0 22 % To clamp some of the parameters, so learning does not change them:
wolffd@0 23 % 'adj_prior' - if 0, do not change prior [1]
wolffd@0 24 % 'adj_trans' - if 0, do not change transmat [1]
wolffd@0 25 % 'adj_obs' - if 0, do not change obsmat [1]
wolffd@0 26 %
wolffd@0 27 % Modified by Herbert Jaeger so xi are not computed individually
wolffd@0 28 % but only their sum (over time) as xi_summed; this is the only way how they are used
wolffd@0 29 % and it saves a lot of memory.
wolffd@0 30
wolffd@0 31 [max_iter, thresh, verbose, obs_prior_weight, adj_prior, adj_trans, adj_obs] = ...
wolffd@0 32 process_options(varargin, 'max_iter', 10, 'thresh', 1e-4, 'verbose', 1, ...
wolffd@0 33 'obs_prior_weight', 0, 'adj_prior', 1, 'adj_trans', 1, 'adj_obs', 1);
wolffd@0 34
wolffd@0 35 previous_loglik = -inf;
wolffd@0 36 loglik = 0;
wolffd@0 37 converged = 0;
wolffd@0 38 num_iter = 1;
wolffd@0 39 LL = [];
wolffd@0 40
wolffd@0 41 if ~iscell(data)
wolffd@0 42 data = num2cell(data, 2); % each row gets its own cell
wolffd@0 43 end
wolffd@0 44
wolffd@0 45 while (num_iter <= max_iter) & ~converged
wolffd@0 46 % E step
wolffd@0 47 [loglik, exp_num_trans, exp_num_visits1, exp_num_emit] = ...
wolffd@0 48 compute_ess_dhmm(prior, transmat, obsmat, data, obs_prior_weight);
wolffd@0 49
wolffd@0 50 % M step
wolffd@0 51 if adj_prior
wolffd@0 52 prior = normalise(exp_num_visits1);
wolffd@0 53 end
wolffd@0 54 if adj_trans & ~isempty(exp_num_trans)
wolffd@0 55 transmat = mk_stochastic(exp_num_trans);
wolffd@0 56 end
wolffd@0 57 if adj_obs
wolffd@0 58 obsmat = mk_stochastic(exp_num_emit);
wolffd@0 59 end
wolffd@0 60
wolffd@0 61 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
wolffd@0 62 num_iter = num_iter + 1;
wolffd@0 63 converged = em_converged(loglik, previous_loglik, thresh);
wolffd@0 64 previous_loglik = loglik;
wolffd@0 65 LL = [LL loglik];
wolffd@0 66 end
wolffd@0 67 nrIterations = num_iter - 1;
wolffd@0 68
wolffd@0 69 %%%%%%%%%%%%%%%%%%%%%%%
wolffd@0 70
wolffd@0 71 function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ...
wolffd@0 72 compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet)
wolffd@0 73 % COMPUTE_ESS_DHMM Compute the Expected Sufficient Statistics for an HMM with discrete outputs
wolffd@0 74 % function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ...
wolffd@0 75 % compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet)
wolffd@0 76 %
wolffd@0 77 % INPUTS:
wolffd@0 78 % startprob(i)
wolffd@0 79 % transmat(i,j)
wolffd@0 80 % obsmat(i,o)
wolffd@0 81 % data{seq}(t)
wolffd@0 82 % dirichlet - weighting term for uniform dirichlet prior on expected emissions
wolffd@0 83 %
wolffd@0 84 % OUTPUTS:
wolffd@0 85 % exp_num_trans(i,j) = sum_l sum_{t=2}^T Pr(X(t-1) = i, X(t) = j| Obs(l))
wolffd@0 86 % exp_num_visits1(i) = sum_l Pr(X(1)=i | Obs(l))
wolffd@0 87 % exp_num_visitsT(i) = sum_l Pr(X(T)=i | Obs(l))
wolffd@0 88 % exp_num_emit(i,o) = sum_l sum_{t=1}^T Pr(X(t) = i, O(t)=o| Obs(l))
wolffd@0 89 % where Obs(l) = O_1 .. O_T for sequence l.
wolffd@0 90
wolffd@0 91 numex = length(data);
wolffd@0 92 [S O] = size(obsmat);
wolffd@0 93 exp_num_trans = zeros(S,S);
wolffd@0 94 exp_num_visits1 = zeros(S,1);
wolffd@0 95 exp_num_visitsT = zeros(S,1);
wolffd@0 96 exp_num_emit = dirichlet*ones(S,O);
wolffd@0 97 loglik = 0;
wolffd@0 98
wolffd@0 99 for ex=1:numex
wolffd@0 100 obs = data{ex};
wolffd@0 101 T = length(obs);
wolffd@0 102 %obslik = eval_pdf_cond_multinomial(obs, obsmat);
wolffd@0 103 obslik = multinomial_prob(obs, obsmat);
wolffd@0 104 [alpha, beta, gamma, current_ll, xi_summed] = fwdback(startprob, transmat, obslik);
wolffd@0 105
wolffd@0 106 loglik = loglik + current_ll;
wolffd@0 107 exp_num_trans = exp_num_trans + xi_summed;
wolffd@0 108 exp_num_visits1 = exp_num_visits1 + gamma(:,1);
wolffd@0 109 exp_num_visitsT = exp_num_visitsT + gamma(:,T);
wolffd@0 110 % loop over whichever is shorter
wolffd@0 111 if T < O
wolffd@0 112 for t=1:T
wolffd@0 113 o = obs(t);
wolffd@0 114 exp_num_emit(:,o) = exp_num_emit(:,o) + gamma(:,t);
wolffd@0 115 end
wolffd@0 116 else
wolffd@0 117 for o=1:O
wolffd@0 118 ndx = find(obs==o);
wolffd@0 119 if ~isempty(ndx)
wolffd@0 120 exp_num_emit(:,o) = exp_num_emit(:,o) + sum(gamma(:, ndx), 2);
wolffd@0 121 end
wolffd@0 122 end
wolffd@0 123 end
wolffd@0 124 end