wolffd@0
|
1 function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv)
|
wolffd@0
|
2 % UPDATE_ESS Update the Expected Sufficient Statistics of a Gaussian node
|
wolffd@0
|
3 % function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv)
|
wolffd@0
|
4
|
wolffd@0
|
5 %if nargin < 6
|
wolffd@0
|
6 % hidden_bitv = zeros(1, max(fmarginal.domain));
|
wolffd@0
|
7 % hidden_bitv(find(isempty(evidence)))=1;
|
wolffd@0
|
8 %end
|
wolffd@0
|
9
|
wolffd@0
|
10 dom = fmarginal.domain;
|
wolffd@0
|
11 self = dom(end);
|
wolffd@0
|
12 ps = dom(1:end-1);
|
wolffd@0
|
13 hidden_self = hidden_bitv(self);
|
wolffd@0
|
14 cps = myintersect(ps, cnodes);
|
wolffd@0
|
15 dps = mysetdiff(ps, cps);
|
wolffd@0
|
16 hidden_cps = all(hidden_bitv(cps));
|
wolffd@0
|
17 hidden_dps = all(hidden_bitv(dps));
|
wolffd@0
|
18
|
wolffd@0
|
19 CPD.nsamples = CPD.nsamples + 1;
|
wolffd@0
|
20 [ss cpsz dpsz] = size(CPD.weights); % ss = self size
|
wolffd@0
|
21
|
wolffd@0
|
22 % Let X be the cts parent (if any), Y be the cts child (self).
|
wolffd@0
|
23
|
wolffd@0
|
24 if ~hidden_self & (isempty(cps) | ~hidden_cps) & hidden_dps % all cts nodes are observed, all discrete nodes are hidden
|
wolffd@0
|
25 % Since X and Y are observed, SYY = 0, SXX = 0, SXY = 0
|
wolffd@0
|
26 % Since discrete parents are hidden, we do not need to add evidence to w.
|
wolffd@0
|
27 w = fmarginal.T(:);
|
wolffd@0
|
28 CPD.Wsum = CPD.Wsum + w;
|
wolffd@0
|
29 y = evidence{self};
|
wolffd@0
|
30 Cyy = y*y';
|
wolffd@0
|
31 if ~CPD.useC
|
wolffd@0
|
32 W = repmat(w(:)',ss,1); % W(y,i) = w(i)
|
wolffd@0
|
33 W2 = repmat(reshape(W, [ss 1 dpsz]), [1 ss 1]); % W2(x,y,i) = w(i)
|
wolffd@0
|
34 CPD.WYsum = CPD.WYsum + W .* repmat(y(:), 1, dpsz);
|
wolffd@0
|
35 CPD.WYYsum = CPD.WYYsum + W2 .* repmat(reshape(Cyy, [ss ss 1]), [1 1 dpsz]);
|
wolffd@0
|
36 else
|
wolffd@0
|
37 W = w(:)';
|
wolffd@0
|
38 W2 = reshape(W, [1 1 dpsz]);
|
wolffd@0
|
39 CPD.WYsum = CPD.WYsum + rep_mult(W, y(:), size(CPD.WYsum));
|
wolffd@0
|
40 CPD.WYYsum = CPD.WYYsum + rep_mult(W2, Cyy, size(CPD.WYYsum));
|
wolffd@0
|
41 end
|
wolffd@0
|
42 if cpsz > 0 % X exists
|
wolffd@0
|
43 x = cat(1, evidence{cps}); x = x(:);
|
wolffd@0
|
44 Cxx = x*x';
|
wolffd@0
|
45 Cxy = x*y';
|
wolffd@0
|
46 if ~CPD.useC
|
wolffd@0
|
47 CPD.WXsum = CPD.WXsum + W .* repmat(x(:), 1, dpsz);
|
wolffd@0
|
48 CPD.WXXsum = CPD.WXXsum + W2 .* repmat(reshape(Cxx, [cpsz cpsz 1]), [1 1 dpsz]);
|
wolffd@0
|
49 CPD.WXYsum = CPD.WXYsum + W2 .* repmat(reshape(Cxy, [cpsz ss 1]), [1 1 dpsz]);
|
wolffd@0
|
50 else
|
wolffd@0
|
51 CPD.WXsum = CPD.WXsum + rep_mult(W, x(:), size(CPD.WXsum));
|
wolffd@0
|
52 CPD.WXXsum = CPD.WXXsum + rep_mult(W2, Cxx, size(CPD.WXXsum));
|
wolffd@0
|
53 CPD.WXYsum = CPD.WXYsum + rep_mult(W2, Cxy, size(CPD.WXYsum));
|
wolffd@0
|
54 end
|
wolffd@0
|
55 end
|
wolffd@0
|
56 return;
|
wolffd@0
|
57 end
|
wolffd@0
|
58
|
wolffd@0
|
59 % general (non-vectorized) case
|
wolffd@0
|
60 fullm = add_evidence_to_gmarginal(fmarginal, evidence, ns, cnodes); % slow!
|
wolffd@0
|
61
|
wolffd@0
|
62 if dpsz == 1 % no discrete parents
|
wolffd@0
|
63 w = 1;
|
wolffd@0
|
64 else
|
wolffd@0
|
65 w = fullm.T(:);
|
wolffd@0
|
66 end
|
wolffd@0
|
67
|
wolffd@0
|
68 CPD.Wsum = CPD.Wsum + w;
|
wolffd@0
|
69 xi = 1:cpsz;
|
wolffd@0
|
70 yi = (cpsz+1):(cpsz+ss);
|
wolffd@0
|
71 for i=1:dpsz
|
wolffd@0
|
72 muY = fullm.mu(yi, i);
|
wolffd@0
|
73 SYY = fullm.Sigma(yi, yi, i);
|
wolffd@0
|
74 CPD.WYsum(:,i) = CPD.WYsum(:,i) + w(i)*muY;
|
wolffd@0
|
75 CPD.WYYsum(:,:,i) = CPD.WYYsum(:,:,i) + w(i)*(SYY + muY*muY'); % E[X Y] = Cov[X,Y] + E[X] E[Y]
|
wolffd@0
|
76 if cpsz > 0
|
wolffd@0
|
77 muX = fullm.mu(xi, i);
|
wolffd@0
|
78 SXX = fullm.Sigma(xi, xi, i);
|
wolffd@0
|
79 SXY = fullm.Sigma(xi, yi, i);
|
wolffd@0
|
80 CPD.WXsum(:,i) = CPD.WXsum(:,i) + w(i)*muX;
|
wolffd@0
|
81 CPD.WXXsum(:,:,i) = CPD.WXXsum(:,:,i) + w(i)*(SXX + muX*muX');
|
wolffd@0
|
82 CPD.WXYsum(:,:,i) = CPD.WXYsum(:,:,i) + w(i)*(SXY + muX*muY');
|
wolffd@0
|
83 end
|
wolffd@0
|
84 end
|
wolffd@0
|
85
|