wolffd@0: function [transmat, obsmat, exp_num_trans, exp_num_emit, gamma, ll] = dhmm_em_online(... wolffd@0: prior, transmat, obsmat, exp_num_trans, exp_num_emit, decay, data, ... wolffd@0: act, adj_trans, adj_obs, dirichlet, filter_only) wolffd@0: % ONLINE_EM Adjust the parameters using a weighted combination of the old and new expected statistics wolffd@0: % wolffd@0: % [transmat, obsmat, exp_num_trans, exp_num_emit, gamma, ll] = online_em(... wolffd@0: % prior, transmat, obsmat, exp_num_trans, exp_num_emit, decay, data, act, ... wolffd@0: % adj_trans, adj_obs, dirichlet, filter_only) wolffd@0: % wolffd@0: % 0 < decay < 1, with smaller values meaning the past is forgotten more quickly. wolffd@0: % (We need to decay the old ess, since they were based on out-of-date parameters.) wolffd@0: % The other params are as in learn_hmm. wolffd@0: % We do a single forwards-backwards pass on the provided data, initializing with the specified prior. wolffd@0: % (If filter_only = 1, we only do a forwards pass.) wolffd@0: wolffd@0: if ~exist('act'), act = []; end wolffd@0: if ~exist('adj_trans'), adj_trans = 1; end wolffd@0: if ~exist('adj_obs'), adj_obs = 1; end wolffd@0: if ~exist('dirichlet'), dirichlet = 0; end wolffd@0: if ~exist('filter_only'), filter_only = 0; end wolffd@0: wolffd@0: % E step wolffd@0: olikseq = multinomial_prob(data, obsmat); wolffd@0: if isempty(act) wolffd@0: [alpha, beta, gamma, ll, xi] = fwdback(prior, transmat, olikseq, 'fwd_only', filter_only); wolffd@0: else wolffd@0: [alpha, beta, gamma, ll, xi] = fwdback(prior, transmat, olikseq, 'fwd_only', filter_only, ... wolffd@0: 'act', act); wolffd@0: end wolffd@0: wolffd@0: % Increment ESS wolffd@0: [S O] = size(obsmat); wolffd@0: if adj_obs wolffd@0: exp_num_emit = decay*exp_num_emit + dirichlet*ones(S,O); wolffd@0: T = length(data); wolffd@0: if T < O wolffd@0: for t=1:T wolffd@0: o = data(t); wolffd@0: exp_num_emit(:,o) = exp_num_emit(:,o) + gamma(:,t); wolffd@0: end wolffd@0: else wolffd@0: for o=1:O wolffd@0: ndx = find(data==o); wolffd@0: if ~isempty(ndx) wolffd@0: exp_num_emit(:,o) = exp_num_emit(:,o) + sum(gamma(:, ndx), 2); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if adj_trans & (T > 1) wolffd@0: if isempty(act) wolffd@0: exp_num_trans = decay*exp_num_trans + sum(xi,3); wolffd@0: else wolffd@0: % act(2) determines Q(2), xi(:,:,1) holds P(Q(1), Q(2)) wolffd@0: A = length(transmat); wolffd@0: for a=1:A wolffd@0: ndx = find(act(2:end)==a); wolffd@0: if ~isempty(ndx) wolffd@0: exp_num_trans{a} = decay*exp_num_trans{a} + sum(xi(:,:,ndx), 3); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: % M step wolffd@0: wolffd@0: if adj_obs wolffd@0: obsmat = mk_stochastic(exp_num_emit); wolffd@0: end wolffd@0: if adj_trans & (T>1) wolffd@0: if isempty(act) wolffd@0: transmat = mk_stochastic(exp_num_trans); wolffd@0: else wolffd@0: for a=1:A wolffd@0: transmat{a} = mk_stochastic(exp_num_trans{a}); wolffd@0: end wolffd@0: end wolffd@0: end