annotate toolboxes/FullBNT-1.0.7/bnt/CPDs/@hhmmQ_CPD/Old/update_ess2.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function CPD = update_ess2(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv)
wolffd@0 2 % UPDATE_ESS Update the Expected Sufficient Statistics of a hhmm Q node.
wolffd@0 3 % function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, idden_bitv)
wolffd@0 4
wolffd@0 5 % Figure out the node numbers associated with each parent
wolffd@0 6 dom = fmarginal.domain;
wolffd@0 7 self = dom(end); % by assumption
wolffd@0 8 old_self = dom(CPD.old_self_ndx);
wolffd@0 9 Fself = dom(CPD.Fself_ndx);
wolffd@0 10 Fbelow = dom(CPD.Fbelow_ndx);
wolffd@0 11 Qps = dom(CPD.Qps_ndx);
wolffd@0 12
wolffd@0 13 Qsz = CPD.Qsz;
wolffd@0 14 Qpsz = CPD.Qpsz;
wolffd@0 15
wolffd@0 16
wolffd@0 17 fmarg = add_ev_to_dmarginal(fmarginal, evidence, ns);
wolffd@0 18
wolffd@0 19
wolffd@0 20
wolffd@0 21 % hor_counts(old_self, Qps, self),
wolffd@0 22 % fmarginal(old_self, Fbelow, Fself, Qps, self)
wolffd@0 23 % hor_counts(i,k,j) = fmarginal(i,2,1,k,j) % below has finished, self has not
wolffd@0 24 % ver_counts(i,k,j) = fmarginal(i,2,2,k,j) % below has finished, and so has self (reset)
wolffd@0 25 % Since any of i,j,k may be observed, we write
wolffd@0 26 % hor_counts(counts_ndx{:}) = fmarginal(fmarg_ndx{:})
wolffd@0 27 % where e.g., counts_ndx = {1, ':', 2} if Qps is hidden but we observe old_self=1, self=2.
wolffd@0 28 % To create this counts_ndx, we write counts_ndx = mk_multi_ndx(3, obs_dim, obs_val)
wolffd@0 29 % where counts_obs_dim = [1 3], counts_obs_val = [1 2] specifies the values of dimensions 1 and 3.
wolffd@0 30
wolffd@0 31 counts_obs_dim = [];
wolffd@0 32 fmarg_obs_dim = [];
wolffd@0 33 obs_val = [];
wolffd@0 34 if hidden_bitv(self)
wolffd@0 35 effQsz = Qsz;
wolffd@0 36 else
wolffd@0 37 effQsz = 1;
wolffd@0 38 counts_obs_dim = [counts_obs_dim 3];
wolffd@0 39 fmarg_obs_dim = [fmarg_obs_dim 5];
wolffd@0 40 obs_val = [obs_val evidence{self}];
wolffd@0 41 end
wolffd@0 42
wolffd@0 43 % e.g., D=4, d=3, Qps = all Qs above, so dom = [Q3(t-1) F4(t-1) F3(t-1) Q1(t) Q2(t) Q3(t)].
wolffd@0 44 % so self = Q3(t), old_self = Q3(t-1), CPD.Qps = [1 2], Qps = [Q1(t) Q2(t)]
wolffd@0 45 dom = fmarginal.domain;
wolffd@0 46 self = dom(end);
wolffd@0 47 old_self = dom(1);
wolffd@0 48 Qps = dom(length(dom)-length(CPD.Qps):end-1);
wolffd@0 49
wolffd@0 50 Qsz = CPD.Qsizes(CPD.d);
wolffd@0 51 Qpsz = prod(CPD.Qsizes(CPD.Qps));
wolffd@0 52
wolffd@0 53 % If some of the Q nodes are observed (which happens during supervised training)
wolffd@0 54 % the counts will only be non-zero in positions
wolffd@0 55 % consistent with the evidence. We put the computed marginal responsibilities
wolffd@0 56 % into the appropriate slots of the big counts array.
wolffd@0 57 % (Recall that observed discrete nodes only have a single effective value.)
wolffd@0 58 % (A more general, but much slower, way is to call add_evidence_to_dmarginal.)
wolffd@0 59 % We assume the F nodes are never observed.
wolffd@0 60
wolffd@0 61 obs_self = ~hidden_bitv(self);
wolffd@0 62 obs_Qps = (~isempty(Qps)) & (~any(hidden_bitv(Qps))); % we assume that all or none of the Q parents are observed
wolffd@0 63
wolffd@0 64 if obs_self
wolffd@0 65 self_val = evidence{self};
wolffd@0 66 oldself_val = evidence{old_self};
wolffd@0 67 end
wolffd@0 68
wolffd@0 69 if obs_Qps
wolffd@0 70 Qps_val = subv2ind(Qpsz, cat(1, evidence{Qps}));
wolffd@0 71 if Qps_val == 0
wolffd@0 72 keyboard
wolffd@0 73 end
wolffd@0 74 end
wolffd@0 75
wolffd@0 76 if CPD.d==1 % no Qps from above
wolffd@0 77 if ~CPD.F1toQ1 % no F from self
wolffd@0 78 % marg(Q1(t-1), F2(t-1), Q1(t))
wolffd@0 79 % F2(t-1) P(Q1(t)=j | Q1(t-1)=i)
wolffd@0 80 % 1 delta(i,j)
wolffd@0 81 % 2 transprob(i,j)
wolffd@0 82 if obs_self
wolffd@0 83 hor_counts = zeros(Qsz, Qsz);
wolffd@0 84 hor_counts(oldself_val, self_val) = fmarginal.T(2);
wolffd@0 85 else
wolffd@0 86 marg = reshape(fmarginal.T, [Qsz 2 Qsz]);
wolffd@0 87 hor_counts = squeeze(marg(:,2,:));
wolffd@0 88 end
wolffd@0 89 else
wolffd@0 90 % marg(Q1(t-1), F2(t-1), F1(t-1), Q1(t))
wolffd@0 91 % F2(t-1) F1(t-1) P(Qd(t)=j| Qd(t-1)=i)
wolffd@0 92 % ------------------------------------------------------
wolffd@0 93 % 1 1 delta(i,j)
wolffd@0 94 % 2 1 transprob(i,j)
wolffd@0 95 % 1 2 impossible
wolffd@0 96 % 2 2 startprob(j)
wolffd@0 97 if obs_self
wolffd@0 98 marg = myreshape(fmarginal.T, [1 2 2 1]);
wolffd@0 99 hor_counts = zeros(Qsz, Qsz);
wolffd@0 100 hor_counts(oldself_val, self_val) = marg(1,2,1,1);
wolffd@0 101 ver_counts = zeros(Qsz, 1);
wolffd@0 102 %ver_counts(self_val) = marg(1,2,2,1);
wolffd@0 103 ver_counts(self_val) = marg(1,2,2,1) + marg(1,1,2,1);
wolffd@0 104 else
wolffd@0 105 marg = reshape(fmarginal.T, [Qsz 2 2 Qsz]);
wolffd@0 106 hor_counts = squeeze(marg(:,2,1,:));
wolffd@0 107 %ver_counts = squeeze(sum(marg(:,2,2,:),1)); % sum over i
wolffd@0 108 ver_counts = squeeze(sum(marg(:,2,2,:),1)) + squeeze(sum(marg(:,1,2,:),1)); % sum i,b
wolffd@0 109 end
wolffd@0 110 end % F1toQ1
wolffd@0 111 else % d ~= 1
wolffd@0 112 if CPD.d < CPD.D % general case
wolffd@0 113 % marg(Qd(t-1), Fd+1(t-1), Fd(t-1), Qps(t), Qd(t))
wolffd@0 114 % Fd+1(t-1) Fd(t-1) P(Qd(t)=j| Qd(t-1)=i, Qps(t)=k)
wolffd@0 115 % ------------------------------------------------------
wolffd@0 116 % 1 1 delta(i,j)
wolffd@0 117 % 2 1 transprob(i,k,j)
wolffd@0 118 % 1 2 impossible
wolffd@0 119 % 2 2 startprob(k,j)
wolffd@0 120 if obs_Qps & obs_self
wolffd@0 121 marg = myreshape(fmarginal.T, [1 2 2 1 1]);
wolffd@0 122 k = 1;
wolffd@0 123 hor_counts = zeros(Qsz, Qpsz, Qsz);
wolffd@0 124 hor_counts(oldself_val, Qps_val, self_val) = marg(1, 2,1, k,1);
wolffd@0 125 ver_counts = zeros(Qpsz, Qsz);
wolffd@0 126 %ver_counts(Qps_val, self_val) = marg(1, 2,2, k,1);
wolffd@0 127 ver_counts(Qps_val, self_val) = marg(1, 2,2, k,1) + marg(1, 1,2, k,1);
wolffd@0 128 elseif obs_Qps & ~obs_self
wolffd@0 129 marg = myreshape(fmarginal.T, [Qsz 2 2 1 Qsz]);
wolffd@0 130 k = 1;
wolffd@0 131 hor_counts = zeros(Qsz, Qpsz, Qsz);
wolffd@0 132 hor_counts(:, Qps_val, :) = marg(:, 2,1, k,:);
wolffd@0 133 ver_counts = zeros(Qpsz, Qsz);
wolffd@0 134 %ver_counts(Qps_val, :) = sum(marg(:, 2,2, k,:), 1);
wolffd@0 135 ver_counts(Qps_val, :) = sum(marg(:, 2,2, k,:), 1) + sum(marg(:, 1,2, k,:), 1);
wolffd@0 136 elseif ~obs_Qps & obs_self
wolffd@0 137 error('not yet implemented')
wolffd@0 138 else % everything is hidden
wolffd@0 139 marg = reshape(fmarginal.T, [Qsz 2 2 Qpsz Qsz]);
wolffd@0 140 hor_counts = squeeze(marg(:,2,1,:,:)); % i,k,j
wolffd@0 141 %ver_counts = squeeze(sum(marg(:,2,2,:,:),1)); % sum over i
wolffd@0 142 ver_counts = squeeze(sum(marg(:,2,2,:,:),1)) + squeeze(sum(marg(:,1,2,:,:),1)); % sum over i,b
wolffd@0 143 end
wolffd@0 144 else % d == D, so no F from below
wolffd@0 145 % marg(QD(t-1), FD(t-1), Qps(t), QD(t))
wolffd@0 146 % FD(t-1) P(QD(t)=j | QD(t-1)=i, Qps(t)=k)
wolffd@0 147 % 1 transprob(i,k,j)
wolffd@0 148 % 2 startprob(k,j)
wolffd@0 149 if obs_Qps & obs_self
wolffd@0 150 marg = myreshape(fmarginal.T, [1 2 1 1]);
wolffd@0 151 k = 1;
wolffd@0 152 hor_counts = zeros(Qsz, Qpsz, Qsz);
wolffd@0 153 hor_counts(oldself_val, Qps_val, self_val) = marg(1, 1, k,1);
wolffd@0 154 ver_counts = zeros(Qpsz, Qsz);
wolffd@0 155 ver_counts(Qps_val, self_val) = marg(1, 2, k,1);
wolffd@0 156 elseif obs_Qps & ~obs_self
wolffd@0 157 marg = myreshape(fmarginal.T, [Qsz 2 1 Qsz]);
wolffd@0 158 k = 1;
wolffd@0 159 hor_counts = zeros(Qsz, Qpsz, Qsz);
wolffd@0 160 hor_counts(:, Qps_val, :) = marg(:, 1, k,:);
wolffd@0 161 ver_counts = zeros(Qpsz, Qsz);
wolffd@0 162 ver_counts(Qps_val, :) = sum(marg(:, 2, k, :), 1);
wolffd@0 163 elseif ~obs_Qps & obs_self
wolffd@0 164 error('not yet implemented')
wolffd@0 165 else % everything is hidden
wolffd@0 166 marg = reshape(fmarginal.T, [Qsz 2 Qpsz Qsz]);
wolffd@0 167 hor_counts = squeeze(marg(:,1,:,:));
wolffd@0 168 ver_counts = squeeze(sum(marg(:,2,:,:),1)); % sum over i
wolffd@0 169 end
wolffd@0 170 end
wolffd@0 171 end
wolffd@0 172
wolffd@0 173 CPD.sub_CPD_trans = update_ess_simple(CPD.sub_CPD_trans, hor_counts);
wolffd@0 174
wolffd@0 175 if ~isempty(CPD.sub_CPD_start)
wolffd@0 176 CPD.sub_CPD_start = update_ess_simple(CPD.sub_CPD_start, ver_counts);
wolffd@0 177 end
wolffd@0 178