wolffd@0: function [LL, prior, transmat, mu, Sigma, mixmat] = ... wolffd@0: mhmm_em(data, prior, transmat, mu, Sigma, mixmat, varargin); wolffd@0: % LEARN_MHMM Compute the ML parameters of an HMM with (mixtures of) Gaussians output using EM. wolffd@0: % [ll_trace, prior, transmat, mu, sigma, mixmat] = learn_mhmm(data, ... wolffd@0: % prior0, transmat0, mu0, sigma0, mixmat0, ...) wolffd@0: % wolffd@0: % Notation: Q(t) = hidden state, Y(t) = observation, M(t) = mixture variable wolffd@0: % wolffd@0: % INPUTS: wolffd@0: % data{ex}(:,t) or data(:,t,ex) if all sequences have the same length wolffd@0: % prior(i) = Pr(Q(1) = i), wolffd@0: % transmat(i,j) = Pr(Q(t+1)=j | Q(t)=i) wolffd@0: % mu(:,j,k) = E[Y(t) | Q(t)=j, M(t)=k ] wolffd@0: % Sigma(:,:,j,k) = Cov[Y(t) | Q(t)=j, M(t)=k] wolffd@0: % mixmat(j,k) = Pr(M(t)=k | Q(t)=j) : set to [] or ones(Q,1) if only one mixture component wolffd@0: % wolffd@0: % Optional parameters may be passed as 'param_name', param_value pairs. wolffd@0: % Parameter names are shown below; default values in [] - if none, argument is mandatory. wolffd@0: % wolffd@0: % 'max_iter' - max number of EM iterations [10] wolffd@0: % 'thresh' - convergence threshold [1e-4] wolffd@0: % 'verbose' - if 1, print out loglik at every iteration [1] wolffd@0: % 'cov_type' - 'full', 'diag' or 'spherical' ['full'] wolffd@0: % wolffd@0: % To clamp some of the parameters, so learning does not change them: wolffd@0: % 'adj_prior' - if 0, do not change prior [1] wolffd@0: % 'adj_trans' - if 0, do not change transmat [1] wolffd@0: % 'adj_mix' - if 0, do not change mixmat [1] wolffd@0: % 'adj_mu' - if 0, do not change mu [1] wolffd@0: % 'adj_Sigma' - if 0, do not change Sigma [1] wolffd@0: % wolffd@0: % If the number of mixture components differs depending on Q, just set the trailing wolffd@0: % entries of mixmat to 0, e.g., 2 components if Q=1, 3 components if Q=2, wolffd@0: % then set mixmat(1,3)=0. In this case, B2(1,3,:)=1.0. wolffd@0: wolffd@0: if ~isstr(varargin{1}) % catch old syntax wolffd@0: error('optional arguments should be passed as string/value pairs') wolffd@0: end wolffd@0: wolffd@0: [max_iter, thresh, verbose, cov_type, adj_prior, adj_trans, adj_mix, adj_mu, adj_Sigma] = ... wolffd@0: process_options(varargin, 'max_iter', 10, 'thresh', 1e-4, 'verbose', 1, ... wolffd@0: 'cov_type', 'full', 'adj_prior', 1, 'adj_trans', 1, 'adj_mix', 1, ... wolffd@0: 'adj_mu', 1, 'adj_Sigma', 1); wolffd@0: wolffd@0: previous_loglik = -inf; wolffd@0: loglik = 0; wolffd@0: converged = 0; wolffd@0: num_iter = 1; wolffd@0: LL = []; wolffd@0: wolffd@0: if ~iscell(data) wolffd@0: data = num2cell(data, [1 2]); % each elt of the 3rd dim gets its own cell wolffd@0: end wolffd@0: numex = length(data); wolffd@0: wolffd@0: wolffd@0: O = size(data{1},1); wolffd@0: Q = length(prior); wolffd@0: if isempty(mixmat) wolffd@0: mixmat = ones(Q,1); wolffd@0: end wolffd@0: M = size(mixmat,2); wolffd@0: if M == 1 wolffd@0: adj_mix = 0; wolffd@0: end wolffd@0: wolffd@0: while (num_iter <= max_iter) & ~converged wolffd@0: % E step wolffd@0: [loglik, exp_num_trans, exp_num_visits1, postmix, m, ip, op] = ... wolffd@0: ess_mhmm(prior, transmat, mixmat, mu, Sigma, data); wolffd@0: wolffd@0: wolffd@0: % M step wolffd@0: if adj_prior wolffd@0: prior = normalise(exp_num_visits1); wolffd@0: end wolffd@0: if adj_trans wolffd@0: transmat = mk_stochastic(exp_num_trans); wolffd@0: end wolffd@0: if adj_mix wolffd@0: mixmat = mk_stochastic(postmix); wolffd@0: end wolffd@0: if adj_mu | adj_Sigma wolffd@0: [mu2, Sigma2] = mixgauss_Mstep(postmix, m, op, ip, 'cov_type', cov_type); wolffd@0: if adj_mu wolffd@0: mu = reshape(mu2, [O Q M]); wolffd@0: end wolffd@0: if adj_Sigma wolffd@0: Sigma = reshape(Sigma2, [O O Q M]); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if verbose, fprintf(1, 'iteration %d, loglik = %f\n', num_iter, loglik); end wolffd@0: num_iter = num_iter + 1; wolffd@0: converged = em_converged(loglik, previous_loglik, thresh); wolffd@0: previous_loglik = loglik; wolffd@0: LL = [LL loglik]; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: %%%%%%%%% wolffd@0: wolffd@0: function [loglik, exp_num_trans, exp_num_visits1, postmix, m, ip, op] = ... wolffd@0: ess_mhmm(prior, transmat, mixmat, mu, Sigma, data) wolffd@0: % ESS_MHMM Compute the Expected Sufficient Statistics for a MOG Hidden Markov Model. wolffd@0: % wolffd@0: % Outputs: wolffd@0: % exp_num_trans(i,j) = sum_l sum_{t=2}^T Pr(Q(t-1) = i, Q(t) = j| Obs(l)) wolffd@0: % exp_num_visits1(i) = sum_l Pr(Q(1)=i | Obs(l)) wolffd@0: % wolffd@0: % Let w(i,k,t,l) = P(Q(t)=i, M(t)=k | Obs(l)) wolffd@0: % where Obs(l) = Obs(:,:,l) = O_1 .. O_T for sequence l wolffd@0: % Then wolffd@0: % postmix(i,k) = sum_l sum_t w(i,k,t,l) (posterior mixing weights/ responsibilities) wolffd@0: % m(:,i,k) = sum_l sum_t w(i,k,t,l) * Obs(:,t,l) wolffd@0: % ip(i,k) = sum_l sum_t w(i,k,t,l) * Obs(:,t,l)' * Obs(:,t,l) wolffd@0: % op(:,:,i,k) = sum_l sum_t w(i,k,t,l) * Obs(:,t,l) * Obs(:,t,l)' wolffd@0: wolffd@0: wolffd@0: verbose = 0; wolffd@0: wolffd@0: %[O T numex] = size(data); wolffd@0: numex = length(data); wolffd@0: O = size(data{1},1); wolffd@0: Q = length(prior); wolffd@0: M = size(mixmat,2); wolffd@0: exp_num_trans = zeros(Q,Q); wolffd@0: exp_num_visits1 = zeros(Q,1); wolffd@0: postmix = zeros(Q,M); wolffd@0: m = zeros(O,Q,M); wolffd@0: op = zeros(O,O,Q,M); wolffd@0: ip = zeros(Q,M); wolffd@0: wolffd@0: mix = (M>1); wolffd@0: wolffd@0: loglik = 0; wolffd@0: if verbose, fprintf(1, 'forwards-backwards example # '); end wolffd@0: for ex=1:numex wolffd@0: if verbose, fprintf(1, '%d ', ex); end wolffd@0: %obs = data(:,:,ex); wolffd@0: obs = data{ex}; wolffd@0: T = size(obs,2); wolffd@0: if mix wolffd@0: [B, B2] = mixgauss_prob(obs, mu, Sigma, mixmat); wolffd@0: [alpha, beta, gamma, current_loglik, xi, gamma2] = ... wolffd@0: fwdback(prior, transmat, B, 'obslik2', B2, 'mixmat', mixmat); wolffd@0: else wolffd@0: B = mixgauss_prob(obs, mu, Sigma); wolffd@0: [alpha, beta, gamma, current_loglik, xi] = fwdback(prior, transmat, B); wolffd@0: end wolffd@0: loglik = loglik + current_loglik; wolffd@0: if verbose, fprintf(1, 'll at ex %d = %f\n', ex, loglik); end wolffd@0: wolffd@0: exp_num_trans = exp_num_trans + sum(xi,3); wolffd@0: exp_num_visits1 = exp_num_visits1 + gamma(:,1); wolffd@0: wolffd@0: if mix wolffd@0: postmix = postmix + sum(gamma2,3); wolffd@0: else wolffd@0: postmix = postmix + sum(gamma,2); wolffd@0: gamma2 = reshape(gamma, [Q 1 T]); % gamma2(i,m,t) = gamma(i,t) wolffd@0: end wolffd@0: for i=1:Q wolffd@0: for k=1:M wolffd@0: w = reshape(gamma2(i,k,:), [1 T]); % w(t) = w(i,k,t,l) wolffd@0: wobs = obs .* repmat(w, [O 1]); % wobs(:,t) = w(t) * obs(:,t) wolffd@0: m(:,i,k) = m(:,i,k) + sum(wobs, 2); % m(:) = sum_t w(t) obs(:,t) wolffd@0: op(:,:,i,k) = op(:,:,i,k) + wobs * obs'; % op(:,:) = sum_t w(t) * obs(:,t) * obs(:,t)' wolffd@0: ip(i,k) = ip(i,k) + sum(sum(wobs .* obs, 2)); % ip = sum_t w(t) * obs(:,t)' * obs(:,t) wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: if verbose, fprintf(1, '\n'); end