wolffd@0: % We consider a switching Kalman filter of the kind studied wolffd@0: % by Zoubin Ghahramani, i.e., where the switch node determines wolffd@0: % which of the hidden chains we get to observe (data association). wolffd@0: % e.g., for n=2 chains wolffd@0: % wolffd@0: % X1 -> X1 wolffd@0: % | X2 -> X2 wolffd@0: % \ | wolffd@0: % v wolffd@0: % Y wolffd@0: % ^ wolffd@0: % | wolffd@0: % S wolffd@0: % wolffd@0: % Y is a gmux (multiplexer) node, where S switches in one of the parents. wolffd@0: % We differ from Zoubin by not connecting the S nodes over time (which wolffd@0: % doesn't make sense for data association). wolffd@0: % Indeed, we assume the S nodes are always observed. wolffd@0: % wolffd@0: % wolffd@0: % We will track 2 objects (points) moving in the plane, as in BNT/Kalman/tracking_demo. wolffd@0: % We will alternate between observing them. wolffd@0: wolffd@0: nobj = 2; wolffd@0: N = nobj+2; wolffd@0: Xs = 1:nobj; wolffd@0: S = nobj+1; wolffd@0: Y = nobj+2; wolffd@0: wolffd@0: intra = zeros(N,N); wolffd@0: inter = zeros(N,N); wolffd@0: intra([Xs S], Y) =1; wolffd@0: for i=1:nobj wolffd@0: inter(Xs(i), Xs(i))=1; wolffd@0: end wolffd@0: wolffd@0: Xsz = 4; % state space = (x y xdot ydot) wolffd@0: Ysz = 2; wolffd@0: ns = zeros(1,N); wolffd@0: ns(Xs) = Xsz; wolffd@0: ns(Y) = Ysz; wolffd@0: ns(S) = n; wolffd@0: wolffd@0: bnet = mk_dbn(intra, inter, ns, 'discrete', S, 'observed', [S Y]); wolffd@0: wolffd@0: % For each object, we have wolffd@0: % X(t+1) = F X(t) + noise(Q) wolffd@0: % Y(t) = H X(t) + noise(R) wolffd@0: F = [1 0 1 0; 0 1 0 1; 0 0 1 0; 0 0 0 1]; wolffd@0: H = [1 0 0 0; 0 1 0 0]; wolffd@0: Q = 1e-3*eye(Xsz); wolffd@0: %R = 1e-3*eye(Ysz); wolffd@0: R = eye(Ysz); wolffd@0: wolffd@0: % We initialise object 1 moving to the right, and object 2 moving to the left wolffd@0: % (Here, we assume nobj=2) wolffd@0: init_state{1} = [10 10 1 0]'; wolffd@0: init_state{2} = [10 -10 -1 0]'; wolffd@0: wolffd@0: for i=1:nobj wolffd@0: bnet.CPD{Xs(i)} = gaussian_CPD(bnet, Xs(i), 'mean', init_state{i}, 'cov', 1e-4*eye(Xsz)); wolffd@0: end wolffd@0: bnet.CPD{S} = root_CPD(bnet, S); % always observed wolffd@0: bnet.CPD{Y} = gmux_CPD(bnet, Y, 'cov', repmat(R, [1 1 nobj]), 'weights', repmat(H, [1 1 nobj])); wolffd@0: % slice 2 wolffd@0: eclass = bnet.equiv_class; wolffd@0: for i=1:nobj wolffd@0: bnet.CPD{eclass(Xs(i), 2)} = gaussian_CPD(bnet, Xs(i)+N, 'mean', zeros(Xsz,1), 'cov', Q, 'weights', F); wolffd@0: end wolffd@0: wolffd@0: % Observe objects at random wolffd@0: T = 10; wolffd@0: evidence = cell(N, T); wolffd@0: data_assoc = sample_discrete(normalise(ones(1,nobj)), 1, T); wolffd@0: evidence(S,:) = num2cell(data_assoc); wolffd@0: evidence = sample_dbn(bnet, 'evidence', evidence); wolffd@0: wolffd@0: % plot the data wolffd@0: true_state = cell(1,nobj); wolffd@0: for i=1:nobj wolffd@0: true_state{i} = cell2num(evidence(Xs(i), :)); % true_state{i}(:,t) = [x y xdot ydot]' wolffd@0: end wolffd@0: obs_pos = cell2num(evidence(Y,:)); wolffd@0: figure(1) wolffd@0: clf wolffd@0: hold on wolffd@0: styles = {'rx', 'go', 'b+', 'k*'}; wolffd@0: for i=1:nobj wolffd@0: plot(true_state{i}(1,:), true_state{i}(2,:), styles{i}); wolffd@0: end wolffd@0: for t=1:T wolffd@0: text(obs_pos(1,t), obs_pos(2,t), sprintf('%d', t)); wolffd@0: end wolffd@0: hold off wolffd@0: relax_axes(0.1) wolffd@0: wolffd@0: wolffd@0: % Inference wolffd@0: ev = cell(N,T); wolffd@0: ev(bnet.observed,:) = evidence(bnet.observed, :); wolffd@0: wolffd@0: engines = {}; wolffd@0: engines{end+1} = jtree_dbn_inf_engine(bnet); wolffd@0: %engines{end+1} = scg_unrolled_dbn_inf_engine(bnet, T); wolffd@0: engines{end+1} = pearl_unrolled_dbn_inf_engine(bnet); wolffd@0: E = length(engines); wolffd@0: wolffd@0: inferred_state = cell(nobj,E); % inferred_state{i,e}(:,t) wolffd@0: for e=1:E wolffd@0: engines{e} = enter_evidence(engines{e}, ev); wolffd@0: for i=1:nobj wolffd@0: inferred_state{i,e} = zeros(4, T); wolffd@0: for t=1:T wolffd@0: m = marginal_nodes(engines{e}, Xs(i), t); wolffd@0: inferred_state{i,e}(:,t) = m.mu; wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: inferred_state{1,1} wolffd@0: inferred_state{1,2} wolffd@0: wolffd@0: % Plot results wolffd@0: figure(2) wolffd@0: clf wolffd@0: hold on wolffd@0: styles = {'rx', 'go', 'b+', 'k*'}; wolffd@0: nstyles = length(styles); wolffd@0: c = 1; wolffd@0: for e=1:E wolffd@0: for i=1:nobj wolffd@0: plot(inferred_state{i,e}(1,:), inferred_state{i,e}(2,:), styles{mod(c-1,nstyles)+1}); wolffd@0: c = c + 1; wolffd@0: end wolffd@0: end wolffd@0: for t=1:T wolffd@0: text(obs_pos(1,t), obs_pos(2,t), sprintf('%d', t)); wolffd@0: end wolffd@0: hold off wolffd@0: relax_axes(0.1)