wolffd@0: function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(varargin) wolffd@0: % MK_HHMM Make a Hierarchical HMM wolffd@0: % function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(...) wolffd@0: % wolffd@0: % e.g. 3-layer hierarchical HMM where level 1 only connects to level 2 wolffd@0: % and the parents of the observed node are levels 2 and 3. wolffd@0: % (This DBN is the same as Fig 10 in my tech report.) wolffd@0: % wolffd@0: % Q1 ----------> Q1 wolffd@0: % | \ ^ | wolffd@0: % | v / | wolffd@0: % | F2 ------/ | wolffd@0: % | ^ ^ \ | wolffd@0: % | / | \ | wolffd@0: % | / | || wolffd@0: % v | vv wolffd@0: % Q2----| --------> Q2 wolffd@0: % /| \ | ^| wolffd@0: % / | v | / | wolffd@0: % | | F3 --------/ | wolffd@0: % | | ^ \ | wolffd@0: % | v / v v wolffd@0: % | Q3 -----------> Q3 wolffd@0: % | | wolffd@0: % \ | wolffd@0: % v v wolffd@0: % O wolffd@0: % wolffd@0: % wolffd@0: % Optional arguments in name/value format [default value in brackets] wolffd@0: % wolffd@0: % Qsizes - sizes at each level [ none ] wolffd@0: % allQ - 1 means level i connects to all Q levels below, 0 means just to i+1 [0] wolffd@0: % transprob - transprob{d}(i,k,j) = P(Q(d,t)=j|Q(d,t-1)=i,Q(1:d-1,t)=k) ['leftright'] wolffd@0: % startprob - startprob{d}(k,j) = P(Q(d,t)=j|Q(1:d-1,t)=k) ['leftstart'] wolffd@0: % termprob - termprob{d}(k,j) = P(F(d,t)=2|Q(1:d-1,t)=k,Q(d,t)=j) for d>1 ['rightstop'] wolffd@0: % selfprop - prob of a self transition (termprob default = 1-selfprop) [0.8] wolffd@0: % Osize - size of O node wolffd@0: % discrete_obs - 1 means O is tabular_CPD, 0 means gaussian_CPD [0] wolffd@0: % Oargs - cell array of args to pass to the O CPD [ {} ] wolffd@0: % Ops - Q parents of O [Qnodes(end)] wolffd@0: % F1 - 1 means level 1 can finish (restart), else there is no F1->Q1 arc [0] wolffd@0: % clamp1 - 1 means we clamp the params of the Q nodes in slice 1 (Qt1params) [1] wolffd@0: % Note: the Qt1params are startprob, which should be shared with other slices. wolffd@0: % However, in the current implementation, the Qt1params will only be estimated wolffd@0: % from the initial state of each sequence. wolffd@0: % wolffd@0: % For d=1, startprob{1}(1,j) is only used in the first slice and wolffd@0: % termprob{1} is ignored, since we assume the top level never resets. wolffd@0: % Also, transprob{1}(i,j) can be used instead of transprob{1}(i,1,j). wolffd@0: % wolffd@0: % leftstart means the model always starts in state 1. wolffd@0: % rightstop means the model can only finish in its last state (Qsize(d)). wolffd@0: % unif means each state is equally like to reach any other wolffd@0: % rnd means the transition/starting probs are random (drawn from rand) wolffd@0: % wolffd@0: % Q1:QD in slice 1 are of type tabular_CPD wolffd@0: % Q1:QD in slice 2 are of type hhmmQ_CPD. wolffd@0: % F(2:D-1) is of type hhmmF_CPD, FD is of type tabular_CPD. wolffd@0: wolffd@0: args = varargin; wolffd@0: nargs = length(args); wolffd@0: wolffd@0: % get sizes of nodes and topology wolffd@0: Qsizes = []; wolffd@0: Osize = []; wolffd@0: allQ = 0; wolffd@0: Ops = []; wolffd@0: F1 = 0; wolffd@0: for i=1:2:nargs wolffd@0: switch args{i}, wolffd@0: case 'Qsizes', Qsizes = args{i+1}; wolffd@0: case 'Osize', Osize = args{i+1}; wolffd@0: case 'allQ', allQ = args{i+1}; wolffd@0: case 'Ops', Ops = args{i+1}; wolffd@0: case 'F1', F1 = args{i+1}; wolffd@0: end wolffd@0: end wolffd@0: if isempty(Qsizes), error('must specify Qsizes'); end wolffd@0: if Osize==0, error('must specify Osize'); end wolffd@0: D = length(Qsizes); wolffd@0: Qnodes = 1:D; wolffd@0: wolffd@0: if isempty(Ops), Ops = Qnodes(end); end wolffd@0: wolffd@0: wolffd@0: [intra, inter, Qnodes, Fnodes, Onode] = mk_hhmm_topo(D, allQ, Ops, F1); wolffd@0: ss = length(intra); wolffd@0: names = {}; wolffd@0: wolffd@0: if F1 wolffd@0: Fnodes_ndx = Fnodes; wolffd@0: else wolffd@0: Fnodes_ndx = [-1 Fnodes]; % Fnodes(1) is a dummy index wolffd@0: end wolffd@0: wolffd@0: % set default params wolffd@0: discrete_obs = 0; wolffd@0: Oargs = {}; wolffd@0: startprob = cell(1,D); wolffd@0: startprob{1} = 'unif'; wolffd@0: for d=2:D wolffd@0: startprob{d} = 'leftstart'; wolffd@0: end wolffd@0: transprob = cell(1,D); wolffd@0: transprob{1} = 'unif'; wolffd@0: for d=2:D wolffd@0: transprob{d} = 'leftright'; wolffd@0: end wolffd@0: termprob = cell(1,D); wolffd@0: for d=2:D wolffd@0: termprob{d} = 'rightstop'; wolffd@0: end wolffd@0: selfprob = 0.8; wolffd@0: clamp1 = 1; wolffd@0: wolffd@0: for i=1:2:nargs wolffd@0: switch args{i}, wolffd@0: case 'discrete_obs', discrete_obs = args{i+1}; wolffd@0: case 'Oargs', Oargs = args{i+1}; wolffd@0: case 'startprob', startprob = args{i+1}; wolffd@0: case 'transprob', transprob = args{i+1}; wolffd@0: case 'termprob', termprob = args{i+1}; wolffd@0: case 'selfprob', selfprob = args{i+1}; wolffd@0: case 'clamp1', clamp1 = args{i+1}; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: ns = zeros(1,ss); wolffd@0: ns(Qnodes) = Qsizes; wolffd@0: ns(Onode) = Osize; wolffd@0: ns(Fnodes) = 2; wolffd@0: wolffd@0: dnodes = [Qnodes Fnodes]; wolffd@0: if discrete_obs wolffd@0: dnodes = [dnodes Onode]; wolffd@0: end wolffd@0: onodes = [Onode]; wolffd@0: wolffd@0: bnet = mk_dbn(intra, inter, ns, 'observed', onodes, 'discrete', dnodes, 'names', names); wolffd@0: eclass = bnet.equiv_class; wolffd@0: wolffd@0: for d=1:D wolffd@0: if d==1 wolffd@0: Qps = []; wolffd@0: elseif allQ wolffd@0: Qps = Qnodes(1:d-1); wolffd@0: else wolffd@0: Qps = Qnodes(d-1); wolffd@0: end wolffd@0: Qpsz = prod(ns(Qps)); wolffd@0: Qsz = ns(Qnodes(d)); wolffd@0: if isstr(startprob{d}) wolffd@0: switch startprob{d} wolffd@0: case 'unif', startprob{d} = mk_stochastic(ones(Qpsz, Qsz)); wolffd@0: case 'rnd', startprob{d} = mk_stochastic(rand(Qpsz, Qsz)); wolffd@0: case 'leftstart', startprob{d} = zeros(Qpsz, Qsz); startprob{d}(:,1) = 1; wolffd@0: end wolffd@0: end wolffd@0: if isstr(transprob{d}) wolffd@0: switch transprob{d} wolffd@0: case 'unif', transprob{d} = mk_stochastic(ones(Qsz, Qpsz, Qsz)); wolffd@0: case 'rnd', transprob{d} = mk_stochastic(rand(Qsz, Qpsz, Qsz)); wolffd@0: case 'leftright', wolffd@0: LR = mk_leftright_transmat(Qsz, selfprob); wolffd@0: temp = repmat(reshape(LR, [1 Qsz Qsz]), [Qpsz 1 1]); % transprob(k,i,j) wolffd@0: transprob{d} = permute(temp, [2 1 3]); % now transprob(i,k,j) wolffd@0: end wolffd@0: end wolffd@0: if isstr(termprob{d}) wolffd@0: switch termprob{d} wolffd@0: case 'unif', termprob{d} = mk_stochastic(ones(Qpsz, Qsz, 2)); wolffd@0: case 'rnd', termprob{d} = mk_stochastic(rand(Qpsz, Qsz, 2)); wolffd@0: case 'rightstop', wolffd@0: %termprob(k,i,t) Might terminate if i=Qsz; will not terminate if i1 % passed in termprob{d}(k,j) wolffd@0: temp = termprob{d}; wolffd@0: termprob{d} = zeros(Qpsz, Qsz, 2); wolffd@0: termprob{d}(:,:,2) = temp; wolffd@0: termprob{d}(:,:,1) = ones(Qpsz,Qsz) - temp; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: % SLICE 1 wolffd@0: wolffd@0: for d=1:D wolffd@0: bnet.CPD{eclass(Qnodes(d),1)} = tabular_CPD(bnet, Qnodes(d), 'CPT', startprob{d}, 'adjustable', clamp1); wolffd@0: end wolffd@0: wolffd@0: if F1 wolffd@0: d = 1; wolffd@0: bnet.CPD{eclass(Fnodes_ndx(d),1)} = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ... wolffd@0: 'termprob', termprob{d}); wolffd@0: end wolffd@0: for d=2:D-1 wolffd@0: if allQ wolffd@0: Qps = Qnodes(1:d-1); wolffd@0: else wolffd@0: Qps = Qnodes(d-1); wolffd@0: end wolffd@0: bnet.CPD{eclass(Fnodes_ndx(d),1)} = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ... wolffd@0: 'Qps', Qps, 'termprob', termprob{d}); wolffd@0: end wolffd@0: bnet.CPD{eclass(Fnodes_ndx(D),1)} = tabular_CPD(bnet, Fnodes_ndx(D), 'CPT', termprob{D}); wolffd@0: wolffd@0: if discrete_obs wolffd@0: bnet.CPD{eclass(Onode,1)} = tabular_CPD(bnet, Onode, Oargs{:}); wolffd@0: else wolffd@0: bnet.CPD{eclass(Onode,1)} = gaussian_CPD(bnet, Onode, Oargs{:}); wolffd@0: end wolffd@0: wolffd@0: % SLICE 2 wolffd@0: wolffd@0: %for d=1:D wolffd@0: % bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, Qnodes, d, D, ... wolffd@0: % 'startprob', startprob{d}, 'transprob', transprob{d}, ... wolffd@0: % 'allQ', allQ); wolffd@0: %end wolffd@0: wolffd@0: d = 1; wolffd@0: if F1 wolffd@0: bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ... wolffd@0: 'Fbelow', Fnodes_ndx(d+1), ... wolffd@0: 'startprob', startprob{d}, 'transprob', transprob{d}); wolffd@0: else wolffd@0: bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, ... wolffd@0: 'Fbelow', Fnodes_ndx(d+1), ... wolffd@0: 'startprob', startprob{d}, 'transprob', transprob{d}); wolffd@0: end wolffd@0: for d=2:D-1 wolffd@0: if allQ wolffd@0: Qps = Qnodes(1:d-1); wolffd@0: else wolffd@0: Qps = Qnodes(d-1); wolffd@0: end wolffd@0: Qps = Qps + ss; % since all in slice 2 wolffd@0: bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ... wolffd@0: 'Fbelow', Fnodes_ndx(d+1), 'Qps', Qps, ... wolffd@0: 'startprob', startprob{d}, 'transprob', transprob{d}); wolffd@0: end wolffd@0: d = D; wolffd@0: if allQ wolffd@0: Qps = Qnodes(1:d-1); wolffd@0: else wolffd@0: Qps = Qnodes(d-1); wolffd@0: end wolffd@0: Qps = Qps + ss; % since all in slice 2 wolffd@0: bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ... wolffd@0: 'Qps', Qps, ... wolffd@0: 'startprob', startprob{d}, 'transprob', transprob{d});