wolffd@0
|
1 % Example of online EM applied to a simple POMDP with fixed action seq
|
wolffd@0
|
2
|
wolffd@0
|
3 clear all
|
wolffd@0
|
4
|
wolffd@0
|
5 % Create a really easy model to learn
|
wolffd@0
|
6 rand('state', 1);
|
wolffd@0
|
7 O = 2;
|
wolffd@0
|
8 S = 2;
|
wolffd@0
|
9 A = 2;
|
wolffd@0
|
10 prior0 = [1 0]';
|
wolffd@0
|
11 transmat0 = cell(1,A);
|
wolffd@0
|
12 transmat0{1} = [0.9 0.1; 0.1 0.9]; % long runs of 1s and 2s
|
wolffd@0
|
13 transmat0{2} = [0.1 0.9; 0.9 0.1]; % short runs
|
wolffd@0
|
14 obsmat0 = eye(2);
|
wolffd@0
|
15
|
wolffd@0
|
16 %prior0 = normalise(rand(S,1));
|
wolffd@0
|
17 %transmat0 = mk_stochastic(rand(S,S));
|
wolffd@0
|
18 %obsmat0 = mk_stochastic(rand(S,O));
|
wolffd@0
|
19
|
wolffd@0
|
20 T = 10;
|
wolffd@0
|
21 act = [1*ones(1,25) 2*ones(1,25) 1*ones(1,25) 2*ones(1,25)];
|
wolffd@0
|
22 data = pomdp_sample(prior0, transmat0, obsmat0, act);
|
wolffd@0
|
23 %data = sample_dhmm(prior0, transmat0, obsmat0, T, 1);
|
wolffd@0
|
24
|
wolffd@0
|
25 % Initial guess of params
|
wolffd@0
|
26 rand('state', 2); % different seed!
|
wolffd@0
|
27 transmat1 = cell(1,A);
|
wolffd@0
|
28 for a=1:A
|
wolffd@0
|
29 transmat1{a} = mk_stochastic(rand(S,S));
|
wolffd@0
|
30 end
|
wolffd@0
|
31 obsmat1 = mk_stochastic(rand(S,O));
|
wolffd@0
|
32 prior1 = prior0; % so it labels states the same way
|
wolffd@0
|
33
|
wolffd@0
|
34 % Uniformative Dirichlet prior (expected sufficient statistics / pseudo counts)
|
wolffd@0
|
35 e = 0.001;
|
wolffd@0
|
36 ess_trans = cell(1,A);
|
wolffd@0
|
37 for a=1:A
|
wolffd@0
|
38 ess_trans{a} = repmat(e, S, S);
|
wolffd@0
|
39 end
|
wolffd@0
|
40 ess_emit = repmat(e, S, O);
|
wolffd@0
|
41
|
wolffd@0
|
42 % Params
|
wolffd@0
|
43 w = 2;
|
wolffd@0
|
44 decay_sched = [0.1:0.1:0.9];
|
wolffd@0
|
45
|
wolffd@0
|
46 % Initialize
|
wolffd@0
|
47 LL1 = zeros(1,T);
|
wolffd@0
|
48 t = 1;
|
wolffd@0
|
49 y = data(t);
|
wolffd@0
|
50 data_win = y;
|
wolffd@0
|
51 act_win = [1]; % arbitrary initial value
|
wolffd@0
|
52 [prior1, LL1(1)] = normalise(prior1 .* obsmat1(:,y));
|
wolffd@0
|
53
|
wolffd@0
|
54 % Iterate
|
wolffd@0
|
55 for t=2:T
|
wolffd@0
|
56 y = data(t);
|
wolffd@0
|
57 a = act(t);
|
wolffd@0
|
58 if t <= w
|
wolffd@0
|
59 data_win = [data_win y];
|
wolffd@0
|
60 act_win = [act_win a];
|
wolffd@0
|
61 else
|
wolffd@0
|
62 data_win = [data_win(2:end) y];
|
wolffd@0
|
63 act_win = [act_win(2:end) a];
|
wolffd@0
|
64 prior1 = gamma(:, 2);
|
wolffd@0
|
65 end
|
wolffd@0
|
66 d = decay_sched(min(t, length(decay_sched)));
|
wolffd@0
|
67 [transmat1, obsmat1, ess_trans, ess_emit, gamma, ll] = dhmm_em_online(...
|
wolffd@0
|
68 prior1, transmat1, obsmat1, ess_trans, ess_emit, d, data_win, act_win);
|
wolffd@0
|
69 bel = gamma(:, end);
|
wolffd@0
|
70 LL1(t) = ll/length(data_win);
|
wolffd@0
|
71 %fprintf('t=%d, ll=%f\n', t, ll);
|
wolffd@0
|
72 end
|
wolffd@0
|
73
|
wolffd@0
|
74 LL1(1) = LL1(2); % since initial likelihood is for 1 slice
|
wolffd@0
|
75 plot(1:T, LL1, 'rx-');
|
wolffd@0
|
76
|
wolffd@0
|
77
|
wolffd@0
|
78 % compare with offline learning
|
wolffd@0
|
79
|
wolffd@0
|
80 if 0
|
wolffd@0
|
81 rand('state', 2); % same seed as online learner
|
wolffd@0
|
82 transmat2 = cell(1,A);
|
wolffd@0
|
83 for a=1:A
|
wolffd@0
|
84 transmat2{a} = mk_stochastic(rand(S,S));
|
wolffd@0
|
85 end
|
wolffd@0
|
86 obsmat2 = mk_stochastic(rand(S,O));
|
wolffd@0
|
87 prior2 = prior0;
|
wolffd@0
|
88 [LL2, prior2, transmat2, obsmat2] = dhmm_em(data, prior2, transmat2, obsmat2, ....
|
wolffd@0
|
89 'max_iter', 10, 'thresh', 1e-3, 'verbose', 1, 'act', act);
|
wolffd@0
|
90
|
wolffd@0
|
91 LL2 = LL2 / T
|
wolffd@0
|
92
|
wolffd@0
|
93 end
|