wolffd@0: % Example of online EM applied to a simple POMDP with fixed action seq wolffd@0: wolffd@0: clear all wolffd@0: wolffd@0: % Create a really easy model to learn wolffd@0: rand('state', 1); wolffd@0: O = 2; wolffd@0: S = 2; wolffd@0: A = 2; wolffd@0: prior0 = [1 0]'; wolffd@0: transmat0 = cell(1,A); wolffd@0: transmat0{1} = [0.9 0.1; 0.1 0.9]; % long runs of 1s and 2s wolffd@0: transmat0{2} = [0.1 0.9; 0.9 0.1]; % short runs wolffd@0: obsmat0 = eye(2); wolffd@0: wolffd@0: %prior0 = normalise(rand(S,1)); wolffd@0: %transmat0 = mk_stochastic(rand(S,S)); wolffd@0: %obsmat0 = mk_stochastic(rand(S,O)); wolffd@0: wolffd@0: T = 10; wolffd@0: act = [1*ones(1,25) 2*ones(1,25) 1*ones(1,25) 2*ones(1,25)]; wolffd@0: data = pomdp_sample(prior0, transmat0, obsmat0, act); wolffd@0: %data = sample_dhmm(prior0, transmat0, obsmat0, T, 1); wolffd@0: wolffd@0: % Initial guess of params wolffd@0: rand('state', 2); % different seed! wolffd@0: transmat1 = cell(1,A); wolffd@0: for a=1:A wolffd@0: transmat1{a} = mk_stochastic(rand(S,S)); wolffd@0: end wolffd@0: obsmat1 = mk_stochastic(rand(S,O)); wolffd@0: prior1 = prior0; % so it labels states the same way wolffd@0: wolffd@0: % Uniformative Dirichlet prior (expected sufficient statistics / pseudo counts) wolffd@0: e = 0.001; wolffd@0: ess_trans = cell(1,A); wolffd@0: for a=1:A wolffd@0: ess_trans{a} = repmat(e, S, S); wolffd@0: end wolffd@0: ess_emit = repmat(e, S, O); wolffd@0: wolffd@0: % Params wolffd@0: w = 2; wolffd@0: decay_sched = [0.1:0.1:0.9]; wolffd@0: wolffd@0: % Initialize wolffd@0: LL1 = zeros(1,T); wolffd@0: t = 1; wolffd@0: y = data(t); wolffd@0: data_win = y; wolffd@0: act_win = [1]; % arbitrary initial value wolffd@0: [prior1, LL1(1)] = normalise(prior1 .* obsmat1(:,y)); wolffd@0: wolffd@0: % Iterate wolffd@0: for t=2:T wolffd@0: y = data(t); wolffd@0: a = act(t); wolffd@0: if t <= w wolffd@0: data_win = [data_win y]; wolffd@0: act_win = [act_win a]; wolffd@0: else wolffd@0: data_win = [data_win(2:end) y]; wolffd@0: act_win = [act_win(2:end) a]; wolffd@0: prior1 = gamma(:, 2); wolffd@0: end wolffd@0: d = decay_sched(min(t, length(decay_sched))); wolffd@0: [transmat1, obsmat1, ess_trans, ess_emit, gamma, ll] = dhmm_em_online(... wolffd@0: prior1, transmat1, obsmat1, ess_trans, ess_emit, d, data_win, act_win); wolffd@0: bel = gamma(:, end); wolffd@0: LL1(t) = ll/length(data_win); wolffd@0: %fprintf('t=%d, ll=%f\n', t, ll); wolffd@0: end wolffd@0: wolffd@0: LL1(1) = LL1(2); % since initial likelihood is for 1 slice wolffd@0: plot(1:T, LL1, 'rx-'); wolffd@0: wolffd@0: wolffd@0: % compare with offline learning wolffd@0: wolffd@0: if 0 wolffd@0: rand('state', 2); % same seed as online learner wolffd@0: transmat2 = cell(1,A); wolffd@0: for a=1:A wolffd@0: transmat2{a} = mk_stochastic(rand(S,S)); wolffd@0: end wolffd@0: obsmat2 = mk_stochastic(rand(S,O)); wolffd@0: prior2 = prior0; wolffd@0: [LL2, prior2, transmat2, obsmat2] = dhmm_em(data, prior2, transmat2, obsmat2, .... wolffd@0: 'max_iter', 10, 'thresh', 1e-3, 'verbose', 1, 'act', act); wolffd@0: wolffd@0: LL2 = LL2 / T wolffd@0: wolffd@0: end