To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @gaussian_CPD / update_ess.m @ 8:b5b38998ef3b

History | View | Annotate | Download (3.06 KB)

1
function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv)
2
% UPDATE_ESS Update the Expected Sufficient Statistics of a Gaussian node
3
% function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv)
4

    
5
%if nargin < 6
6
%  hidden_bitv = zeros(1, max(fmarginal.domain));
7
%  hidden_bitv(find(isempty(evidence)))=1;
8
%end
9

    
10
dom = fmarginal.domain;
11
self = dom(end);
12
ps = dom(1:end-1);
13
cps = myintersect(ps, cnodes);
14
dps = mysetdiff(ps, cps);
15

    
16
CPD.nsamples = CPD.nsamples + 1;            
17
[ss cpsz dpsz] = size(CPD.weights); % ss = self size
18
[ss dpsz] = size(CPD.mean);
19

    
20
% Let X be the cts parent (if any), Y be the cts child (self).
21

    
22
if ~hidden_bitv(self) & ~any(hidden_bitv(cps)) & all(hidden_bitv(dps))
23
  % Speedup for the common case that all cts nodes are observed, all discrete nodes are hidden
24
  % Since X and Y are observed, SYY = 0, SXX = 0, SXY = 0
25
  % Since discrete parents are hidden, we do not need to add evidence to w.
26
  w = fmarginal.T(:);
27
  CPD.Wsum = CPD.Wsum + w;
28
  y = evidence{self};
29
  Cyy = y*y';
30
  if ~CPD.useC
31
     WY = repmat(w(:)',ss,1); % WY(y,i) = w(i)
32
     WYY = repmat(reshape(WY, [ss 1 dpsz]), [1 ss 1]); % WYY(y,y',i) = w(i)
33
     %CPD.WYsum = CPD.WYsum +  WY .* repmat(y(:), 1, dpsz);
34
     CPD.WYsum = CPD.WYsum +  y(:) * w(:)';
35
     CPD.WYYsum = CPD.WYYsum + WYY  .* repmat(reshape(Cyy, [ss ss 1]), [1 1 dpsz]);
36
  else
37
     W = w(:)';
38
     W2 = reshape(W, [1 1 dpsz]);
39
     CPD.WYsum = CPD.WYsum +  rep_mult(W, y(:), size(CPD.WYsum)); 
40
     CPD.WYYsum = CPD.WYYsum + rep_mult(W2, Cyy, size(CPD.WYYsum));
41
  end
42
  if cpsz > 0 % X exists
43
    x = cat(1, evidence{cps}); x = x(:);
44
    Cxx = x*x';
45
    Cxy = x*y';
46
    WX = repmat(w(:)',cpsz,1); % WX(x,i) = w(i)
47
    WXX = repmat(reshape(WX, [cpsz 1 dpsz]), [1 cpsz 1]); % WXX(x,x',i) = w(i)
48
    WXY = repmat(reshape(WX, [cpsz 1 dpsz]), [1 ss 1]); % WXY(x,y,i) = w(i)
49
    if ~CPD.useC
50
      CPD.WXsum = CPD.WXsum + WX .* repmat(x(:), 1, dpsz);
51
      CPD.WXXsum = CPD.WXXsum + WXX .* repmat(reshape(Cxx, [cpsz cpsz 1]), [1 1 dpsz]);
52
      CPD.WXYsum = CPD.WXYsum + WXY .* repmat(reshape(Cxy, [cpsz ss 1]), [1 1 dpsz]);
53
    else
54
      CPD.WXsum = CPD.WXsum + rep_mult(W, x(:), size(CPD.WXsum));
55
      CPD.WXXsum = CPD.WXXsum + rep_mult(W2, Cxx, size(CPD.WXXsum));
56
      CPD.WXYsum = CPD.WXYsum + rep_mult(W2, Cxy, size(CPD.WXYsum));
57
    end
58
  end
59
  return;
60
end
61

    
62
% general (non-vectorized) case
63
fullm = add_evidence_to_gmarginal(fmarginal, evidence, ns, cnodes); % slow!
64

    
65
if dpsz == 1 % no discrete parents
66
  w = 1;
67
else
68
  w = fullm.T(:);
69
end
70

    
71
CPD.Wsum = CPD.Wsum + w;
72
xi = 1:cpsz;
73
yi = (cpsz+1):(cpsz+ss);
74
for i=1:dpsz
75
  muY = fullm.mu(yi, i);
76
  SYY = fullm.Sigma(yi, yi, i);
77
  CPD.WYsum(:,i) = CPD.WYsum(:,i) + w(i)*muY;
78
  CPD.WYYsum(:,:,i) = CPD.WYYsum(:,:,i) + w(i)*(SYY + muY*muY'); % E[X Y] = Cov[X,Y] + E[X] E[Y]
79
  if cpsz > 0
80
    muX = fullm.mu(xi, i);
81
    SXX = fullm.Sigma(xi, xi, i);
82
    SXY = fullm.Sigma(xi, yi, i);
83
    CPD.WXsum(:,i) = CPD.WXsum(:,i) + w(i)*muX;
84
    CPD.WXXsum(:,:,i) = CPD.WXXsum(:,:,i) + w(i)*(SXX + muX*muX');
85
    CPD.WXYsum(:,:,i) = CPD.WXYsum(:,:,i) + w(i)*(SXY + muX*muY');
86
  end
87
end                
88