Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/Square/Old/sample_square_hhmm.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 | |
2 seed = 0; | |
3 rand('state', seed); | |
4 randn('state', seed); | |
5 | |
6 discrete_obs = 1; | |
7 topright = 0; | |
8 | |
9 Qsizes = [2 4 2]; | |
10 D = 3; | |
11 Qnodes = 1:D; | |
12 startprob = cell(1,D); | |
13 transprob = cell(1,D); | |
14 termprob = cell(1,D); | |
15 | |
16 % LEVEL 1 | |
17 | |
18 startprob{1} = 'ergodic'; | |
19 transprob{1} = 'ergodic'; | |
20 | |
21 % LEVEL 2 | |
22 | |
23 startprob{2} = zeros(2, 4); | |
24 startprob{2}(1, :) = [1 0 0 0]; | |
25 if topright | |
26 startprob{2}(2, :) = [0 0 1 0]; | |
27 else | |
28 startprob{2}(2, :) = [0 1 0 0]; | |
29 end | |
30 | |
31 transprob{2} = zeros(4, 2, 4); | |
32 | |
33 transprob{2}(:,1,:) = [0 1 0 0 | |
34 0 0 1 0 | |
35 0 0 0 1 | |
36 0 0 0 1]; % 4->e | |
37 if topright | |
38 transprob{2}(:,2,:) = [0 0 0 1 | |
39 1 0 0 0 | |
40 0 1 0 0 | |
41 0 0 0 1]; % 4->e | |
42 else | |
43 transprob{2}(:,2,:) = [0 0 0 1 | |
44 1 0 0 0 | |
45 0 0 1 0 % 3->e | |
46 0 0 1 0]; | |
47 end | |
48 | |
49 %termprob{2} = 'rightstop'; | |
50 termprob{2} = zeros(2,4,2); | |
51 pfin = 0.8; | |
52 termprob{2}(1,:,2) = [0 0 0 pfin]; % finish in state 4 (DU) | |
53 termprob{2}(1,:,1) = 1 - [0 0 0 pfin]; | |
54 if topright | |
55 termprob{2}(2,:,2) = [0 0 0 pfin]; | |
56 termprob{2}(2,:,1) = 1 - [0 0 0 pfin]; | |
57 else | |
58 termprob{2}(2,:,2) = [0 0 pfin 0]; % finish in state 3 (RL) | |
59 termprob{2}(2,:,1) = 1 - [0 0 pfin 0]; | |
60 end | |
61 | |
62 % LEVEL 3 | |
63 | |
64 startprob{3} = 'leftstart'; | |
65 transprob{3} = 'leftright'; | |
66 termprob{3} = 'rightstop'; | |
67 | |
68 | |
69 % OBS LEVEl | |
70 | |
71 if discrete_obs | |
72 chars = ['L', 'l', 'U', 'u', 'R', 'r', 'D', 'd']; | |
73 L=find(chars=='L'); l=find(chars=='l'); | |
74 U=find(chars=='U'); u=find(chars=='u'); | |
75 R=find(chars=='R'); r=find(chars=='r'); | |
76 D=find(chars=='D'); d=find(chars=='d'); | |
77 Osize = length(chars); | |
78 | |
79 obsprob = zeros([4 2 Osize]); | |
80 % Q2 Q3 O | |
81 obsprob(1, 1, L) = 1.0; | |
82 obsprob(1, 2, l) = 1.0; | |
83 obsprob(2, 1, U) = 1.0; | |
84 obsprob(2, 2, u) = 1.0; | |
85 obsprob(3, 1, R) = 1.0; | |
86 obsprob(3, 2, r) = 1.0; | |
87 obsprob(4, 1, D) = 1.0; | |
88 obsprob(4, 2, d) = 1.0; | |
89 | |
90 Oargs = {'CPT', obsprob}; | |
91 else | |
92 Osize = 2; | |
93 mu = zeros(2, 4, 2); | |
94 noise = 0; | |
95 scale = 10; | |
96 for q3=1:2 | |
97 mu(:, 1, q3) = scale*[1;0] + noise*rand(2,1); | |
98 end | |
99 for q3=1:2 | |
100 mu(:, 2, q3) = scale*[0;-1] + noise*rand(2,1); | |
101 end | |
102 for q3=1:2 | |
103 mu(:, 3, q3) = scale*[-1;0] + noise*rand(2,1); | |
104 end | |
105 for q3=1:2 | |
106 mu(:, 4, q3) = scale*[0;1] + noise*rand(2,1); | |
107 end | |
108 Sigma = repmat(reshape(0.01*eye(2), [2 2 1 1 ]), [1 1 4 2]); | |
109 Oargs = {'mean', mu, 'cov', Sigma}; | |
110 end | |
111 | |
112 bnet = mk_hhmm('Qsizes', Qsizes, 'Osize', Osize', 'discrete_obs', discrete_obs, ... | |
113 'Oargs', Oargs, 'Ops', Qnodes(2:3), ... | |
114 'startprob', startprob, 'transprob', transprob, 'termprob', termprob); | |
115 | |
116 if discrete_obs | |
117 Tmax = 30; | |
118 else | |
119 Tmax = 200; | |
120 end | |
121 usecell = ~discrete_obs; | |
122 Q1 = 1; Q2 = 2; Q3 = 3; F3 = 4; F2 = 5; Onode = 6; | |
123 Qnodes = [Q1 Q2 Q3]; Fnodes = [F2 F3]; | |
124 | |
125 for seqi=1:3 | |
126 evidence = sample_dbn(bnet, Tmax, usecell, 'stop_sampling_F2'); | |
127 T = size(evidence, 2) | |
128 if discrete_obs | |
129 pretty_print_hhmm_parse(evidence, Qnodes, Fnodes, Onode, chars); | |
130 else | |
131 pos = zeros(2,T+1); | |
132 delta = cell2num(evidence(Onode,:)); | |
133 clf | |
134 hold on | |
135 cols = {'r', 'g', 'k', 'b'}; | |
136 boundary = cell2num(evidence(F3,:))-1; | |
137 coli = 1; | |
138 for t=2:T+1 | |
139 pos(:,t) = pos(:,t-1) + delta(:,t-1); | |
140 plot(pos(1,t), pos(2,t), sprintf('%c.', cols{coli})); | |
141 if boundary(t-1) | |
142 coli = coli + 1; | |
143 coli = mod(coli-1, length(cols)) + 1; | |
144 end | |
145 end | |
146 %plot(pos(1,:), pos(2,:), '.') | |
147 %pretty_print_hhmm_parse(evidence, Qnodes, Fnodes, Onode, []); | |
148 pause | |
149 end | |
150 end | |
151 | |
152 eclass = bnet.equiv_class; | |
153 S=struct(bnet.CPD{eclass(Q2,2)}); | |
154 | |
155 | |
156 | |
157 | |
158 | |
159 | |
160 |