diff toolboxes/FullBNT-1.0.7/HMM/dhmm_em_online.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/HMM/dhmm_em_online.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,80 @@
+function [transmat, obsmat, exp_num_trans, exp_num_emit, gamma, ll] = dhmm_em_online(...
+    prior, transmat, obsmat, exp_num_trans, exp_num_emit, decay, data, ...
+    act, adj_trans, adj_obs, dirichlet, filter_only)
+% ONLINE_EM Adjust the parameters using a weighted combination of the old and new expected statistics
+%
+% [transmat, obsmat, exp_num_trans, exp_num_emit, gamma, ll] = online_em(...
+%    prior, transmat, obsmat, exp_num_trans, exp_num_emit, decay, data, act, ...
+%    adj_trans, adj_obs, dirichlet, filter_only)
+% 
+% 0 < decay < 1, with smaller values meaning the past is forgotten more quickly.
+% (We need to decay the old ess, since they were based on out-of-date parameters.)
+% The other params are as in learn_hmm.
+% We do a single forwards-backwards pass on the provided data, initializing with the specified prior.
+% (If filter_only = 1, we only do a forwards pass.)
+
+if ~exist('act'), act = []; end
+if ~exist('adj_trans'), adj_trans = 1; end
+if ~exist('adj_obs'), adj_obs = 1; end
+if ~exist('dirichlet'), dirichlet = 0; end
+if ~exist('filter_only'), filter_only = 0; end
+
+% E step
+olikseq = multinomial_prob(data, obsmat);
+if isempty(act)
+  [alpha, beta, gamma, ll, xi] = fwdback(prior, transmat, olikseq, 'fwd_only', filter_only);
+else
+  [alpha, beta, gamma, ll, xi] = fwdback(prior, transmat, olikseq, 'fwd_only', filter_only, ...
+					 'act', act);
+end
+
+% Increment ESS
+[S O] = size(obsmat);
+if adj_obs
+  exp_num_emit = decay*exp_num_emit + dirichlet*ones(S,O);
+  T = length(data);
+  if T < O
+    for t=1:T
+      o = data(t);
+      exp_num_emit(:,o) = exp_num_emit(:,o) + gamma(:,t);
+    end
+  else
+    for o=1:O
+      ndx = find(data==o);
+      if ~isempty(ndx)
+	exp_num_emit(:,o) = exp_num_emit(:,o) + sum(gamma(:, ndx), 2);
+      end
+    end
+  end
+end
+
+if adj_trans & (T > 1)
+  if isempty(act)
+    exp_num_trans = decay*exp_num_trans + sum(xi,3);
+  else
+    % act(2) determines Q(2), xi(:,:,1) holds P(Q(1), Q(2))
+    A = length(transmat);
+    for a=1:A
+      ndx = find(act(2:end)==a);
+      if ~isempty(ndx)
+	exp_num_trans{a} = decay*exp_num_trans{a} + sum(xi(:,:,ndx), 3);
+      end
+    end
+  end
+end
+
+
+% M step
+
+if adj_obs
+  obsmat = mk_stochastic(exp_num_emit);
+end
+if adj_trans & (T>1)
+  if isempty(act)
+    transmat = mk_stochastic(exp_num_trans);
+  else
+    for a=1:A
+      transmat{a} = mk_stochastic(exp_num_trans{a});
+    end
+  end
+end