To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @gaussian_CPD / Old / update_tied_ess.m @ 8:b5b38998ef3b

History | View | Annotate | Download (3.29 KB)

1
function CPD = update_tied_ess(CPD, domain, engine, evidence, ns, cnodes)
2

    
3
if ~adjustable_CPD(CPD), return; end
4
nCPDs = size(domain, 2);
5
fmarginal = cell(1, nCPDs);
6
for l=1:nCPDs
7
  fmarginal{l} = marginal_family(engine, nodes(l));
8
end
9

    
10
[ss cpsz dpsz] = size(CPD.weights);
11
if const_evidence_pattern(engine)
12
  dom = domain(:,1);
13
  dnodes = mysetdiff(1:length(ns), cnodes);
14
  ddom = myintersect(dom, dnodes);
15
  cdom = myintersect(dom, cnodes);
16
  odom = dom(~isemptycell(evidence(dom)));
17
  hdom = dom(isemptycell(evidence(dom)));
18
  % If all hidden nodes are discrete and all cts nodes are observed 
19
  % (e.g., HMM with Gaussian output)
20
  % we can add the observed evidence in parallel
21
  if mysubset(ddom, hdom) & mysubset(cdom, odom)
22
    [mu, Sigma, T] = add_cts_ev_to_marginals(fmarginal, evidence, ns, cnodes);
23
  else
24
    mu = zeros(ss, dpsz, nCPDs);
25
    Sigma = zeros(ss, ss, dpsz, nCPDs);
26
    T = zeros(dpsz, nCPDs);
27
    for l=1:nCPDs
28
      [mu(:,:,l), Sigma(:,:,:,l), T(:,l)] = add_ev_to_marginals(fmarginal{l}, evidence, ns, cnodes);
29
    end
30
  end
31
end
32
CPD.nsamples = CPD.nsamples + nCPDs;            
33

    
34

    
35
if dpsz == 1 % no discrete parents
36
  w = 1;
37
else
38
  w = fullm.T(:);
39
end
40
CPD.Wsum = CPD.Wsum + w;
41
% Let X be the cts parent (if any), Y be the cts child (self).
42
xi = 1:cpsz;
43
yi = (cpsz+1):(cpsz+ss);
44
for i=1:dpsz
45
  muY = fullm.mu(yi, i);
46
  SYY = fullm.Sigma(yi, yi, i);
47
  CPD.WYsum(:,i) = CPD.WYsum(:,i) + w(i)*muY;
48
  CPD.WYYsum(:,:,i) = CPD.WYYsum(:,:,i) + w(i)*(SYY + muY*muY'); % E[X Y] = Cov[X,Y] + E[X] E[Y]
49
  if cpsz > 0
50
    muX = fullm.mu(xi, i);
51
    SXX = fullm.Sigma(xi, xi, i);
52
    SXY = fullm.Sigma(xi, yi, i);
53
    CPD.WXsum(:,i) = CPD.WXsum(:,i) + w(i)*muX;
54
    CPD.WXYsum(:,:,i) = CPD.WXYsum(:,:,i) + w(i)*(SXY + muX*muY');
55
    CPD.WXXsum(:,:,i) = CPD.WXXsum(:,:,i) + w(i)*(SXX + muX*muX');
56
  end
57
end                
58

    
59

    
60
%%%%%%%%%%%%%
61

    
62
function fullm = add_evidence_to_marginal(fmarginal, evidence, ns, cnodes)
63

    
64

    
65
dom = fmarginal.domain;
66

    
67
% Find out which values of the discrete parents (if any) are compatible with 
68
% the discrete evidence (if any).
69
dnodes = mysetdiff(1:length(ns), cnodes);
70
ddom = myintersect(dom, dnodes);
71
cdom = myintersect(dom, cnodes);
72
odom = dom(~isemptycell(evidence(dom)));
73
hdom = dom(isemptycell(evidence(dom)));
74

    
75
dobs = myintersect(ddom, odom);
76
dvals = cat(1, evidence{dobs});
77
ens = ns; % effective node sizes
78
ens(dobs) = 1;
79
S = prod(ens(ddom));
80
subs = ind2subv(ens(ddom), 1:S);
81
mask = find_equiv_posns(dobs, ddom);
82
subs(mask) = dvals;
83
supportedQs = subv2ind(ns(ddom), subs);
84

    
85
if isempty(ddom)
86
  Qarity = 1;
87
else
88
  Qarity = prod(ns(ddom));
89
end
90
fullm.T = zeros(Qarity, 1);
91
fullm.T(supportedQs) = fmarginal.T(:);
92

    
93
% Now put the hidden cts parts into their right blocks,
94
% leaving the observed cts parts as 0.
95
cobs = myintersect(cdom, odom);
96
chid = myintersect(cdom, hdom);
97
cvals = cat(1, evidence{cobs});
98
n = sum(ns(cdom));
99
fullm.mu = zeros(n,Qarity);
100
fullm.Sigma = zeros(n,n,Qarity);
101

    
102
if ~isempty(chid)
103
  chid_blocks = block(find_equiv_posns(chid, cdom), ns(cdom));
104
end
105
if ~isempty(cobs)
106
  cobs_blocks = block(find_equiv_posns(cobs, cdom), ns(cdom));
107
end
108

    
109
for i=1:length(supportedQs)
110
  Q = supportedQs(i);
111
  if ~isempty(chid)
112
    fullm.mu(chid_blocks, Q) = fmarginal.mu(:, i);
113
    fullm.Sigma(chid_blocks, chid_blocks, Q) = fmarginal.Sigma(:,:,i);
114
  end
115
  if ~isempty(cobs)
116
    fullm.mu(cobs_blocks, Q) = cvals(:);
117
  end
118
end