wolffd@0
|
1 function [transmat, obsmat, exp_num_trans, exp_num_emit, gamma, ll] = dhmm_em_online(...
|
wolffd@0
|
2 prior, transmat, obsmat, exp_num_trans, exp_num_emit, decay, data, ...
|
wolffd@0
|
3 act, adj_trans, adj_obs, dirichlet, filter_only)
|
wolffd@0
|
4 % ONLINE_EM Adjust the parameters using a weighted combination of the old and new expected statistics
|
wolffd@0
|
5 %
|
wolffd@0
|
6 % [transmat, obsmat, exp_num_trans, exp_num_emit, gamma, ll] = online_em(...
|
wolffd@0
|
7 % prior, transmat, obsmat, exp_num_trans, exp_num_emit, decay, data, act, ...
|
wolffd@0
|
8 % adj_trans, adj_obs, dirichlet, filter_only)
|
wolffd@0
|
9 %
|
wolffd@0
|
10 % 0 < decay < 1, with smaller values meaning the past is forgotten more quickly.
|
wolffd@0
|
11 % (We need to decay the old ess, since they were based on out-of-date parameters.)
|
wolffd@0
|
12 % The other params are as in learn_hmm.
|
wolffd@0
|
13 % We do a single forwards-backwards pass on the provided data, initializing with the specified prior.
|
wolffd@0
|
14 % (If filter_only = 1, we only do a forwards pass.)
|
wolffd@0
|
15
|
wolffd@0
|
16 if ~exist('act'), act = []; end
|
wolffd@0
|
17 if ~exist('adj_trans'), adj_trans = 1; end
|
wolffd@0
|
18 if ~exist('adj_obs'), adj_obs = 1; end
|
wolffd@0
|
19 if ~exist('dirichlet'), dirichlet = 0; end
|
wolffd@0
|
20 if ~exist('filter_only'), filter_only = 0; end
|
wolffd@0
|
21
|
wolffd@0
|
22 % E step
|
wolffd@0
|
23 olikseq = multinomial_prob(data, obsmat);
|
wolffd@0
|
24 if isempty(act)
|
wolffd@0
|
25 [alpha, beta, gamma, ll, xi] = fwdback(prior, transmat, olikseq, 'fwd_only', filter_only);
|
wolffd@0
|
26 else
|
wolffd@0
|
27 [alpha, beta, gamma, ll, xi] = fwdback(prior, transmat, olikseq, 'fwd_only', filter_only, ...
|
wolffd@0
|
28 'act', act);
|
wolffd@0
|
29 end
|
wolffd@0
|
30
|
wolffd@0
|
31 % Increment ESS
|
wolffd@0
|
32 [S O] = size(obsmat);
|
wolffd@0
|
33 if adj_obs
|
wolffd@0
|
34 exp_num_emit = decay*exp_num_emit + dirichlet*ones(S,O);
|
wolffd@0
|
35 T = length(data);
|
wolffd@0
|
36 if T < O
|
wolffd@0
|
37 for t=1:T
|
wolffd@0
|
38 o = data(t);
|
wolffd@0
|
39 exp_num_emit(:,o) = exp_num_emit(:,o) + gamma(:,t);
|
wolffd@0
|
40 end
|
wolffd@0
|
41 else
|
wolffd@0
|
42 for o=1:O
|
wolffd@0
|
43 ndx = find(data==o);
|
wolffd@0
|
44 if ~isempty(ndx)
|
wolffd@0
|
45 exp_num_emit(:,o) = exp_num_emit(:,o) + sum(gamma(:, ndx), 2);
|
wolffd@0
|
46 end
|
wolffd@0
|
47 end
|
wolffd@0
|
48 end
|
wolffd@0
|
49 end
|
wolffd@0
|
50
|
wolffd@0
|
51 if adj_trans & (T > 1)
|
wolffd@0
|
52 if isempty(act)
|
wolffd@0
|
53 exp_num_trans = decay*exp_num_trans + sum(xi,3);
|
wolffd@0
|
54 else
|
wolffd@0
|
55 % act(2) determines Q(2), xi(:,:,1) holds P(Q(1), Q(2))
|
wolffd@0
|
56 A = length(transmat);
|
wolffd@0
|
57 for a=1:A
|
wolffd@0
|
58 ndx = find(act(2:end)==a);
|
wolffd@0
|
59 if ~isempty(ndx)
|
wolffd@0
|
60 exp_num_trans{a} = decay*exp_num_trans{a} + sum(xi(:,:,ndx), 3);
|
wolffd@0
|
61 end
|
wolffd@0
|
62 end
|
wolffd@0
|
63 end
|
wolffd@0
|
64 end
|
wolffd@0
|
65
|
wolffd@0
|
66
|
wolffd@0
|
67 % M step
|
wolffd@0
|
68
|
wolffd@0
|
69 if adj_obs
|
wolffd@0
|
70 obsmat = mk_stochastic(exp_num_emit);
|
wolffd@0
|
71 end
|
wolffd@0
|
72 if adj_trans & (T>1)
|
wolffd@0
|
73 if isempty(act)
|
wolffd@0
|
74 transmat = mk_stochastic(exp_num_trans);
|
wolffd@0
|
75 else
|
wolffd@0
|
76 for a=1:A
|
wolffd@0
|
77 transmat{a} = mk_stochastic(exp_num_trans{a});
|
wolffd@0
|
78 end
|
wolffd@0
|
79 end
|
wolffd@0
|
80 end
|