annotate toolboxes/FullBNT-1.0.7/HMM/dhmm_em_online_demo.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 % Example of online EM applied to a simple POMDP with fixed action seq
wolffd@0 2
wolffd@0 3 clear all
wolffd@0 4
wolffd@0 5 % Create a really easy model to learn
wolffd@0 6 rand('state', 1);
wolffd@0 7 O = 2;
wolffd@0 8 S = 2;
wolffd@0 9 A = 2;
wolffd@0 10 prior0 = [1 0]';
wolffd@0 11 transmat0 = cell(1,A);
wolffd@0 12 transmat0{1} = [0.9 0.1; 0.1 0.9]; % long runs of 1s and 2s
wolffd@0 13 transmat0{2} = [0.1 0.9; 0.9 0.1]; % short runs
wolffd@0 14 obsmat0 = eye(2);
wolffd@0 15
wolffd@0 16 %prior0 = normalise(rand(S,1));
wolffd@0 17 %transmat0 = mk_stochastic(rand(S,S));
wolffd@0 18 %obsmat0 = mk_stochastic(rand(S,O));
wolffd@0 19
wolffd@0 20 T = 10;
wolffd@0 21 act = [1*ones(1,25) 2*ones(1,25) 1*ones(1,25) 2*ones(1,25)];
wolffd@0 22 data = pomdp_sample(prior0, transmat0, obsmat0, act);
wolffd@0 23 %data = sample_dhmm(prior0, transmat0, obsmat0, T, 1);
wolffd@0 24
wolffd@0 25 % Initial guess of params
wolffd@0 26 rand('state', 2); % different seed!
wolffd@0 27 transmat1 = cell(1,A);
wolffd@0 28 for a=1:A
wolffd@0 29 transmat1{a} = mk_stochastic(rand(S,S));
wolffd@0 30 end
wolffd@0 31 obsmat1 = mk_stochastic(rand(S,O));
wolffd@0 32 prior1 = prior0; % so it labels states the same way
wolffd@0 33
wolffd@0 34 % Uniformative Dirichlet prior (expected sufficient statistics / pseudo counts)
wolffd@0 35 e = 0.001;
wolffd@0 36 ess_trans = cell(1,A);
wolffd@0 37 for a=1:A
wolffd@0 38 ess_trans{a} = repmat(e, S, S);
wolffd@0 39 end
wolffd@0 40 ess_emit = repmat(e, S, O);
wolffd@0 41
wolffd@0 42 % Params
wolffd@0 43 w = 2;
wolffd@0 44 decay_sched = [0.1:0.1:0.9];
wolffd@0 45
wolffd@0 46 % Initialize
wolffd@0 47 LL1 = zeros(1,T);
wolffd@0 48 t = 1;
wolffd@0 49 y = data(t);
wolffd@0 50 data_win = y;
wolffd@0 51 act_win = [1]; % arbitrary initial value
wolffd@0 52 [prior1, LL1(1)] = normalise(prior1 .* obsmat1(:,y));
wolffd@0 53
wolffd@0 54 % Iterate
wolffd@0 55 for t=2:T
wolffd@0 56 y = data(t);
wolffd@0 57 a = act(t);
wolffd@0 58 if t <= w
wolffd@0 59 data_win = [data_win y];
wolffd@0 60 act_win = [act_win a];
wolffd@0 61 else
wolffd@0 62 data_win = [data_win(2:end) y];
wolffd@0 63 act_win = [act_win(2:end) a];
wolffd@0 64 prior1 = gamma(:, 2);
wolffd@0 65 end
wolffd@0 66 d = decay_sched(min(t, length(decay_sched)));
wolffd@0 67 [transmat1, obsmat1, ess_trans, ess_emit, gamma, ll] = dhmm_em_online(...
wolffd@0 68 prior1, transmat1, obsmat1, ess_trans, ess_emit, d, data_win, act_win);
wolffd@0 69 bel = gamma(:, end);
wolffd@0 70 LL1(t) = ll/length(data_win);
wolffd@0 71 %fprintf('t=%d, ll=%f\n', t, ll);
wolffd@0 72 end
wolffd@0 73
wolffd@0 74 LL1(1) = LL1(2); % since initial likelihood is for 1 slice
wolffd@0 75 plot(1:T, LL1, 'rx-');
wolffd@0 76
wolffd@0 77
wolffd@0 78 % compare with offline learning
wolffd@0 79
wolffd@0 80 if 0
wolffd@0 81 rand('state', 2); % same seed as online learner
wolffd@0 82 transmat2 = cell(1,A);
wolffd@0 83 for a=1:A
wolffd@0 84 transmat2{a} = mk_stochastic(rand(S,S));
wolffd@0 85 end
wolffd@0 86 obsmat2 = mk_stochastic(rand(S,O));
wolffd@0 87 prior2 = prior0;
wolffd@0 88 [LL2, prior2, transmat2, obsmat2] = dhmm_em(data, prior2, transmat2, obsmat2, ....
wolffd@0 89 'max_iter', 10, 'thresh', 1e-3, 'verbose', 1, 'act', act);
wolffd@0 90
wolffd@0 91 LL2 = LL2 / T
wolffd@0 92
wolffd@0 93 end