diff toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/mk_hhmm.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/mk_hhmm.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,258 @@
+function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(varargin)
+% MK_HHMM Make a Hierarchical HMM
+% function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(...)
+%
+% e.g. 3-layer hierarchical HMM where level 1 only connects to level 2
+% and the parents of the observed node are levels 2 and 3.
+% (This DBN is the same as Fig 10 in my tech report.)
+%
+%   Q1 ---------->   Q1
+%   |  \            ^ |
+%   |   v          /  |
+%   |    F2 ------/   |
+%   |   ^ ^        \  |
+%   |  /  |         \ |
+%   | /   |          ||
+%   v     |          vv
+%   Q2----| --------> Q2
+%  /| \   |          ^|
+% / |  v  |         / |
+% | |   F3 --------/  |
+% | |   ^          \  |
+% | v  /            v v
+% | Q3 ----------->  Q3
+% |  |    
+% \  | 
+%  v v    
+%   O
+%
+%
+% Optional arguments in name/value format [default value in brackets]
+%
+% Qsizes      - sizes at each level [ none ]
+% allQ       - 1 means level i connects to all Q levels below, 0 means just to i+1 [0]
+% transprob  - transprob{d}(i,k,j) = P(Q(d,t)=j|Q(d,t-1)=i,Q(1:d-1,t)=k)  ['leftright']
+% startprob  - startprob{d}(k,j) = P(Q(d,t)=j|Q(1:d-1,t)=k)  ['leftstart']
+% termprob   - termprob{d}(k,j) = P(F(d,t)=2|Q(1:d-1,t)=k,Q(d,t)=j) for d>1 ['rightstop']
+% selfprop   - prob of a self transition (termprob default = 1-selfprop) [0.8]
+% Osize       - size of O node
+% discrete_obs - 1 means O is tabular_CPD, 0 means gaussian_CPD [0]
+% Oargs       - cell array of args to pass to the O CPD  [ {} ]
+% Ops         - Q parents of O [Qnodes(end)]
+% F1          - 1 means level 1 can finish (restart), else there is no F1->Q1 arc [0]
+% clamp1     - 1 means we clamp the params of the Q nodes in slice 1 (Qt1params) [1]
+%   Note: the Qt1params are startprob, which should be shared with other slices.
+%   However, in the current implementation, the Qt1params will only be estimated
+%   from the initial state of each sequence.
+%
+% For d=1, startprob{1}(1,j) is only used in the first slice and
+% termprob{1} is ignored, since we assume the top level never resets.
+% Also, transprob{1}(i,j) can be used instead of transprob{1}(i,1,j).
+%
+% leftstart means the model always starts in state 1.
+% rightstop means the model can only finish in its last state (Qsize(d)).
+% unif means each state is equally like to reach any other
+% rnd means the transition/starting probs are random (drawn from rand)
+%
+% Q1:QD in slice 1 are of type tabular_CPD
+% Q1:QD in slice 2 are of type hhmmQ_CPD.
+% F(2:D-1) is of type hhmmF_CPD, FD is of type tabular_CPD.
+
+args = varargin;
+nargs = length(args);
+
+% get sizes of nodes and topology
+Qsizes = [];
+Osize = [];
+allQ = 0;
+Ops = [];
+F1 = 0;
+for i=1:2:nargs
+  switch args{i},
+   case 'Qsizes', Qsizes = args{i+1}; 
+   case 'Osize',  Osize = args{i+1}; 
+   case 'allQ',   allQ = args{i+1}; 
+   case 'Ops',    Ops = args{i+1}; 
+   case 'F1',     F1 = args{i+1}; 
+  end
+end
+if isempty(Qsizes), error('must specify Qsizes'); end
+if Osize==0, error('must specify Osize'); end
+D = length(Qsizes);
+Qnodes = 1:D;
+
+if isempty(Ops), Ops = Qnodes(end); end
+
+
+[intra, inter, Qnodes, Fnodes, Onode] = mk_hhmm_topo(D, allQ, Ops, F1);
+ss = length(intra);
+names = {};
+
+if F1
+  Fnodes_ndx = Fnodes;
+else
+  Fnodes_ndx = [-1 Fnodes]; % Fnodes(1) is a dummy index
+end
+		
+% set default params
+discrete_obs = 0;
+Oargs = {};
+startprob = cell(1,D);
+startprob{1} = 'unif';
+for d=2:D
+  startprob{d} = 'leftstart';
+end
+transprob = cell(1,D);
+transprob{1} = 'unif';
+for d=2:D
+  transprob{d} = 'leftright';
+end
+termprob = cell(1,D);
+for d=2:D
+  termprob{d} = 'rightstop';
+end
+selfprob = 0.8;
+clamp1 = 1;
+
+for i=1:2:nargs
+  switch args{i},
+   case 'discrete_obs', discrete_obs = args{i+1}; 
+   case 'Oargs',        Oargs = args{i+1};
+   case 'startprob',    startprob = args{i+1};
+   case 'transprob',    transprob = args{i+1};
+   case 'termprob',     termprob = args{i+1};
+   case 'selfprob',     selfprob = args{i+1};
+   case 'clamp1',       clamp1 = args{i+1};
+  end
+end
+
+ns = zeros(1,ss);
+ns(Qnodes) = Qsizes;
+ns(Onode) = Osize;
+ns(Fnodes) = 2;
+
+dnodes = [Qnodes Fnodes];
+if discrete_obs
+  dnodes = [dnodes Onode];
+end
+onodes = [Onode];
+
+bnet = mk_dbn(intra, inter, ns, 'observed', onodes, 'discrete', dnodes, 'names', names);
+eclass = bnet.equiv_class;
+
+for d=1:D
+  if d==1
+    Qps = [];
+  elseif allQ
+    Qps = Qnodes(1:d-1);
+  else
+    Qps = Qnodes(d-1);
+  end
+  Qpsz = prod(ns(Qps));
+  Qsz = ns(Qnodes(d));
+  if isstr(startprob{d})
+    switch startprob{d}
+     case 'unif', startprob{d} = mk_stochastic(ones(Qpsz, Qsz));
+     case 'rnd', startprob{d} = mk_stochastic(rand(Qpsz, Qsz));
+     case 'leftstart', startprob{d} = zeros(Qpsz, Qsz); startprob{d}(:,1) = 1;
+    end
+  end
+  if isstr(transprob{d})
+    switch transprob{d}
+     case 'unif', transprob{d} = mk_stochastic(ones(Qsz, Qpsz, Qsz));
+     case 'rnd', transprob{d} = mk_stochastic(rand(Qsz, Qpsz, Qsz));
+     case 'leftright',
+      LR = mk_leftright_transmat(Qsz, selfprob);
+      temp = repmat(reshape(LR, [1 Qsz Qsz]), [Qpsz 1 1]); % transprob(k,i,j)
+      transprob{d} = permute(temp, [2 1 3]); % now transprob(i,k,j)
+    end
+  end
+  if isstr(termprob{d})
+    switch termprob{d}
+     case 'unif', termprob{d} = mk_stochastic(ones(Qpsz, Qsz, 2));
+     case 'rnd', termprob{d} = mk_stochastic(rand(Qpsz, Qsz, 2));
+     case 'rightstop',
+      %termprob(k,i,t) Might terminate if i=Qsz; will not terminate if i<Qsz
+      stopprob = 1-selfprob;
+      termprob{d} = zeros(Qpsz, Qsz, 2);
+      termprob{d}(:,Qsz,2) = stopprob;
+      termprob{d}(:,Qsz,1) = 1-stopprob;
+      termprob{d}(:,1:(Qsz-1),1) = 1;
+     otherwise, error(['unrecognized termprob ' termprob{d}])
+    end
+  elseif d>1 % passed in termprob{d}(k,j)
+    temp = termprob{d};
+    termprob{d} = zeros(Qpsz, Qsz, 2);
+    termprob{d}(:,:,2) = temp;
+    termprob{d}(:,:,1) = ones(Qpsz,Qsz) - temp;
+  end
+end
+
+
+% SLICE 1
+
+for d=1:D
+  bnet.CPD{eclass(Qnodes(d),1)} = tabular_CPD(bnet, Qnodes(d), 'CPT', startprob{d}, 'adjustable', clamp1);
+end
+
+if F1
+  d = 1;
+  bnet.CPD{eclass(Fnodes_ndx(d),1)}  = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ...
+						 'termprob', termprob{d});
+end
+for d=2:D-1
+  if allQ
+    Qps = Qnodes(1:d-1);
+  else
+    Qps = Qnodes(d-1);
+  end
+  bnet.CPD{eclass(Fnodes_ndx(d),1)}  = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ...
+						 'Qps', Qps, 'termprob', termprob{d});
+end
+bnet.CPD{eclass(Fnodes_ndx(D),1)}  = tabular_CPD(bnet, Fnodes_ndx(D), 'CPT', termprob{D});
+
+if discrete_obs
+  bnet.CPD{eclass(Onode,1)} = tabular_CPD(bnet, Onode, Oargs{:});
+else
+  bnet.CPD{eclass(Onode,1)} = gaussian_CPD(bnet, Onode, Oargs{:});
+end
+
+% SLICE 2
+
+%for d=1:D
+%  bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, Qnodes, d, D, ...
+%					    'startprob', startprob{d}, 'transprob', transprob{d}, ...
+%					    'allQ', allQ);
+%end
+
+d = 1;
+if F1
+  bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ...
+					    'Fbelow', Fnodes_ndx(d+1), ...
+					    'startprob', startprob{d}, 'transprob', transprob{d});
+else
+    bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, ...
+					    'Fbelow', Fnodes_ndx(d+1), ...
+					    'startprob', startprob{d}, 'transprob', transprob{d});
+end
+for d=2:D-1
+  if allQ
+    Qps = Qnodes(1:d-1);
+  else
+    Qps = Qnodes(d-1);
+  end
+  Qps = Qps + ss; % since all in slice 2
+  bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ...
+					    'Fbelow', Fnodes_ndx(d+1), 'Qps', Qps, ...
+					    'startprob', startprob{d}, 'transprob', transprob{d});
+end
+d = D;
+if allQ
+  Qps = Qnodes(1:d-1);
+else
+  Qps = Qnodes(d-1);
+end
+Qps = Qps + ss; % since all in slice 2
+bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ...
+					  'Qps', Qps, ...
+					  'startprob', startprob{d}, 'transprob', transprob{d});