annotate toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/HHMM/mk_hhmm.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
rev   line source
Daniel@0 1 function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(varargin)
Daniel@0 2 % MK_HHMM Make a Hierarchical HMM
Daniel@0 3 % function [bnet, Qnodes, Fnodes, Onode] = mk_hhmm(...)
Daniel@0 4 %
Daniel@0 5 % e.g. 3-layer hierarchical HMM where level 1 only connects to level 2
Daniel@0 6 % and the parents of the observed node are levels 2 and 3.
Daniel@0 7 % (This DBN is the same as Fig 10 in my tech report.)
Daniel@0 8 %
Daniel@0 9 % Q1 ----------> Q1
Daniel@0 10 % | \ ^ |
Daniel@0 11 % | v / |
Daniel@0 12 % | F2 ------/ |
Daniel@0 13 % | ^ ^ \ |
Daniel@0 14 % | / | \ |
Daniel@0 15 % | / | ||
Daniel@0 16 % v | vv
Daniel@0 17 % Q2----| --------> Q2
Daniel@0 18 % /| \ | ^|
Daniel@0 19 % / | v | / |
Daniel@0 20 % | | F3 --------/ |
Daniel@0 21 % | | ^ \ |
Daniel@0 22 % | v / v v
Daniel@0 23 % | Q3 -----------> Q3
Daniel@0 24 % | |
Daniel@0 25 % \ |
Daniel@0 26 % v v
Daniel@0 27 % O
Daniel@0 28 %
Daniel@0 29 %
Daniel@0 30 % Optional arguments in name/value format [default value in brackets]
Daniel@0 31 %
Daniel@0 32 % Qsizes - sizes at each level [ none ]
Daniel@0 33 % allQ - 1 means level i connects to all Q levels below, 0 means just to i+1 [0]
Daniel@0 34 % transprob - transprob{d}(i,k,j) = P(Q(d,t)=j|Q(d,t-1)=i,Q(1:d-1,t)=k) ['leftright']
Daniel@0 35 % startprob - startprob{d}(k,j) = P(Q(d,t)=j|Q(1:d-1,t)=k) ['leftstart']
Daniel@0 36 % termprob - termprob{d}(k,j) = P(F(d,t)=2|Q(1:d-1,t)=k,Q(d,t)=j) for d>1 ['rightstop']
Daniel@0 37 % selfprop - prob of a self transition (termprob default = 1-selfprop) [0.8]
Daniel@0 38 % Osize - size of O node
Daniel@0 39 % discrete_obs - 1 means O is tabular_CPD, 0 means gaussian_CPD [0]
Daniel@0 40 % Oargs - cell array of args to pass to the O CPD [ {} ]
Daniel@0 41 % Ops - Q parents of O [Qnodes(end)]
Daniel@0 42 % F1 - 1 means level 1 can finish (restart), else there is no F1->Q1 arc [0]
Daniel@0 43 % clamp1 - 1 means we clamp the params of the Q nodes in slice 1 (Qt1params) [1]
Daniel@0 44 % Note: the Qt1params are startprob, which should be shared with other slices.
Daniel@0 45 % However, in the current implementation, the Qt1params will only be estimated
Daniel@0 46 % from the initial state of each sequence.
Daniel@0 47 %
Daniel@0 48 % For d=1, startprob{1}(1,j) is only used in the first slice and
Daniel@0 49 % termprob{1} is ignored, since we assume the top level never resets.
Daniel@0 50 % Also, transprob{1}(i,j) can be used instead of transprob{1}(i,1,j).
Daniel@0 51 %
Daniel@0 52 % leftstart means the model always starts in state 1.
Daniel@0 53 % rightstop means the model can only finish in its last state (Qsize(d)).
Daniel@0 54 % unif means each state is equally like to reach any other
Daniel@0 55 % rnd means the transition/starting probs are random (drawn from rand)
Daniel@0 56 %
Daniel@0 57 % Q1:QD in slice 1 are of type tabular_CPD
Daniel@0 58 % Q1:QD in slice 2 are of type hhmmQ_CPD.
Daniel@0 59 % F(2:D-1) is of type hhmmF_CPD, FD is of type tabular_CPD.
Daniel@0 60
Daniel@0 61 args = varargin;
Daniel@0 62 nargs = length(args);
Daniel@0 63
Daniel@0 64 % get sizes of nodes and topology
Daniel@0 65 Qsizes = [];
Daniel@0 66 Osize = [];
Daniel@0 67 allQ = 0;
Daniel@0 68 Ops = [];
Daniel@0 69 F1 = 0;
Daniel@0 70 for i=1:2:nargs
Daniel@0 71 switch args{i},
Daniel@0 72 case 'Qsizes', Qsizes = args{i+1};
Daniel@0 73 case 'Osize', Osize = args{i+1};
Daniel@0 74 case 'allQ', allQ = args{i+1};
Daniel@0 75 case 'Ops', Ops = args{i+1};
Daniel@0 76 case 'F1', F1 = args{i+1};
Daniel@0 77 end
Daniel@0 78 end
Daniel@0 79 if isempty(Qsizes), error('must specify Qsizes'); end
Daniel@0 80 if Osize==0, error('must specify Osize'); end
Daniel@0 81 D = length(Qsizes);
Daniel@0 82 Qnodes = 1:D;
Daniel@0 83
Daniel@0 84 if isempty(Ops), Ops = Qnodes(end); end
Daniel@0 85
Daniel@0 86
Daniel@0 87 [intra, inter, Qnodes, Fnodes, Onode] = mk_hhmm_topo(D, allQ, Ops, F1);
Daniel@0 88 ss = length(intra);
Daniel@0 89 names = {};
Daniel@0 90
Daniel@0 91 if F1
Daniel@0 92 Fnodes_ndx = Fnodes;
Daniel@0 93 else
Daniel@0 94 Fnodes_ndx = [-1 Fnodes]; % Fnodes(1) is a dummy index
Daniel@0 95 end
Daniel@0 96
Daniel@0 97 % set default params
Daniel@0 98 discrete_obs = 0;
Daniel@0 99 Oargs = {};
Daniel@0 100 startprob = cell(1,D);
Daniel@0 101 startprob{1} = 'unif';
Daniel@0 102 for d=2:D
Daniel@0 103 startprob{d} = 'leftstart';
Daniel@0 104 end
Daniel@0 105 transprob = cell(1,D);
Daniel@0 106 transprob{1} = 'unif';
Daniel@0 107 for d=2:D
Daniel@0 108 transprob{d} = 'leftright';
Daniel@0 109 end
Daniel@0 110 termprob = cell(1,D);
Daniel@0 111 for d=2:D
Daniel@0 112 termprob{d} = 'rightstop';
Daniel@0 113 end
Daniel@0 114 selfprob = 0.8;
Daniel@0 115 clamp1 = 1;
Daniel@0 116
Daniel@0 117 for i=1:2:nargs
Daniel@0 118 switch args{i},
Daniel@0 119 case 'discrete_obs', discrete_obs = args{i+1};
Daniel@0 120 case 'Oargs', Oargs = args{i+1};
Daniel@0 121 case 'startprob', startprob = args{i+1};
Daniel@0 122 case 'transprob', transprob = args{i+1};
Daniel@0 123 case 'termprob', termprob = args{i+1};
Daniel@0 124 case 'selfprob', selfprob = args{i+1};
Daniel@0 125 case 'clamp1', clamp1 = args{i+1};
Daniel@0 126 end
Daniel@0 127 end
Daniel@0 128
Daniel@0 129 ns = zeros(1,ss);
Daniel@0 130 ns(Qnodes) = Qsizes;
Daniel@0 131 ns(Onode) = Osize;
Daniel@0 132 ns(Fnodes) = 2;
Daniel@0 133
Daniel@0 134 dnodes = [Qnodes Fnodes];
Daniel@0 135 if discrete_obs
Daniel@0 136 dnodes = [dnodes Onode];
Daniel@0 137 end
Daniel@0 138 onodes = [Onode];
Daniel@0 139
Daniel@0 140 bnet = mk_dbn(intra, inter, ns, 'observed', onodes, 'discrete', dnodes, 'names', names);
Daniel@0 141 eclass = bnet.equiv_class;
Daniel@0 142
Daniel@0 143 for d=1:D
Daniel@0 144 if d==1
Daniel@0 145 Qps = [];
Daniel@0 146 elseif allQ
Daniel@0 147 Qps = Qnodes(1:d-1);
Daniel@0 148 else
Daniel@0 149 Qps = Qnodes(d-1);
Daniel@0 150 end
Daniel@0 151 Qpsz = prod(ns(Qps));
Daniel@0 152 Qsz = ns(Qnodes(d));
Daniel@0 153 if isstr(startprob{d})
Daniel@0 154 switch startprob{d}
Daniel@0 155 case 'unif', startprob{d} = mk_stochastic(ones(Qpsz, Qsz));
Daniel@0 156 case 'rnd', startprob{d} = mk_stochastic(rand(Qpsz, Qsz));
Daniel@0 157 case 'leftstart', startprob{d} = zeros(Qpsz, Qsz); startprob{d}(:,1) = 1;
Daniel@0 158 end
Daniel@0 159 end
Daniel@0 160 if isstr(transprob{d})
Daniel@0 161 switch transprob{d}
Daniel@0 162 case 'unif', transprob{d} = mk_stochastic(ones(Qsz, Qpsz, Qsz));
Daniel@0 163 case 'rnd', transprob{d} = mk_stochastic(rand(Qsz, Qpsz, Qsz));
Daniel@0 164 case 'leftright',
Daniel@0 165 LR = mk_leftright_transmat(Qsz, selfprob);
Daniel@0 166 temp = repmat(reshape(LR, [1 Qsz Qsz]), [Qpsz 1 1]); % transprob(k,i,j)
Daniel@0 167 transprob{d} = permute(temp, [2 1 3]); % now transprob(i,k,j)
Daniel@0 168 end
Daniel@0 169 end
Daniel@0 170 if isstr(termprob{d})
Daniel@0 171 switch termprob{d}
Daniel@0 172 case 'unif', termprob{d} = mk_stochastic(ones(Qpsz, Qsz, 2));
Daniel@0 173 case 'rnd', termprob{d} = mk_stochastic(rand(Qpsz, Qsz, 2));
Daniel@0 174 case 'rightstop',
Daniel@0 175 %termprob(k,i,t) Might terminate if i=Qsz; will not terminate if i<Qsz
Daniel@0 176 stopprob = 1-selfprob;
Daniel@0 177 termprob{d} = zeros(Qpsz, Qsz, 2);
Daniel@0 178 termprob{d}(:,Qsz,2) = stopprob;
Daniel@0 179 termprob{d}(:,Qsz,1) = 1-stopprob;
Daniel@0 180 termprob{d}(:,1:(Qsz-1),1) = 1;
Daniel@0 181 otherwise, error(['unrecognized termprob ' termprob{d}])
Daniel@0 182 end
Daniel@0 183 elseif d>1 % passed in termprob{d}(k,j)
Daniel@0 184 temp = termprob{d};
Daniel@0 185 termprob{d} = zeros(Qpsz, Qsz, 2);
Daniel@0 186 termprob{d}(:,:,2) = temp;
Daniel@0 187 termprob{d}(:,:,1) = ones(Qpsz,Qsz) - temp;
Daniel@0 188 end
Daniel@0 189 end
Daniel@0 190
Daniel@0 191
Daniel@0 192 % SLICE 1
Daniel@0 193
Daniel@0 194 for d=1:D
Daniel@0 195 bnet.CPD{eclass(Qnodes(d),1)} = tabular_CPD(bnet, Qnodes(d), 'CPT', startprob{d}, 'adjustable', clamp1);
Daniel@0 196 end
Daniel@0 197
Daniel@0 198 if F1
Daniel@0 199 d = 1;
Daniel@0 200 bnet.CPD{eclass(Fnodes_ndx(d),1)} = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ...
Daniel@0 201 'termprob', termprob{d});
Daniel@0 202 end
Daniel@0 203 for d=2:D-1
Daniel@0 204 if allQ
Daniel@0 205 Qps = Qnodes(1:d-1);
Daniel@0 206 else
Daniel@0 207 Qps = Qnodes(d-1);
Daniel@0 208 end
Daniel@0 209 bnet.CPD{eclass(Fnodes_ndx(d),1)} = hhmmF_CPD(bnet, Fnodes_ndx(d), Qnodes(d), Fnodes_ndx(d+1), ...
Daniel@0 210 'Qps', Qps, 'termprob', termprob{d});
Daniel@0 211 end
Daniel@0 212 bnet.CPD{eclass(Fnodes_ndx(D),1)} = tabular_CPD(bnet, Fnodes_ndx(D), 'CPT', termprob{D});
Daniel@0 213
Daniel@0 214 if discrete_obs
Daniel@0 215 bnet.CPD{eclass(Onode,1)} = tabular_CPD(bnet, Onode, Oargs{:});
Daniel@0 216 else
Daniel@0 217 bnet.CPD{eclass(Onode,1)} = gaussian_CPD(bnet, Onode, Oargs{:});
Daniel@0 218 end
Daniel@0 219
Daniel@0 220 % SLICE 2
Daniel@0 221
Daniel@0 222 %for d=1:D
Daniel@0 223 % bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, Qnodes, d, D, ...
Daniel@0 224 % 'startprob', startprob{d}, 'transprob', transprob{d}, ...
Daniel@0 225 % 'allQ', allQ);
Daniel@0 226 %end
Daniel@0 227
Daniel@0 228 d = 1;
Daniel@0 229 if F1
Daniel@0 230 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ...
Daniel@0 231 'Fbelow', Fnodes_ndx(d+1), ...
Daniel@0 232 'startprob', startprob{d}, 'transprob', transprob{d});
Daniel@0 233 else
Daniel@0 234 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, ...
Daniel@0 235 'Fbelow', Fnodes_ndx(d+1), ...
Daniel@0 236 'startprob', startprob{d}, 'transprob', transprob{d});
Daniel@0 237 end
Daniel@0 238 for d=2:D-1
Daniel@0 239 if allQ
Daniel@0 240 Qps = Qnodes(1:d-1);
Daniel@0 241 else
Daniel@0 242 Qps = Qnodes(d-1);
Daniel@0 243 end
Daniel@0 244 Qps = Qps + ss; % since all in slice 2
Daniel@0 245 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ...
Daniel@0 246 'Fbelow', Fnodes_ndx(d+1), 'Qps', Qps, ...
Daniel@0 247 'startprob', startprob{d}, 'transprob', transprob{d});
Daniel@0 248 end
Daniel@0 249 d = D;
Daniel@0 250 if allQ
Daniel@0 251 Qps = Qnodes(1:d-1);
Daniel@0 252 else
Daniel@0 253 Qps = Qnodes(d-1);
Daniel@0 254 end
Daniel@0 255 Qps = Qps + ss; % since all in slice 2
Daniel@0 256 bnet.CPD{eclass(Qnodes(d),2)} = hhmmQ_CPD(bnet, Qnodes(d)+ss, 'Fself', Fnodes_ndx(d), ...
Daniel@0 257 'Qps', Qps, ...
Daniel@0 258 'startprob', startprob{d}, 'transprob', transprob{d});