To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @mlp_CPD / update_ess.m @ 8:b5b38998ef3b

History | View | Annotate | Download (5.29 KB)

1
function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv)
2
% UPDATE_ESS Update the Expected Sufficient Statistics of a CPD (MLP)
3
% CPD = update_ess(CPD, family_marginal, evidence, node_sizes, cnodes, hidden_bitv)
4
%
5
% fmarginal = overall posterior distribution of self and its parents
6
% fmarginal(i1,i2...,ik,s)=prob(Pa1=i1,...,Pak=ik, self=s| X)
7
% 
8
% => 1) prob(self|Pa1,...,Pak)=fmarginal/prob(Pa1,...,Pak) with prob(Pa1,...,Pak)=sum{s,fmarginal}
9
%       [self estimation -> CPD.self_vals]
10
% 	  2) prob(Pa1,...,Pak) [SCG weights -> CPD.eso_weights]
11
%
12
% Hidden_bitv is ignored
13

    
14
% Written by Pierpaolo Brutti
15

    
16
if ~adjustable_CPD(CPD), return; end
17

    
18
dom = fmarginal.domain;                              
19
cdom = myintersect(dom, cnodes);                     
20
assert(~any(isemptycell(evidence(cdom))));           
21
ns(cdom)=1;
22

    
23
self = dom(end);                                  
24
ps=dom(1:end-1);                                     
25
dpdom=mysetdiff(ps,cdom);                            
26

    
27
dnodes = mysetdiff(1:length(ns), cnodes);            
28

    
29
ddom = myintersect(ps, dnodes);                      %
30
if isempty(evidence{self}),                          % if self is hidden in what follow we must 
31
    ddom = myintersect(dom, dnodes);                 % consider its dimension
32
end                                                  % 
33

    
34
odom = dom(~isemptycell(evidence(dom)));    
35
hdom = dom(isemptycell(evidence(dom)));              % hidden parents in domain
36
 
37
dobs = myintersect(ddom, odom);             
38
dvals = cat(1, evidence{dobs});             
39
ens = ns;                                            % effective node sizes              
40
ens(dobs) = 1;                              
41
                                            
42
dpsz=prod(ns(dpdom));
43
S=prod(ens(ddom));
44
subs = ind2subv(ens(ddom), 1:S);
45
mask = find_equiv_posns(dobs, ddom);
46
for i=1:length(mask),
47
    subs(:,mask(i)) = dvals(i);
48
end
49
supportedQs = subv2ind(ns(ddom), subs);
50

    
51
Qarity = prod(ns(ddom));
52
if isempty(ddom),                      
53
  Qarity = 1;                         
54
end                                
55
fullm.T = zeros(Qarity, 1);
56
fullm.T(supportedQs) = fmarginal.T(:);
57

    
58
% For dynamic (recurrent) net-------------------------------------------------------------
59
% ----------------------------------------------------------------------------------------
60
high=size(evidence,1);                                  % slice height
61
ss_ns=ns(1:high);                                       % single slice nodes sizes
62
pos=self;                                               %
63
slice_num=0;                                            %
64
while pos>high,                                         % 
65
    slice_num=slice_num+1;                              % find active slice
66
    pos=pos-high;                                       % pos=self posistion into a single slice
67
end                                                     %
68

    
69
last_dim=pos-1;                                         % 
70
if isempty(evidence{self}),                             % 
71
    last_dim=pos;                                       %
72
end                                                     % last_dim=last reshaping dimension      
73
reg=dom-slice_num*high;
74
dex=myintersect(reg(find(reg>=0)), [1:last_dim]);       %           
75
rs_dim=ss_ns(dex);                                      % reshaping dimensions
76

    
77
if slice_num>0,
78
    act_slice=[]; past_ancest=[];                       %
79
    act_slice=slice_num*high+[1:high];                  % recover the active slice nodes
80
    % past_ancest=mysetdiff(ddom, act_slice);
81
    past_ancest=mysetdiff(ps, act_slice);               % recover ancestors contained into past slices
82
    app=ns(past_ancest);
83
    rs_dim=[app(:)' rs_dim(:)'];                        %
84
end                                                     %
85
if length(rs_dim)==1, rs_dim=[1 rs_dim]; end            %
86
if size(rs_dim,1)~=1, rs_dim=rs_dim';    end            %
87

    
88
fullm.T=reshape(fullm.T, rs_dim);                       % reshaping the marginal
89

    
90
% ----------------------------------------------------------------------------------------
91
% ----------------------------------------------------------------------------------------
92

    
93
% X = cts parent, R = discrete self
94

    
95
% 1) observations vector -> CPD.parents_vals -------------------------------------------------
96
x = cat(1, evidence{cdom});
97

    
98
% 2) weights vector -> CPD.eso_weights -------------------------------------------------------
99
if isempty(evidence{self}) % R is hidden
100
    sum_over=length(rs_dim);
101
    app=sum(fullm.T, sum_over);    
102
    pesi=reshape(app,[dpsz,1]);
103
    clear app;
104
else
105
    pesi=reshape(fullm.T,[dpsz,1]);
106
end
107

    
108
assert(approxeq(sum(pesi),1));
109

    
110
% 3) estimate (if R is hidden) or recover (if R is obs) self'value----------------------------
111
if isempty(evidence{self})              % R is hidden    
112
    app=mk_stochastic(fullm.T);         % P(self|Pa1,...,Pak)=fmarginal/prob(Pa1,...,Pak)
113
    app=reshape(app,[dpsz ns(self)]);   % matrix size: prod{j,ns(Paj)} x ns(self)      
114
    r=app;
115
    clear app;
116
else
117
    r = zeros(dpsz,ns(self));
118
    for i=1:dpsz
119
        if pesi(i)~=0, r(i,evidence{self}) = 1; end
120
    end
121
end
122
for i=1:dpsz
123
    if pesi(i) ~=0, assert(approxeq(sum(r(i,:)),1)); end
124
end
125

    
126
CPD.nsamples = CPD.nsamples + 1;            
127
CPD.parent_vals(CPD.nsamples,:) = x(:)';
128
for i=1:dpsz
129
    CPD.eso_weights(CPD.nsamples,:,i)=pesi(i);
130
    CPD.self_vals(CPD.nsamples,:,i) = r(i,:); 
131
end