wolffd@0: function bnet = mk_square_hhmm(discrete_obs, true_params, topright) wolffd@0: wolffd@0: % Make a 3 level HHMM described by the following grammar wolffd@0: % wolffd@0: % Square -> CLK | CCK % clockwise or counterclockwise wolffd@0: % CLK -> LR UD RL DU start on top left (1 2 3 4) wolffd@0: % CCK -> RL UD LR DU if start at top right (3 2 1 4) wolffd@0: % CCK -> UD LR DU RL if start at top left (2 1 4 3) wolffd@0: % wolffd@0: % LR = left-right, UD = up-down, RL = right-left, DU = down-up wolffd@0: % LR, UD, RL, DU are sub HMMs. wolffd@0: % wolffd@0: % For discrete observations, the subHMMs are 2-state left-right. wolffd@0: % LR emits L then l, etc. wolffd@0: % wolffd@0: % For cts observations, the subHMMs are 1 state. wolffd@0: % LR emits a vector in the -> direction, with a little noise. wolffd@0: % Since there is no constraint that we remain in the LR state as long as the RL state, wolffd@0: % the sides of the square might have different lengths, wolffd@0: % so the result is not really a square! wolffd@0: % wolffd@0: % If true_params = 0, we use random parameters at the top 2 levels wolffd@0: % (ready for learning). At the bottom level, we use noisy versions wolffd@0: % of the "true" observations. wolffd@0: % wolffd@0: % If topright=1, counter-clockwise starts at top right, not top left wolffd@0: % This example was inspired by Ivanov and Bobick. wolffd@0: wolffd@0: if nargin < 3, topright = 1; end wolffd@0: wolffd@0: if 1 % discrete_obs wolffd@0: Qsizes = [2 4 2]; wolffd@0: else wolffd@0: Qsizes = [2 4 1]; wolffd@0: end wolffd@0: wolffd@0: D = 3; wolffd@0: Qnodes = 1:D; wolffd@0: startprob = cell(1,D); wolffd@0: transprob = cell(1,D); wolffd@0: termprob = cell(1,D); wolffd@0: wolffd@0: % LEVEL 1 wolffd@0: wolffd@0: startprob{1} = 'unif'; wolffd@0: transprob{1} = 'unif'; wolffd@0: wolffd@0: % LEVEL 2 wolffd@0: wolffd@0: if true_params wolffd@0: startprob{2} = zeros(2, 4); wolffd@0: startprob{2}(1, :) = [1 0 0 0]; wolffd@0: if topright wolffd@0: startprob{2}(2, :) = [0 0 1 0]; wolffd@0: else wolffd@0: startprob{2}(2, :) = [0 1 0 0]; wolffd@0: end wolffd@0: wolffd@0: transprob{2} = zeros(4, 2, 4); wolffd@0: wolffd@0: transprob{2}(:,1,:) = [0 1 0 0 wolffd@0: 0 0 1 0 wolffd@0: 0 0 0 1 wolffd@0: 0 0 0 1]; % 4->e wolffd@0: if topright wolffd@0: transprob{2}(:,2,:) = [0 0 0 1 wolffd@0: 1 0 0 0 wolffd@0: 0 1 0 0 wolffd@0: 0 0 0 1]; % 4->e wolffd@0: else wolffd@0: transprob{2}(:,2,:) = [0 0 0 1 wolffd@0: 1 0 0 0 wolffd@0: 0 0 1 0 % 3->e wolffd@0: 0 0 1 0]; wolffd@0: end wolffd@0: wolffd@0: %termprob{2} = 'rightstop'; wolffd@0: termprob{2} = zeros(2,4); wolffd@0: pfin = 0.8; wolffd@0: termprob{2}(1,:) = [0 0 0 pfin]; % finish in state 4 (DU) wolffd@0: if topright wolffd@0: termprob{2}(2,:) = [0 0 0 pfin]; wolffd@0: else wolffd@0: termprob{2}(2,:) = [0 0 pfin 0]; % finish in state 3 (RL) wolffd@0: end wolffd@0: else wolffd@0: % In the unsupervised case, it is essential that we break symmetry wolffd@0: % in the initial param estimates. wolffd@0: %startprob{2} = 'unif'; wolffd@0: %transprob{2} = 'unif'; wolffd@0: %termprob{2} = 'unif'; wolffd@0: startprob{2} = 'rnd'; wolffd@0: transprob{2} = 'rnd'; wolffd@0: termprob{2} = 'rnd'; wolffd@0: end wolffd@0: wolffd@0: % LEVEL 3 wolffd@0: wolffd@0: if 1 | true_params wolffd@0: startprob{3} = 'leftstart'; wolffd@0: transprob{3} = 'leftright'; wolffd@0: termprob{3} = 'rightstop'; wolffd@0: else wolffd@0: % If we want to be able to run a base-level model backwards... wolffd@0: startprob{3} = 'rnd'; wolffd@0: transprob{3} = 'rnd'; wolffd@0: termprob{3} = 'rnd'; wolffd@0: end wolffd@0: wolffd@0: wolffd@0: % OBS LEVEl wolffd@0: wolffd@0: if discrete_obs wolffd@0: % Initialise observations of lowest level primitives in a way which we can interpret wolffd@0: chars = ['L', 'l', 'U', 'u', 'R', 'r', 'D', 'd']; wolffd@0: L=find(chars=='L'); l=find(chars=='l'); wolffd@0: U=find(chars=='U'); u=find(chars=='u'); wolffd@0: R=find(chars=='R'); r=find(chars=='r'); wolffd@0: D=find(chars=='D'); d=find(chars=='d'); wolffd@0: Osize = length(chars); wolffd@0: wolffd@0: if true_params wolffd@0: p = 1; % makes each state fully observed wolffd@0: else wolffd@0: p = 0.9; wolffd@0: end wolffd@0: wolffd@0: obsprob = (1-p)*ones([4 2 Osize]); wolffd@0: % Q2 Q3 O wolffd@0: obsprob(1, 1, L) = p; wolffd@0: obsprob(1, 2, l) = p; wolffd@0: obsprob(2, 1, U) = p; wolffd@0: obsprob(2, 2, u) = p; wolffd@0: obsprob(3, 1, R) = p; wolffd@0: obsprob(3, 2, r) = p; wolffd@0: obsprob(4, 1, D) = p; wolffd@0: obsprob(4, 2, d) = p; wolffd@0: obsprob = mk_stochastic(obsprob); wolffd@0: Oargs = {'CPT', obsprob}; wolffd@0: else wolffd@0: % Initialise means of lowest level primitives in a way which we can interpret wolffd@0: % These means are little vectors in the east, south, west, north directions. wolffd@0: % (left-right=east, up-down=south, right-left=west, down-up=north) wolffd@0: Osize = 2; wolffd@0: mu = zeros(2, Qsizes(2), Qsizes(3)); wolffd@0: scale = 3; wolffd@0: if true_params wolffd@0: noise = 0; wolffd@0: else wolffd@0: noise = 0.5*scale; wolffd@0: end wolffd@0: for q3=1:Qsizes(3) wolffd@0: mu(:, 1, q3) = scale*[1;0] + noise*rand(2,1); wolffd@0: end wolffd@0: for q3=1:Qsizes(3) wolffd@0: mu(:, 2, q3) = scale*[0;-1] + noise*rand(2,1); wolffd@0: end wolffd@0: for q3=1:Qsizes(3) wolffd@0: mu(:, 3, q3) = scale*[-1;0] + noise*rand(2,1); wolffd@0: end wolffd@0: for q3=1:Qsizes(3) wolffd@0: mu(:, 4, q3) = scale*[0;1] + noise*rand(2,1); wolffd@0: end wolffd@0: Sigma = repmat(reshape(scale*eye(2), [2 2 1 1 ]), [1 1 Qsizes(2) Qsizes(3)]); wolffd@0: Oargs = {'mean', mu, 'cov', Sigma, 'cov_type', 'diag'}; wolffd@0: end wolffd@0: wolffd@0: if discrete_obs wolffd@0: selfprob = 0.5; wolffd@0: else wolffd@0: selfprob = 0.95; wolffd@0: % If less than this, it won't look like a square wolffd@0: % because it doesn't spend enough time in each state wolffd@0: % Unfortunately, the variance on durations (lengths of each side) wolffd@0: % is very large wolffd@0: end wolffd@0: bnet = mk_hhmm('Qsizes', Qsizes, 'Osize', Osize', 'discrete_obs', discrete_obs, ... wolffd@0: 'Oargs', Oargs, 'Ops', Qnodes(2:3), 'selfprob', selfprob, ... wolffd@0: 'startprob', startprob, 'transprob', transprob, 'termprob', termprob); wolffd@0: