comparison toolboxes/FullBNT-1.0.7/HMM/dhmm_em.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [LL, prior, transmat, obsmat, nrIterations] = ...
2 dhmm_em(data, prior, transmat, obsmat, varargin)
3 % LEARN_DHMM Find the ML/MAP parameters of an HMM with discrete outputs using EM.
4 % [ll_trace, prior, transmat, obsmat, iterNr] = learn_dhmm(data, prior0, transmat0, obsmat0, ...)
5 %
6 % Notation: Q(t) = hidden state, Y(t) = observation
7 %
8 % INPUTS:
9 % data{ex} or data(ex,:) if all sequences have the same length
10 % prior(i)
11 % transmat(i,j)
12 % obsmat(i,o)
13 %
14 % Optional parameters may be passed as 'param_name', param_value pairs.
15 % Parameter names are shown below; default values in [] - if none, argument is mandatory.
16 %
17 % 'max_iter' - max number of EM iterations [10]
18 % 'thresh' - convergence threshold [1e-4]
19 % 'verbose' - if 1, print out loglik at every iteration [1]
20 % 'obs_prior_weight' - weight to apply to uniform dirichlet prior on observation matrix [0]
21 %
22 % To clamp some of the parameters, so learning does not change them:
23 % 'adj_prior' - if 0, do not change prior [1]
24 % 'adj_trans' - if 0, do not change transmat [1]
25 % 'adj_obs' - if 0, do not change obsmat [1]
26 %
27 % Modified by Herbert Jaeger so xi are not computed individually
28 % but only their sum (over time) as xi_summed; this is the only way how they are used
29 % and it saves a lot of memory.
30
31 [max_iter, thresh, verbose, obs_prior_weight, adj_prior, adj_trans, adj_obs] = ...
32 process_options(varargin, 'max_iter', 10, 'thresh', 1e-4, 'verbose', 1, ...
33 'obs_prior_weight', 0, 'adj_prior', 1, 'adj_trans', 1, 'adj_obs', 1);
34
35 previous_loglik = -inf;
36 loglik = 0;
37 converged = 0;
38 num_iter = 1;
39 LL = [];
40
41 if ~iscell(data)
42 data = num2cell(data, 2); % each row gets its own cell
43 end
44
45 while (num_iter <= max_iter) & ~converged
46 % E step
47 [loglik, exp_num_trans, exp_num_visits1, exp_num_emit] = ...
48 compute_ess_dhmm(prior, transmat, obsmat, data, obs_prior_weight);
49
50 % M step
51 if adj_prior
52 prior = normalise(exp_num_visits1);
53 end
54 if adj_trans & ~isempty(exp_num_trans)
55 transmat = mk_stochastic(exp_num_trans);
56 end
57 if adj_obs
58 obsmat = mk_stochastic(exp_num_emit);
59 end
60
61 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
62 num_iter = num_iter + 1;
63 converged = em_converged(loglik, previous_loglik, thresh);
64 previous_loglik = loglik;
65 LL = [LL loglik];
66 end
67 nrIterations = num_iter - 1;
68
69 %%%%%%%%%%%%%%%%%%%%%%%
70
71 function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ...
72 compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet)
73 % COMPUTE_ESS_DHMM Compute the Expected Sufficient Statistics for an HMM with discrete outputs
74 % function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ...
75 % compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet)
76 %
77 % INPUTS:
78 % startprob(i)
79 % transmat(i,j)
80 % obsmat(i,o)
81 % data{seq}(t)
82 % dirichlet - weighting term for uniform dirichlet prior on expected emissions
83 %
84 % OUTPUTS:
85 % exp_num_trans(i,j) = sum_l sum_{t=2}^T Pr(X(t-1) = i, X(t) = j| Obs(l))
86 % exp_num_visits1(i) = sum_l Pr(X(1)=i | Obs(l))
87 % exp_num_visitsT(i) = sum_l Pr(X(T)=i | Obs(l))
88 % exp_num_emit(i,o) = sum_l sum_{t=1}^T Pr(X(t) = i, O(t)=o| Obs(l))
89 % where Obs(l) = O_1 .. O_T for sequence l.
90
91 numex = length(data);
92 [S O] = size(obsmat);
93 exp_num_trans = zeros(S,S);
94 exp_num_visits1 = zeros(S,1);
95 exp_num_visitsT = zeros(S,1);
96 exp_num_emit = dirichlet*ones(S,O);
97 loglik = 0;
98
99 for ex=1:numex
100 obs = data{ex};
101 T = length(obs);
102 %obslik = eval_pdf_cond_multinomial(obs, obsmat);
103 obslik = multinomial_prob(obs, obsmat);
104 [alpha, beta, gamma, current_ll, xi_summed] = fwdback(startprob, transmat, obslik);
105
106 loglik = loglik + current_ll;
107 exp_num_trans = exp_num_trans + xi_summed;
108 exp_num_visits1 = exp_num_visits1 + gamma(:,1);
109 exp_num_visitsT = exp_num_visitsT + gamma(:,T);
110 % loop over whichever is shorter
111 if T < O
112 for t=1:T
113 o = obs(t);
114 exp_num_emit(:,o) = exp_num_emit(:,o) + gamma(:,t);
115 end
116 else
117 for o=1:O
118 ndx = find(obs==o);
119 if ~isempty(ndx)
120 exp_num_emit(:,o) = exp_num_emit(:,o) + sum(gamma(:, ndx), 2);
121 end
122 end
123 end
124 end