Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/HMM/mhmm_em.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [LL, prior, transmat, mu, Sigma, mixmat] = ... | |
2 mhmm_em(data, prior, transmat, mu, Sigma, mixmat, varargin); | |
3 % LEARN_MHMM Compute the ML parameters of an HMM with (mixtures of) Gaussians output using EM. | |
4 % [ll_trace, prior, transmat, mu, sigma, mixmat] = learn_mhmm(data, ... | |
5 % prior0, transmat0, mu0, sigma0, mixmat0, ...) | |
6 % | |
7 % Notation: Q(t) = hidden state, Y(t) = observation, M(t) = mixture variable | |
8 % | |
9 % INPUTS: | |
10 % data{ex}(:,t) or data(:,t,ex) if all sequences have the same length | |
11 % prior(i) = Pr(Q(1) = i), | |
12 % transmat(i,j) = Pr(Q(t+1)=j | Q(t)=i) | |
13 % mu(:,j,k) = E[Y(t) | Q(t)=j, M(t)=k ] | |
14 % Sigma(:,:,j,k) = Cov[Y(t) | Q(t)=j, M(t)=k] | |
15 % mixmat(j,k) = Pr(M(t)=k | Q(t)=j) : set to [] or ones(Q,1) if only one mixture component | |
16 % | |
17 % Optional parameters may be passed as 'param_name', param_value pairs. | |
18 % Parameter names are shown below; default values in [] - if none, argument is mandatory. | |
19 % | |
20 % 'max_iter' - max number of EM iterations [10] | |
21 % 'thresh' - convergence threshold [1e-4] | |
22 % 'verbose' - if 1, print out loglik at every iteration [1] | |
23 % 'cov_type' - 'full', 'diag' or 'spherical' ['full'] | |
24 % | |
25 % To clamp some of the parameters, so learning does not change them: | |
26 % 'adj_prior' - if 0, do not change prior [1] | |
27 % 'adj_trans' - if 0, do not change transmat [1] | |
28 % 'adj_mix' - if 0, do not change mixmat [1] | |
29 % 'adj_mu' - if 0, do not change mu [1] | |
30 % 'adj_Sigma' - if 0, do not change Sigma [1] | |
31 % | |
32 % If the number of mixture components differs depending on Q, just set the trailing | |
33 % entries of mixmat to 0, e.g., 2 components if Q=1, 3 components if Q=2, | |
34 % then set mixmat(1,3)=0. In this case, B2(1,3,:)=1.0. | |
35 | |
36 if ~isstr(varargin{1}) % catch old syntax | |
37 error('optional arguments should be passed as string/value pairs') | |
38 end | |
39 | |
40 [max_iter, thresh, verbose, cov_type, adj_prior, adj_trans, adj_mix, adj_mu, adj_Sigma] = ... | |
41 process_options(varargin, 'max_iter', 10, 'thresh', 1e-4, 'verbose', 1, ... | |
42 'cov_type', 'full', 'adj_prior', 1, 'adj_trans', 1, 'adj_mix', 1, ... | |
43 'adj_mu', 1, 'adj_Sigma', 1); | |
44 | |
45 previous_loglik = -inf; | |
46 loglik = 0; | |
47 converged = 0; | |
48 num_iter = 1; | |
49 LL = []; | |
50 | |
51 if ~iscell(data) | |
52 data = num2cell(data, [1 2]); % each elt of the 3rd dim gets its own cell | |
53 end | |
54 numex = length(data); | |
55 | |
56 | |
57 O = size(data{1},1); | |
58 Q = length(prior); | |
59 if isempty(mixmat) | |
60 mixmat = ones(Q,1); | |
61 end | |
62 M = size(mixmat,2); | |
63 if M == 1 | |
64 adj_mix = 0; | |
65 end | |
66 | |
67 while (num_iter <= max_iter) & ~converged | |
68 % E step | |
69 [loglik, exp_num_trans, exp_num_visits1, postmix, m, ip, op] = ... | |
70 ess_mhmm(prior, transmat, mixmat, mu, Sigma, data); | |
71 | |
72 | |
73 % M step | |
74 if adj_prior | |
75 prior = normalise(exp_num_visits1); | |
76 end | |
77 if adj_trans | |
78 transmat = mk_stochastic(exp_num_trans); | |
79 end | |
80 if adj_mix | |
81 mixmat = mk_stochastic(postmix); | |
82 end | |
83 if adj_mu | adj_Sigma | |
84 [mu2, Sigma2] = mixgauss_Mstep(postmix, m, op, ip, 'cov_type', cov_type); | |
85 if adj_mu | |
86 mu = reshape(mu2, [O Q M]); | |
87 end | |
88 if adj_Sigma | |
89 Sigma = reshape(Sigma2, [O O Q M]); | |
90 end | |
91 end | |
92 | |
93 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end | |
94 num_iter = num_iter + 1; | |
95 converged = em_converged(loglik, previous_loglik, thresh); | |
96 previous_loglik = loglik; | |
97 LL = [LL loglik]; | |
98 end | |
99 | |
100 | |
101 %%%%%%%%% | |
102 | |
103 function [loglik, exp_num_trans, exp_num_visits1, postmix, m, ip, op] = ... | |
104 ess_mhmm(prior, transmat, mixmat, mu, Sigma, data) | |
105 % ESS_MHMM Compute the Expected Sufficient Statistics for a MOG Hidden Markov Model. | |
106 % | |
107 % Outputs: | |
108 % exp_num_trans(i,j) = sum_l sum_{t=2}^T Pr(Q(t-1) = i, Q(t) = j| Obs(l)) | |
109 % exp_num_visits1(i) = sum_l Pr(Q(1)=i | Obs(l)) | |
110 % | |
111 % Let w(i,k,t,l) = P(Q(t)=i, M(t)=k | Obs(l)) | |
112 % where Obs(l) = Obs(:,:,l) = O_1 .. O_T for sequence l | |
113 % Then | |
114 % postmix(i,k) = sum_l sum_t w(i,k,t,l) (posterior mixing weights/ responsibilities) | |
115 % m(:,i,k) = sum_l sum_t w(i,k,t,l) * Obs(:,t,l) | |
116 % ip(i,k) = sum_l sum_t w(i,k,t,l) * Obs(:,t,l)' * Obs(:,t,l) | |
117 % op(:,:,i,k) = sum_l sum_t w(i,k,t,l) * Obs(:,t,l) * Obs(:,t,l)' | |
118 | |
119 | |
120 verbose = 0; | |
121 | |
122 %[O T numex] = size(data); | |
123 numex = length(data); | |
124 O = size(data{1},1); | |
125 Q = length(prior); | |
126 M = size(mixmat,2); | |
127 exp_num_trans = zeros(Q,Q); | |
128 exp_num_visits1 = zeros(Q,1); | |
129 postmix = zeros(Q,M); | |
130 m = zeros(O,Q,M); | |
131 op = zeros(O,O,Q,M); | |
132 ip = zeros(Q,M); | |
133 | |
134 mix = (M>1); | |
135 | |
136 loglik = 0; | |
137 if verbose, fprintf(1, 'forwards-backwards example # '); end | |
138 for ex=1:numex | |
139 if verbose, fprintf(1, '%d ', ex); end | |
140 %obs = data(:,:,ex); | |
141 obs = data{ex}; | |
142 T = size(obs,2); | |
143 if mix | |
144 [B, B2] = mixgauss_prob(obs, mu, Sigma, mixmat); | |
145 [alpha, beta, gamma, current_loglik, xi, gamma2] = ... | |
146 fwdback(prior, transmat, B, 'obslik2', B2, 'mixmat', mixmat); | |
147 else | |
148 B = mixgauss_prob(obs, mu, Sigma); | |
149 [alpha, beta, gamma, current_loglik, xi] = fwdback(prior, transmat, B); | |
150 end | |
151 loglik = loglik + current_loglik; | |
152 if verbose, fprintf(1, 'll at ex %d = %f\n', ex, loglik); end | |
153 | |
154 exp_num_trans = exp_num_trans + sum(xi,3); | |
155 exp_num_visits1 = exp_num_visits1 + gamma(:,1); | |
156 | |
157 if mix | |
158 postmix = postmix + sum(gamma2,3); | |
159 else | |
160 postmix = postmix + sum(gamma,2); | |
161 gamma2 = reshape(gamma, [Q 1 T]); % gamma2(i,m,t) = gamma(i,t) | |
162 end | |
163 for i=1:Q | |
164 for k=1:M | |
165 w = reshape(gamma2(i,k,:), [1 T]); % w(t) = w(i,k,t,l) | |
166 wobs = obs .* repmat(w, [O 1]); % wobs(:,t) = w(t) * obs(:,t) | |
167 m(:,i,k) = m(:,i,k) + sum(wobs, 2); % m(:) = sum_t w(t) obs(:,t) | |
168 op(:,:,i,k) = op(:,:,i,k) + wobs * obs'; % op(:,:) = sum_t w(t) * obs(:,t) * obs(:,t)' | |
169 ip(i,k) = ip(i,k) + sum(sum(wobs .* obs, 2)); % ip = sum_t w(t) * obs(:,t)' * obs(:,t) | |
170 end | |
171 end | |
172 end | |
173 if verbose, fprintf(1, '\n'); end |