Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@gaussian_CPD/update_ess.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv) | |
2 % UPDATE_ESS Update the Expected Sufficient Statistics of a Gaussian node | |
3 % function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv) | |
4 | |
5 %if nargin < 6 | |
6 % hidden_bitv = zeros(1, max(fmarginal.domain)); | |
7 % hidden_bitv(find(isempty(evidence)))=1; | |
8 %end | |
9 | |
10 dom = fmarginal.domain; | |
11 self = dom(end); | |
12 ps = dom(1:end-1); | |
13 cps = myintersect(ps, cnodes); | |
14 dps = mysetdiff(ps, cps); | |
15 | |
16 CPD.nsamples = CPD.nsamples + 1; | |
17 [ss cpsz dpsz] = size(CPD.weights); % ss = self size | |
18 [ss dpsz] = size(CPD.mean); | |
19 | |
20 % Let X be the cts parent (if any), Y be the cts child (self). | |
21 | |
22 if ~hidden_bitv(self) & ~any(hidden_bitv(cps)) & all(hidden_bitv(dps)) | |
23 % Speedup for the common case that all cts nodes are observed, all discrete nodes are hidden | |
24 % Since X and Y are observed, SYY = 0, SXX = 0, SXY = 0 | |
25 % Since discrete parents are hidden, we do not need to add evidence to w. | |
26 w = fmarginal.T(:); | |
27 CPD.Wsum = CPD.Wsum + w; | |
28 y = evidence{self}; | |
29 Cyy = y*y'; | |
30 if ~CPD.useC | |
31 WY = repmat(w(:)',ss,1); % WY(y,i) = w(i) | |
32 WYY = repmat(reshape(WY, [ss 1 dpsz]), [1 ss 1]); % WYY(y,y',i) = w(i) | |
33 %CPD.WYsum = CPD.WYsum + WY .* repmat(y(:), 1, dpsz); | |
34 CPD.WYsum = CPD.WYsum + y(:) * w(:)'; | |
35 CPD.WYYsum = CPD.WYYsum + WYY .* repmat(reshape(Cyy, [ss ss 1]), [1 1 dpsz]); | |
36 else | |
37 W = w(:)'; | |
38 W2 = reshape(W, [1 1 dpsz]); | |
39 CPD.WYsum = CPD.WYsum + rep_mult(W, y(:), size(CPD.WYsum)); | |
40 CPD.WYYsum = CPD.WYYsum + rep_mult(W2, Cyy, size(CPD.WYYsum)); | |
41 end | |
42 if cpsz > 0 % X exists | |
43 x = cat(1, evidence{cps}); x = x(:); | |
44 Cxx = x*x'; | |
45 Cxy = x*y'; | |
46 WX = repmat(w(:)',cpsz,1); % WX(x,i) = w(i) | |
47 WXX = repmat(reshape(WX, [cpsz 1 dpsz]), [1 cpsz 1]); % WXX(x,x',i) = w(i) | |
48 WXY = repmat(reshape(WX, [cpsz 1 dpsz]), [1 ss 1]); % WXY(x,y,i) = w(i) | |
49 if ~CPD.useC | |
50 CPD.WXsum = CPD.WXsum + WX .* repmat(x(:), 1, dpsz); | |
51 CPD.WXXsum = CPD.WXXsum + WXX .* repmat(reshape(Cxx, [cpsz cpsz 1]), [1 1 dpsz]); | |
52 CPD.WXYsum = CPD.WXYsum + WXY .* repmat(reshape(Cxy, [cpsz ss 1]), [1 1 dpsz]); | |
53 else | |
54 CPD.WXsum = CPD.WXsum + rep_mult(W, x(:), size(CPD.WXsum)); | |
55 CPD.WXXsum = CPD.WXXsum + rep_mult(W2, Cxx, size(CPD.WXXsum)); | |
56 CPD.WXYsum = CPD.WXYsum + rep_mult(W2, Cxy, size(CPD.WXYsum)); | |
57 end | |
58 end | |
59 return; | |
60 end | |
61 | |
62 % general (non-vectorized) case | |
63 fullm = add_evidence_to_gmarginal(fmarginal, evidence, ns, cnodes); % slow! | |
64 | |
65 if dpsz == 1 % no discrete parents | |
66 w = 1; | |
67 else | |
68 w = fullm.T(:); | |
69 end | |
70 | |
71 CPD.Wsum = CPD.Wsum + w; | |
72 xi = 1:cpsz; | |
73 yi = (cpsz+1):(cpsz+ss); | |
74 for i=1:dpsz | |
75 muY = fullm.mu(yi, i); | |
76 SYY = fullm.Sigma(yi, yi, i); | |
77 CPD.WYsum(:,i) = CPD.WYsum(:,i) + w(i)*muY; | |
78 CPD.WYYsum(:,:,i) = CPD.WYYsum(:,:,i) + w(i)*(SYY + muY*muY'); % E[X Y] = Cov[X,Y] + E[X] E[Y] | |
79 if cpsz > 0 | |
80 muX = fullm.mu(xi, i); | |
81 SXX = fullm.Sigma(xi, xi, i); | |
82 SXY = fullm.Sigma(xi, yi, i); | |
83 CPD.WXsum(:,i) = CPD.WXsum(:,i) + w(i)*muX; | |
84 CPD.WXXsum(:,:,i) = CPD.WXXsum(:,:,i) + w(i)*(SXX + muX*muX'); | |
85 CPD.WXYsum(:,:,i) = CPD.WXYsum(:,:,i) + w(i)*(SXY + muX*muY'); | |
86 end | |
87 end | |
88 |