To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @softmax_CPD / update_ess.m @ 8:b5b38998ef3b

History | View | Annotate | Download (4.05 KB)

1
function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv)
2
% UPDATE_ESS Update the Expected Sufficient Statistics of a softmax node
3
% function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv)
4
%
5
% fmarginal = overall posterior distribution of self and its parents
6
% fmarginal(i1,i2...,ik,s)=prob(Pa1=i1,...,Pak=ik, self=s| X)
7
% 
8
% => 1) prob(self|Pa1,...,Pak)=fmarginal/prob(Pa1,...,Pak) with prob(Pa1,...,Pak)=sum{s,fmarginal}
9
%       [self estimation -> CPD.self_vals]
10
% 	  2) prob(Pa1,...,Pak) [WIRLS weights -> CPD.eso_weights]
11
%
12
% Hidden_bitv is ignored
13

    
14
% Written by Pierpaolo Brutti
15

    
16
if ~adjustable_CPD(CPD), return; end
17

    
18
domain     = fmarginal.domain;                              
19
self       = domain(end);          
20
ps         = domain(1:end-1);                                     
21
cnodes     = domain(CPD.cpndx);
22
cps        = myintersect(domain, cnodes);                     
23
dps        = mysetdiff(ps, cps);                            
24
dn_use     = dps;
25
if isempty(evidence{self}) dn_use = [dn_use self]; end % if self is hidden we must consider its dimension  
26
dps_as_cps = domain(CPD.dps_as_cps.ndx);
27
odom       = domain(~isemptycell(evidence(domain))); 
28

    
29
ns = zeros(1, max(domain));
30
ns(domain) = CPD.sizes;     % CPD.sizes = bnet.node_sizes([ps self]);
31
ens = ns;                   % effective node sizes
32
ens(odom) = 1;              
33
dpsize = prod(ns(dps));
34

    
35
% Extract the params compatible with the observations (if any) on the discrete parents (if any)
36
dops = myintersect(dps, odom);
37
dpvals = cat(1, evidence{dops});
38

    
39
subs = ind2subv(ens(dn_use), 1:prod(ens(dn_use)));
40
dpmap = find_equiv_posns(dops, dn_use);
41
if ~isempty(dpmap), subs(:,dpmap) = subs(:,dpmap)+repmat(dpvals(:)',[size(subs,1) 1])-1; end
42
supportedQs = subv2ind(ns(dn_use), subs); subs=subs(1:prod(ens(dps)),1:length(dps));
43
Qarity = prod(ns(dn_use));
44
if isempty(dn_use), Qarity = 1; end   
45

    
46
fullm.T              = zeros(Qarity, 1);
47
fullm.T(supportedQs) = fmarginal.T(:);
48
rs_dim = CPD.sizes;    rs_dim(CPD.cpndx) = 1;           %
49
if ~isempty(evidence{self}), rs_dim(end)=1; end         % reshaping the marginal
50
fullm.T              = reshape(fullm.T, rs_dim);        %
51

    
52
% --------------------------------------------------------------------------------UPDATE--
53

    
54
CPD.nsamples = CPD.nsamples + 1;
55

    
56
% 1) observations vector -> CPD.parents_vals ---------------------------------------------
57
cpvals = cat(1, evidence{cps});
58

    
59
if ~isempty(dps_as_cps),   % ...get in the dp_as_cp parents... 
60
    separator          = CPD.dps_as_cps.separator;
61
    dp_as_cpmap        = find_equiv_posns(dps_as_cps, dps);       
62
    for i=1:dpsize,
63
        dp_as_cpvals=zeros(1,sum(ns(dps_as_cps)));
64
        possible_vals = ind2subv(ns(dps),i);
65
        ll=find(ismember(subs(:,dp_as_cpmap), possible_vals(dp_as_cpmap), 'rows')==1);   
66
        if ~isempty(ll),
67
            where_one = separator + possible_vals(dp_as_cpmap);
68
            dp_as_cpvals(where_one)=1;                            
69
        end
70
        CPD.parent_vals(CPD.nsamples,:,i) = [dp_as_cpvals(:); cpvals(:)]';
71
    end
72
else
73
    CPD.parent_vals(CPD.nsamples,:) = cpvals(:)';
74
end
75

    
76
% 2) weights vector -> CPD.eso_weights ----------------------------------------------------
77
if isempty(evidence{self}),             % self is hidden
78
    pesi=reshape(sum(fullm.T, length(rs_dim)),[dpsize,1]);
79
else
80
    pesi=reshape(fullm.T,[dpsize,1]);
81
end
82
assert(approxeq(sum(pesi),1));          % check
83

    
84
% 3) estimate (if R is hidden) or recover (if R is obs) self'value-------------------------
85
if isempty(evidence{self})                                  % P(self|Pa1,...,Pak)=fmarginal/prob(Pa1,...,Pak)
86
    r=reshape(mk_stochastic(fullm.T), [dpsize ns(self)]);   % matrix size: prod{j,ns(Paj)} x ns(self)      
87
else
88
    r = zeros(dpsize,ns(self));
89
    for i=1:dpsize, if pesi(i)~=0, r(i,evidence{self}) = 1; end; end
90
end
91
for i=1:dpsize, if pesi(i)~=0, assert(approxeq(sum(r(i,:)),1)); end; end     % check
92

    
93
% 4) save the previous values --------------------------------------------------------------
94
for i=1:dpsize
95
    CPD.eso_weights(CPD.nsamples,:,i)=pesi(i);
96
    CPD.self_vals(CPD.nsamples,:,i) = r(i,:); 
97
end