Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/HMM/dhmm_em_online_demo.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % Example of online EM applied to a simple POMDP with fixed action seq | |
2 | |
3 clear all | |
4 | |
5 % Create a really easy model to learn | |
6 rand('state', 1); | |
7 O = 2; | |
8 S = 2; | |
9 A = 2; | |
10 prior0 = [1 0]'; | |
11 transmat0 = cell(1,A); | |
12 transmat0{1} = [0.9 0.1; 0.1 0.9]; % long runs of 1s and 2s | |
13 transmat0{2} = [0.1 0.9; 0.9 0.1]; % short runs | |
14 obsmat0 = eye(2); | |
15 | |
16 %prior0 = normalise(rand(S,1)); | |
17 %transmat0 = mk_stochastic(rand(S,S)); | |
18 %obsmat0 = mk_stochastic(rand(S,O)); | |
19 | |
20 T = 10; | |
21 act = [1*ones(1,25) 2*ones(1,25) 1*ones(1,25) 2*ones(1,25)]; | |
22 data = pomdp_sample(prior0, transmat0, obsmat0, act); | |
23 %data = sample_dhmm(prior0, transmat0, obsmat0, T, 1); | |
24 | |
25 % Initial guess of params | |
26 rand('state', 2); % different seed! | |
27 transmat1 = cell(1,A); | |
28 for a=1:A | |
29 transmat1{a} = mk_stochastic(rand(S,S)); | |
30 end | |
31 obsmat1 = mk_stochastic(rand(S,O)); | |
32 prior1 = prior0; % so it labels states the same way | |
33 | |
34 % Uniformative Dirichlet prior (expected sufficient statistics / pseudo counts) | |
35 e = 0.001; | |
36 ess_trans = cell(1,A); | |
37 for a=1:A | |
38 ess_trans{a} = repmat(e, S, S); | |
39 end | |
40 ess_emit = repmat(e, S, O); | |
41 | |
42 % Params | |
43 w = 2; | |
44 decay_sched = [0.1:0.1:0.9]; | |
45 | |
46 % Initialize | |
47 LL1 = zeros(1,T); | |
48 t = 1; | |
49 y = data(t); | |
50 data_win = y; | |
51 act_win = [1]; % arbitrary initial value | |
52 [prior1, LL1(1)] = normalise(prior1 .* obsmat1(:,y)); | |
53 | |
54 % Iterate | |
55 for t=2:T | |
56 y = data(t); | |
57 a = act(t); | |
58 if t <= w | |
59 data_win = [data_win y]; | |
60 act_win = [act_win a]; | |
61 else | |
62 data_win = [data_win(2:end) y]; | |
63 act_win = [act_win(2:end) a]; | |
64 prior1 = gamma(:, 2); | |
65 end | |
66 d = decay_sched(min(t, length(decay_sched))); | |
67 [transmat1, obsmat1, ess_trans, ess_emit, gamma, ll] = dhmm_em_online(... | |
68 prior1, transmat1, obsmat1, ess_trans, ess_emit, d, data_win, act_win); | |
69 bel = gamma(:, end); | |
70 LL1(t) = ll/length(data_win); | |
71 %fprintf('t=%d, ll=%f\n', t, ll); | |
72 end | |
73 | |
74 LL1(1) = LL1(2); % since initial likelihood is for 1 slice | |
75 plot(1:T, LL1, 'rx-'); | |
76 | |
77 | |
78 % compare with offline learning | |
79 | |
80 if 0 | |
81 rand('state', 2); % same seed as online learner | |
82 transmat2 = cell(1,A); | |
83 for a=1:A | |
84 transmat2{a} = mk_stochastic(rand(S,S)); | |
85 end | |
86 obsmat2 = mk_stochastic(rand(S,O)); | |
87 prior2 = prior0; | |
88 [LL2, prior2, transmat2, obsmat2] = dhmm_em(data, prior2, transmat2, obsmat2, .... | |
89 'max_iter', 10, 'thresh', 1e-3, 'verbose', 1, 'act', act); | |
90 | |
91 LL2 = LL2 / T | |
92 | |
93 end |