Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/skf_data_assoc_gmux.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % We consider a switching Kalman filter of the kind studied | |
2 % by Zoubin Ghahramani, i.e., where the switch node determines | |
3 % which of the hidden chains we get to observe (data association). | |
4 % e.g., for n=2 chains | |
5 % | |
6 % X1 -> X1 | |
7 % | X2 -> X2 | |
8 % \ | | |
9 % v | |
10 % Y | |
11 % ^ | |
12 % | | |
13 % S | |
14 % | |
15 % Y is a gmux (multiplexer) node, where S switches in one of the parents. | |
16 % We differ from Zoubin by not connecting the S nodes over time (which | |
17 % doesn't make sense for data association). | |
18 % Indeed, we assume the S nodes are always observed. | |
19 % | |
20 % | |
21 % We will track 2 objects (points) moving in the plane, as in BNT/Kalman/tracking_demo. | |
22 % We will alternate between observing them. | |
23 | |
24 nobj = 2; | |
25 N = nobj+2; | |
26 Xs = 1:nobj; | |
27 S = nobj+1; | |
28 Y = nobj+2; | |
29 | |
30 intra = zeros(N,N); | |
31 inter = zeros(N,N); | |
32 intra([Xs S], Y) =1; | |
33 for i=1:nobj | |
34 inter(Xs(i), Xs(i))=1; | |
35 end | |
36 | |
37 Xsz = 4; % state space = (x y xdot ydot) | |
38 Ysz = 2; | |
39 ns = zeros(1,N); | |
40 ns(Xs) = Xsz; | |
41 ns(Y) = Ysz; | |
42 ns(S) = n; | |
43 | |
44 bnet = mk_dbn(intra, inter, ns, 'discrete', S, 'observed', [S Y]); | |
45 | |
46 % For each object, we have | |
47 % X(t+1) = F X(t) + noise(Q) | |
48 % Y(t) = H X(t) + noise(R) | |
49 F = [1 0 1 0; 0 1 0 1; 0 0 1 0; 0 0 0 1]; | |
50 H = [1 0 0 0; 0 1 0 0]; | |
51 Q = 1e-3*eye(Xsz); | |
52 %R = 1e-3*eye(Ysz); | |
53 R = eye(Ysz); | |
54 | |
55 % We initialise object 1 moving to the right, and object 2 moving to the left | |
56 % (Here, we assume nobj=2) | |
57 init_state{1} = [10 10 1 0]'; | |
58 init_state{2} = [10 -10 -1 0]'; | |
59 | |
60 for i=1:nobj | |
61 bnet.CPD{Xs(i)} = gaussian_CPD(bnet, Xs(i), 'mean', init_state{i}, 'cov', 1e-4*eye(Xsz)); | |
62 end | |
63 bnet.CPD{S} = root_CPD(bnet, S); % always observed | |
64 bnet.CPD{Y} = gmux_CPD(bnet, Y, 'cov', repmat(R, [1 1 nobj]), 'weights', repmat(H, [1 1 nobj])); | |
65 % slice 2 | |
66 eclass = bnet.equiv_class; | |
67 for i=1:nobj | |
68 bnet.CPD{eclass(Xs(i), 2)} = gaussian_CPD(bnet, Xs(i)+N, 'mean', zeros(Xsz,1), 'cov', Q, 'weights', F); | |
69 end | |
70 | |
71 % Observe objects at random | |
72 T = 10; | |
73 evidence = cell(N, T); | |
74 data_assoc = sample_discrete(normalise(ones(1,nobj)), 1, T); | |
75 evidence(S,:) = num2cell(data_assoc); | |
76 evidence = sample_dbn(bnet, 'evidence', evidence); | |
77 | |
78 % plot the data | |
79 true_state = cell(1,nobj); | |
80 for i=1:nobj | |
81 true_state{i} = cell2num(evidence(Xs(i), :)); % true_state{i}(:,t) = [x y xdot ydot]' | |
82 end | |
83 obs_pos = cell2num(evidence(Y,:)); | |
84 figure(1) | |
85 clf | |
86 hold on | |
87 styles = {'rx', 'go', 'b+', 'k*'}; | |
88 for i=1:nobj | |
89 plot(true_state{i}(1,:), true_state{i}(2,:), styles{i}); | |
90 end | |
91 for t=1:T | |
92 text(obs_pos(1,t), obs_pos(2,t), sprintf('%d', t)); | |
93 end | |
94 hold off | |
95 relax_axes(0.1) | |
96 | |
97 | |
98 % Inference | |
99 ev = cell(N,T); | |
100 ev(bnet.observed,:) = evidence(bnet.observed, :); | |
101 | |
102 engines = {}; | |
103 engines{end+1} = jtree_dbn_inf_engine(bnet); | |
104 %engines{end+1} = scg_unrolled_dbn_inf_engine(bnet, T); | |
105 engines{end+1} = pearl_unrolled_dbn_inf_engine(bnet); | |
106 E = length(engines); | |
107 | |
108 inferred_state = cell(nobj,E); % inferred_state{i,e}(:,t) | |
109 for e=1:E | |
110 engines{e} = enter_evidence(engines{e}, ev); | |
111 for i=1:nobj | |
112 inferred_state{i,e} = zeros(4, T); | |
113 for t=1:T | |
114 m = marginal_nodes(engines{e}, Xs(i), t); | |
115 inferred_state{i,e}(:,t) = m.mu; | |
116 end | |
117 end | |
118 end | |
119 inferred_state{1,1} | |
120 inferred_state{1,2} | |
121 | |
122 % Plot results | |
123 figure(2) | |
124 clf | |
125 hold on | |
126 styles = {'rx', 'go', 'b+', 'k*'}; | |
127 nstyles = length(styles); | |
128 c = 1; | |
129 for e=1:E | |
130 for i=1:nobj | |
131 plot(inferred_state{i,e}(1,:), inferred_state{i,e}(2,:), styles{mod(c-1,nstyles)+1}); | |
132 c = c + 1; | |
133 end | |
134 end | |
135 for t=1:T | |
136 text(obs_pos(1,t), obs_pos(2,t), sprintf('%d', t)); | |
137 end | |
138 hold off | |
139 relax_axes(0.1) |