annotate toolboxes/FullBNT-1.0.7/HMM/mhmm_em.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [LL, prior, transmat, mu, Sigma, mixmat] = ...
wolffd@0 2 mhmm_em(data, prior, transmat, mu, Sigma, mixmat, varargin);
wolffd@0 3 % LEARN_MHMM Compute the ML parameters of an HMM with (mixtures of) Gaussians output using EM.
wolffd@0 4 % [ll_trace, prior, transmat, mu, sigma, mixmat] = learn_mhmm(data, ...
wolffd@0 5 % prior0, transmat0, mu0, sigma0, mixmat0, ...)
wolffd@0 6 %
wolffd@0 7 % Notation: Q(t) = hidden state, Y(t) = observation, M(t) = mixture variable
wolffd@0 8 %
wolffd@0 9 % INPUTS:
wolffd@0 10 % data{ex}(:,t) or data(:,t,ex) if all sequences have the same length
wolffd@0 11 % prior(i) = Pr(Q(1) = i),
wolffd@0 12 % transmat(i,j) = Pr(Q(t+1)=j | Q(t)=i)
wolffd@0 13 % mu(:,j,k) = E[Y(t) | Q(t)=j, M(t)=k ]
wolffd@0 14 % Sigma(:,:,j,k) = Cov[Y(t) | Q(t)=j, M(t)=k]
wolffd@0 15 % mixmat(j,k) = Pr(M(t)=k | Q(t)=j) : set to [] or ones(Q,1) if only one mixture component
wolffd@0 16 %
wolffd@0 17 % Optional parameters may be passed as 'param_name', param_value pairs.
wolffd@0 18 % Parameter names are shown below; default values in [] - if none, argument is mandatory.
wolffd@0 19 %
wolffd@0 20 % 'max_iter' - max number of EM iterations [10]
wolffd@0 21 % 'thresh' - convergence threshold [1e-4]
wolffd@0 22 % 'verbose' - if 1, print out loglik at every iteration [1]
wolffd@0 23 % 'cov_type' - 'full', 'diag' or 'spherical' ['full']
wolffd@0 24 %
wolffd@0 25 % To clamp some of the parameters, so learning does not change them:
wolffd@0 26 % 'adj_prior' - if 0, do not change prior [1]
wolffd@0 27 % 'adj_trans' - if 0, do not change transmat [1]
wolffd@0 28 % 'adj_mix' - if 0, do not change mixmat [1]
wolffd@0 29 % 'adj_mu' - if 0, do not change mu [1]
wolffd@0 30 % 'adj_Sigma' - if 0, do not change Sigma [1]
wolffd@0 31 %
wolffd@0 32 % If the number of mixture components differs depending on Q, just set the trailing
wolffd@0 33 % entries of mixmat to 0, e.g., 2 components if Q=1, 3 components if Q=2,
wolffd@0 34 % then set mixmat(1,3)=0. In this case, B2(1,3,:)=1.0.
wolffd@0 35
wolffd@0 36 if ~isstr(varargin{1}) % catch old syntax
wolffd@0 37 error('optional arguments should be passed as string/value pairs')
wolffd@0 38 end
wolffd@0 39
wolffd@0 40 [max_iter, thresh, verbose, cov_type, adj_prior, adj_trans, adj_mix, adj_mu, adj_Sigma] = ...
wolffd@0 41 process_options(varargin, 'max_iter', 10, 'thresh', 1e-4, 'verbose', 1, ...
wolffd@0 42 'cov_type', 'full', 'adj_prior', 1, 'adj_trans', 1, 'adj_mix', 1, ...
wolffd@0 43 'adj_mu', 1, 'adj_Sigma', 1);
wolffd@0 44
wolffd@0 45 previous_loglik = -inf;
wolffd@0 46 loglik = 0;
wolffd@0 47 converged = 0;
wolffd@0 48 num_iter = 1;
wolffd@0 49 LL = [];
wolffd@0 50
wolffd@0 51 if ~iscell(data)
wolffd@0 52 data = num2cell(data, [1 2]); % each elt of the 3rd dim gets its own cell
wolffd@0 53 end
wolffd@0 54 numex = length(data);
wolffd@0 55
wolffd@0 56
wolffd@0 57 O = size(data{1},1);
wolffd@0 58 Q = length(prior);
wolffd@0 59 if isempty(mixmat)
wolffd@0 60 mixmat = ones(Q,1);
wolffd@0 61 end
wolffd@0 62 M = size(mixmat,2);
wolffd@0 63 if M == 1
wolffd@0 64 adj_mix = 0;
wolffd@0 65 end
wolffd@0 66
wolffd@0 67 while (num_iter <= max_iter) & ~converged
wolffd@0 68 % E step
wolffd@0 69 [loglik, exp_num_trans, exp_num_visits1, postmix, m, ip, op] = ...
wolffd@0 70 ess_mhmm(prior, transmat, mixmat, mu, Sigma, data);
wolffd@0 71
wolffd@0 72
wolffd@0 73 % M step
wolffd@0 74 if adj_prior
wolffd@0 75 prior = normalise(exp_num_visits1);
wolffd@0 76 end
wolffd@0 77 if adj_trans
wolffd@0 78 transmat = mk_stochastic(exp_num_trans);
wolffd@0 79 end
wolffd@0 80 if adj_mix
wolffd@0 81 mixmat = mk_stochastic(postmix);
wolffd@0 82 end
wolffd@0 83 if adj_mu | adj_Sigma
wolffd@0 84 [mu2, Sigma2] = mixgauss_Mstep(postmix, m, op, ip, 'cov_type', cov_type);
wolffd@0 85 if adj_mu
wolffd@0 86 mu = reshape(mu2, [O Q M]);
wolffd@0 87 end
wolffd@0 88 if adj_Sigma
wolffd@0 89 Sigma = reshape(Sigma2, [O O Q M]);
wolffd@0 90 end
wolffd@0 91 end
wolffd@0 92
wolffd@0 93 if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end
wolffd@0 94 num_iter = num_iter + 1;
wolffd@0 95 converged = em_converged(loglik, previous_loglik, thresh);
wolffd@0 96 previous_loglik = loglik;
wolffd@0 97 LL = [LL loglik];
wolffd@0 98 end
wolffd@0 99
wolffd@0 100
wolffd@0 101 %%%%%%%%%
wolffd@0 102
wolffd@0 103 function [loglik, exp_num_trans, exp_num_visits1, postmix, m, ip, op] = ...
wolffd@0 104 ess_mhmm(prior, transmat, mixmat, mu, Sigma, data)
wolffd@0 105 % ESS_MHMM Compute the Expected Sufficient Statistics for a MOG Hidden Markov Model.
wolffd@0 106 %
wolffd@0 107 % Outputs:
wolffd@0 108 % exp_num_trans(i,j) = sum_l sum_{t=2}^T Pr(Q(t-1) = i, Q(t) = j| Obs(l))
wolffd@0 109 % exp_num_visits1(i) = sum_l Pr(Q(1)=i | Obs(l))
wolffd@0 110 %
wolffd@0 111 % Let w(i,k,t,l) = P(Q(t)=i, M(t)=k | Obs(l))
wolffd@0 112 % where Obs(l) = Obs(:,:,l) = O_1 .. O_T for sequence l
wolffd@0 113 % Then
wolffd@0 114 % postmix(i,k) = sum_l sum_t w(i,k,t,l) (posterior mixing weights/ responsibilities)
wolffd@0 115 % m(:,i,k) = sum_l sum_t w(i,k,t,l) * Obs(:,t,l)
wolffd@0 116 % ip(i,k) = sum_l sum_t w(i,k,t,l) * Obs(:,t,l)' * Obs(:,t,l)
wolffd@0 117 % op(:,:,i,k) = sum_l sum_t w(i,k,t,l) * Obs(:,t,l) * Obs(:,t,l)'
wolffd@0 118
wolffd@0 119
wolffd@0 120 verbose = 0;
wolffd@0 121
wolffd@0 122 %[O T numex] = size(data);
wolffd@0 123 numex = length(data);
wolffd@0 124 O = size(data{1},1);
wolffd@0 125 Q = length(prior);
wolffd@0 126 M = size(mixmat,2);
wolffd@0 127 exp_num_trans = zeros(Q,Q);
wolffd@0 128 exp_num_visits1 = zeros(Q,1);
wolffd@0 129 postmix = zeros(Q,M);
wolffd@0 130 m = zeros(O,Q,M);
wolffd@0 131 op = zeros(O,O,Q,M);
wolffd@0 132 ip = zeros(Q,M);
wolffd@0 133
wolffd@0 134 mix = (M>1);
wolffd@0 135
wolffd@0 136 loglik = 0;
wolffd@0 137 if verbose, fprintf(1, 'forwards-backwards example # '); end
wolffd@0 138 for ex=1:numex
wolffd@0 139 if verbose, fprintf(1, '%d ', ex); end
wolffd@0 140 %obs = data(:,:,ex);
wolffd@0 141 obs = data{ex};
wolffd@0 142 T = size(obs,2);
wolffd@0 143 if mix
wolffd@0 144 [B, B2] = mixgauss_prob(obs, mu, Sigma, mixmat);
wolffd@0 145 [alpha, beta, gamma, current_loglik, xi, gamma2] = ...
wolffd@0 146 fwdback(prior, transmat, B, 'obslik2', B2, 'mixmat', mixmat);
wolffd@0 147 else
wolffd@0 148 B = mixgauss_prob(obs, mu, Sigma);
wolffd@0 149 [alpha, beta, gamma, current_loglik, xi] = fwdback(prior, transmat, B);
wolffd@0 150 end
wolffd@0 151 loglik = loglik + current_loglik;
wolffd@0 152 if verbose, fprintf(1, 'll at ex %d = %f\n', ex, loglik); end
wolffd@0 153
wolffd@0 154 exp_num_trans = exp_num_trans + sum(xi,3);
wolffd@0 155 exp_num_visits1 = exp_num_visits1 + gamma(:,1);
wolffd@0 156
wolffd@0 157 if mix
wolffd@0 158 postmix = postmix + sum(gamma2,3);
wolffd@0 159 else
wolffd@0 160 postmix = postmix + sum(gamma,2);
wolffd@0 161 gamma2 = reshape(gamma, [Q 1 T]); % gamma2(i,m,t) = gamma(i,t)
wolffd@0 162 end
wolffd@0 163 for i=1:Q
wolffd@0 164 for k=1:M
wolffd@0 165 w = reshape(gamma2(i,k,:), [1 T]); % w(t) = w(i,k,t,l)
wolffd@0 166 wobs = obs .* repmat(w, [O 1]); % wobs(:,t) = w(t) * obs(:,t)
wolffd@0 167 m(:,i,k) = m(:,i,k) + sum(wobs, 2); % m(:) = sum_t w(t) obs(:,t)
wolffd@0 168 op(:,:,i,k) = op(:,:,i,k) + wobs * obs'; % op(:,:) = sum_t w(t) * obs(:,t) * obs(:,t)'
wolffd@0 169 ip(i,k) = ip(i,k) + sum(sum(wobs .* obs, 2)); % ip = sum_t w(t) * obs(:,t)' * obs(:,t)
wolffd@0 170 end
wolffd@0 171 end
wolffd@0 172 end
wolffd@0 173 if verbose, fprintf(1, '\n'); end