wolffd@0
|
1 function CPD = update_tied_ess(CPD, domain, engine, evidence, ns, cnodes)
|
wolffd@0
|
2
|
wolffd@0
|
3 if ~adjustable_CPD(CPD), return; end
|
wolffd@0
|
4 nCPDs = size(domain, 2);
|
wolffd@0
|
5 fmarginal = cell(1, nCPDs);
|
wolffd@0
|
6 for l=1:nCPDs
|
wolffd@0
|
7 fmarginal{l} = marginal_family(engine, nodes(l));
|
wolffd@0
|
8 end
|
wolffd@0
|
9
|
wolffd@0
|
10 [ss cpsz dpsz] = size(CPD.weights);
|
wolffd@0
|
11 if const_evidence_pattern(engine)
|
wolffd@0
|
12 dom = domain(:,1);
|
wolffd@0
|
13 dnodes = mysetdiff(1:length(ns), cnodes);
|
wolffd@0
|
14 ddom = myintersect(dom, dnodes);
|
wolffd@0
|
15 cdom = myintersect(dom, cnodes);
|
wolffd@0
|
16 odom = dom(~isemptycell(evidence(dom)));
|
wolffd@0
|
17 hdom = dom(isemptycell(evidence(dom)));
|
wolffd@0
|
18 % If all hidden nodes are discrete and all cts nodes are observed
|
wolffd@0
|
19 % (e.g., HMM with Gaussian output)
|
wolffd@0
|
20 % we can add the observed evidence in parallel
|
wolffd@0
|
21 if mysubset(ddom, hdom) & mysubset(cdom, odom)
|
wolffd@0
|
22 [mu, Sigma, T] = add_cts_ev_to_marginals(fmarginal, evidence, ns, cnodes);
|
wolffd@0
|
23 else
|
wolffd@0
|
24 mu = zeros(ss, dpsz, nCPDs);
|
wolffd@0
|
25 Sigma = zeros(ss, ss, dpsz, nCPDs);
|
wolffd@0
|
26 T = zeros(dpsz, nCPDs);
|
wolffd@0
|
27 for l=1:nCPDs
|
wolffd@0
|
28 [mu(:,:,l), Sigma(:,:,:,l), T(:,l)] = add_ev_to_marginals(fmarginal{l}, evidence, ns, cnodes);
|
wolffd@0
|
29 end
|
wolffd@0
|
30 end
|
wolffd@0
|
31 end
|
wolffd@0
|
32 CPD.nsamples = CPD.nsamples + nCPDs;
|
wolffd@0
|
33
|
wolffd@0
|
34
|
wolffd@0
|
35 if dpsz == 1 % no discrete parents
|
wolffd@0
|
36 w = 1;
|
wolffd@0
|
37 else
|
wolffd@0
|
38 w = fullm.T(:);
|
wolffd@0
|
39 end
|
wolffd@0
|
40 CPD.Wsum = CPD.Wsum + w;
|
wolffd@0
|
41 % Let X be the cts parent (if any), Y be the cts child (self).
|
wolffd@0
|
42 xi = 1:cpsz;
|
wolffd@0
|
43 yi = (cpsz+1):(cpsz+ss);
|
wolffd@0
|
44 for i=1:dpsz
|
wolffd@0
|
45 muY = fullm.mu(yi, i);
|
wolffd@0
|
46 SYY = fullm.Sigma(yi, yi, i);
|
wolffd@0
|
47 CPD.WYsum(:,i) = CPD.WYsum(:,i) + w(i)*muY;
|
wolffd@0
|
48 CPD.WYYsum(:,:,i) = CPD.WYYsum(:,:,i) + w(i)*(SYY + muY*muY'); % E[X Y] = Cov[X,Y] + E[X] E[Y]
|
wolffd@0
|
49 if cpsz > 0
|
wolffd@0
|
50 muX = fullm.mu(xi, i);
|
wolffd@0
|
51 SXX = fullm.Sigma(xi, xi, i);
|
wolffd@0
|
52 SXY = fullm.Sigma(xi, yi, i);
|
wolffd@0
|
53 CPD.WXsum(:,i) = CPD.WXsum(:,i) + w(i)*muX;
|
wolffd@0
|
54 CPD.WXYsum(:,:,i) = CPD.WXYsum(:,:,i) + w(i)*(SXY + muX*muY');
|
wolffd@0
|
55 CPD.WXXsum(:,:,i) = CPD.WXXsum(:,:,i) + w(i)*(SXX + muX*muX');
|
wolffd@0
|
56 end
|
wolffd@0
|
57 end
|
wolffd@0
|
58
|
wolffd@0
|
59
|
wolffd@0
|
60 %%%%%%%%%%%%%
|
wolffd@0
|
61
|
wolffd@0
|
62 function fullm = add_evidence_to_marginal(fmarginal, evidence, ns, cnodes)
|
wolffd@0
|
63
|
wolffd@0
|
64
|
wolffd@0
|
65 dom = fmarginal.domain;
|
wolffd@0
|
66
|
wolffd@0
|
67 % Find out which values of the discrete parents (if any) are compatible with
|
wolffd@0
|
68 % the discrete evidence (if any).
|
wolffd@0
|
69 dnodes = mysetdiff(1:length(ns), cnodes);
|
wolffd@0
|
70 ddom = myintersect(dom, dnodes);
|
wolffd@0
|
71 cdom = myintersect(dom, cnodes);
|
wolffd@0
|
72 odom = dom(~isemptycell(evidence(dom)));
|
wolffd@0
|
73 hdom = dom(isemptycell(evidence(dom)));
|
wolffd@0
|
74
|
wolffd@0
|
75 dobs = myintersect(ddom, odom);
|
wolffd@0
|
76 dvals = cat(1, evidence{dobs});
|
wolffd@0
|
77 ens = ns; % effective node sizes
|
wolffd@0
|
78 ens(dobs) = 1;
|
wolffd@0
|
79 S = prod(ens(ddom));
|
wolffd@0
|
80 subs = ind2subv(ens(ddom), 1:S);
|
wolffd@0
|
81 mask = find_equiv_posns(dobs, ddom);
|
wolffd@0
|
82 subs(mask) = dvals;
|
wolffd@0
|
83 supportedQs = subv2ind(ns(ddom), subs);
|
wolffd@0
|
84
|
wolffd@0
|
85 if isempty(ddom)
|
wolffd@0
|
86 Qarity = 1;
|
wolffd@0
|
87 else
|
wolffd@0
|
88 Qarity = prod(ns(ddom));
|
wolffd@0
|
89 end
|
wolffd@0
|
90 fullm.T = zeros(Qarity, 1);
|
wolffd@0
|
91 fullm.T(supportedQs) = fmarginal.T(:);
|
wolffd@0
|
92
|
wolffd@0
|
93 % Now put the hidden cts parts into their right blocks,
|
wolffd@0
|
94 % leaving the observed cts parts as 0.
|
wolffd@0
|
95 cobs = myintersect(cdom, odom);
|
wolffd@0
|
96 chid = myintersect(cdom, hdom);
|
wolffd@0
|
97 cvals = cat(1, evidence{cobs});
|
wolffd@0
|
98 n = sum(ns(cdom));
|
wolffd@0
|
99 fullm.mu = zeros(n,Qarity);
|
wolffd@0
|
100 fullm.Sigma = zeros(n,n,Qarity);
|
wolffd@0
|
101
|
wolffd@0
|
102 if ~isempty(chid)
|
wolffd@0
|
103 chid_blocks = block(find_equiv_posns(chid, cdom), ns(cdom));
|
wolffd@0
|
104 end
|
wolffd@0
|
105 if ~isempty(cobs)
|
wolffd@0
|
106 cobs_blocks = block(find_equiv_posns(cobs, cdom), ns(cdom));
|
wolffd@0
|
107 end
|
wolffd@0
|
108
|
wolffd@0
|
109 for i=1:length(supportedQs)
|
wolffd@0
|
110 Q = supportedQs(i);
|
wolffd@0
|
111 if ~isempty(chid)
|
wolffd@0
|
112 fullm.mu(chid_blocks, Q) = fmarginal.mu(:, i);
|
wolffd@0
|
113 fullm.Sigma(chid_blocks, chid_blocks, Q) = fmarginal.Sigma(:,:,i);
|
wolffd@0
|
114 end
|
wolffd@0
|
115 if ~isempty(cobs)
|
wolffd@0
|
116 fullm.mu(cobs_blocks, Q) = cvals(:);
|
wolffd@0
|
117 end
|
wolffd@0
|
118 end
|