Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/HMM/dhmm_em.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [LL, prior, transmat, obsmat, nrIterations] = ... | |
2 dhmm_em(data, prior, transmat, obsmat, varargin) | |
3 % LEARN_DHMM Find the ML/MAP parameters of an HMM with discrete outputs using EM. | |
4 % [ll_trace, prior, transmat, obsmat, iterNr] = learn_dhmm(data, prior0, transmat0, obsmat0, ...) | |
5 % | |
6 % Notation: Q(t) = hidden state, Y(t) = observation | |
7 % | |
8 % INPUTS: | |
9 % data{ex} or data(ex,:) if all sequences have the same length | |
10 % prior(i) | |
11 % transmat(i,j) | |
12 % obsmat(i,o) | |
13 % | |
14 % Optional parameters may be passed as 'param_name', param_value pairs. | |
15 % Parameter names are shown below; default values in [] - if none, argument is mandatory. | |
16 % | |
17 % 'max_iter' - max number of EM iterations [10] | |
18 % 'thresh' - convergence threshold [1e-4] | |
19 % 'verbose' - if 1, print out loglik at every iteration [1] | |
20 % 'obs_prior_weight' - weight to apply to uniform dirichlet prior on observation matrix [0] | |
21 % | |
22 % To clamp some of the parameters, so learning does not change them: | |
23 % 'adj_prior' - if 0, do not change prior [1] | |
24 % 'adj_trans' - if 0, do not change transmat [1] | |
25 % 'adj_obs' - if 0, do not change obsmat [1] | |
26 % | |
27 % Modified by Herbert Jaeger so xi are not computed individually | |
28 % but only their sum (over time) as xi_summed; this is the only way how they are used | |
29 % and it saves a lot of memory. | |
30 | |
31 [max_iter, thresh, verbose, obs_prior_weight, adj_prior, adj_trans, adj_obs] = ... | |
32 process_options(varargin, 'max_iter', 10, 'thresh', 1e-4, 'verbose', 1, ... | |
33 'obs_prior_weight', 0, 'adj_prior', 1, 'adj_trans', 1, 'adj_obs', 1); | |
34 | |
35 previous_loglik = -inf; | |
36 loglik = 0; | |
37 converged = 0; | |
38 num_iter = 1; | |
39 LL = []; | |
40 | |
41 if ~iscell(data) | |
42 data = num2cell(data, 2); % each row gets its own cell | |
43 end | |
44 | |
45 while (num_iter <= max_iter) & ~converged | |
46 % E step | |
47 [loglik, exp_num_trans, exp_num_visits1, exp_num_emit] = ... | |
48 compute_ess_dhmm(prior, transmat, obsmat, data, obs_prior_weight); | |
49 | |
50 % M step | |
51 if adj_prior | |
52 prior = normalise(exp_num_visits1); | |
53 end | |
54 if adj_trans & ~isempty(exp_num_trans) | |
55 transmat = mk_stochastic(exp_num_trans); | |
56 end | |
57 if adj_obs | |
58 obsmat = mk_stochastic(exp_num_emit); | |
59 end | |
60 | |
61 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end | |
62 num_iter = num_iter + 1; | |
63 converged = em_converged(loglik, previous_loglik, thresh); | |
64 previous_loglik = loglik; | |
65 LL = [LL loglik]; | |
66 end | |
67 nrIterations = num_iter - 1; | |
68 | |
69 %%%%%%%%%%%%%%%%%%%%%%% | |
70 | |
71 function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ... | |
72 compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet) | |
73 % COMPUTE_ESS_DHMM Compute the Expected Sufficient Statistics for an HMM with discrete outputs | |
74 % function [loglik, exp_num_trans, exp_num_visits1, exp_num_emit, exp_num_visitsT] = ... | |
75 % compute_ess_dhmm(startprob, transmat, obsmat, data, dirichlet) | |
76 % | |
77 % INPUTS: | |
78 % startprob(i) | |
79 % transmat(i,j) | |
80 % obsmat(i,o) | |
81 % data{seq}(t) | |
82 % dirichlet - weighting term for uniform dirichlet prior on expected emissions | |
83 % | |
84 % OUTPUTS: | |
85 % exp_num_trans(i,j) = sum_l sum_{t=2}^T Pr(X(t-1) = i, X(t) = j| Obs(l)) | |
86 % exp_num_visits1(i) = sum_l Pr(X(1)=i | Obs(l)) | |
87 % exp_num_visitsT(i) = sum_l Pr(X(T)=i | Obs(l)) | |
88 % exp_num_emit(i,o) = sum_l sum_{t=1}^T Pr(X(t) = i, O(t)=o| Obs(l)) | |
89 % where Obs(l) = O_1 .. O_T for sequence l. | |
90 | |
91 numex = length(data); | |
92 [S O] = size(obsmat); | |
93 exp_num_trans = zeros(S,S); | |
94 exp_num_visits1 = zeros(S,1); | |
95 exp_num_visitsT = zeros(S,1); | |
96 exp_num_emit = dirichlet*ones(S,O); | |
97 loglik = 0; | |
98 | |
99 for ex=1:numex | |
100 obs = data{ex}; | |
101 T = length(obs); | |
102 %obslik = eval_pdf_cond_multinomial(obs, obsmat); | |
103 obslik = multinomial_prob(obs, obsmat); | |
104 [alpha, beta, gamma, current_ll, xi_summed] = fwdback(startprob, transmat, obslik); | |
105 | |
106 loglik = loglik + current_ll; | |
107 exp_num_trans = exp_num_trans + xi_summed; | |
108 exp_num_visits1 = exp_num_visits1 + gamma(:,1); | |
109 exp_num_visitsT = exp_num_visitsT + gamma(:,T); | |
110 % loop over whichever is shorter | |
111 if T < O | |
112 for t=1:T | |
113 o = obs(t); | |
114 exp_num_emit(:,o) = exp_num_emit(:,o) + gamma(:,t); | |
115 end | |
116 else | |
117 for o=1:O | |
118 ndx = find(obs==o); | |
119 if ~isempty(ndx) | |
120 exp_num_emit(:,o) = exp_num_emit(:,o) + sum(gamma(:, ndx), 2); | |
121 end | |
122 end | |
123 end | |
124 end |