Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@mlp_CPD/update_ess.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPD = update_ess(CPD, fmarginal, evidence, ns, cnodes, hidden_bitv) | |
2 % UPDATE_ESS Update the Expected Sufficient Statistics of a CPD (MLP) | |
3 % CPD = update_ess(CPD, family_marginal, evidence, node_sizes, cnodes, hidden_bitv) | |
4 % | |
5 % fmarginal = overall posterior distribution of self and its parents | |
6 % fmarginal(i1,i2...,ik,s)=prob(Pa1=i1,...,Pak=ik, self=s| X) | |
7 % | |
8 % => 1) prob(self|Pa1,...,Pak)=fmarginal/prob(Pa1,...,Pak) with prob(Pa1,...,Pak)=sum{s,fmarginal} | |
9 % [self estimation -> CPD.self_vals] | |
10 % 2) prob(Pa1,...,Pak) [SCG weights -> CPD.eso_weights] | |
11 % | |
12 % Hidden_bitv is ignored | |
13 | |
14 % Written by Pierpaolo Brutti | |
15 | |
16 if ~adjustable_CPD(CPD), return; end | |
17 | |
18 dom = fmarginal.domain; | |
19 cdom = myintersect(dom, cnodes); | |
20 assert(~any(isemptycell(evidence(cdom)))); | |
21 ns(cdom)=1; | |
22 | |
23 self = dom(end); | |
24 ps=dom(1:end-1); | |
25 dpdom=mysetdiff(ps,cdom); | |
26 | |
27 dnodes = mysetdiff(1:length(ns), cnodes); | |
28 | |
29 ddom = myintersect(ps, dnodes); % | |
30 if isempty(evidence{self}), % if self is hidden in what follow we must | |
31 ddom = myintersect(dom, dnodes); % consider its dimension | |
32 end % | |
33 | |
34 odom = dom(~isemptycell(evidence(dom))); | |
35 hdom = dom(isemptycell(evidence(dom))); % hidden parents in domain | |
36 | |
37 dobs = myintersect(ddom, odom); | |
38 dvals = cat(1, evidence{dobs}); | |
39 ens = ns; % effective node sizes | |
40 ens(dobs) = 1; | |
41 | |
42 dpsz=prod(ns(dpdom)); | |
43 S=prod(ens(ddom)); | |
44 subs = ind2subv(ens(ddom), 1:S); | |
45 mask = find_equiv_posns(dobs, ddom); | |
46 for i=1:length(mask), | |
47 subs(:,mask(i)) = dvals(i); | |
48 end | |
49 supportedQs = subv2ind(ns(ddom), subs); | |
50 | |
51 Qarity = prod(ns(ddom)); | |
52 if isempty(ddom), | |
53 Qarity = 1; | |
54 end | |
55 fullm.T = zeros(Qarity, 1); | |
56 fullm.T(supportedQs) = fmarginal.T(:); | |
57 | |
58 % For dynamic (recurrent) net------------------------------------------------------------- | |
59 % ---------------------------------------------------------------------------------------- | |
60 high=size(evidence,1); % slice height | |
61 ss_ns=ns(1:high); % single slice nodes sizes | |
62 pos=self; % | |
63 slice_num=0; % | |
64 while pos>high, % | |
65 slice_num=slice_num+1; % find active slice | |
66 pos=pos-high; % pos=self posistion into a single slice | |
67 end % | |
68 | |
69 last_dim=pos-1; % | |
70 if isempty(evidence{self}), % | |
71 last_dim=pos; % | |
72 end % last_dim=last reshaping dimension | |
73 reg=dom-slice_num*high; | |
74 dex=myintersect(reg(find(reg>=0)), [1:last_dim]); % | |
75 rs_dim=ss_ns(dex); % reshaping dimensions | |
76 | |
77 if slice_num>0, | |
78 act_slice=[]; past_ancest=[]; % | |
79 act_slice=slice_num*high+[1:high]; % recover the active slice nodes | |
80 % past_ancest=mysetdiff(ddom, act_slice); | |
81 past_ancest=mysetdiff(ps, act_slice); % recover ancestors contained into past slices | |
82 app=ns(past_ancest); | |
83 rs_dim=[app(:)' rs_dim(:)']; % | |
84 end % | |
85 if length(rs_dim)==1, rs_dim=[1 rs_dim]; end % | |
86 if size(rs_dim,1)~=1, rs_dim=rs_dim'; end % | |
87 | |
88 fullm.T=reshape(fullm.T, rs_dim); % reshaping the marginal | |
89 | |
90 % ---------------------------------------------------------------------------------------- | |
91 % ---------------------------------------------------------------------------------------- | |
92 | |
93 % X = cts parent, R = discrete self | |
94 | |
95 % 1) observations vector -> CPD.parents_vals ------------------------------------------------- | |
96 x = cat(1, evidence{cdom}); | |
97 | |
98 % 2) weights vector -> CPD.eso_weights ------------------------------------------------------- | |
99 if isempty(evidence{self}) % R is hidden | |
100 sum_over=length(rs_dim); | |
101 app=sum(fullm.T, sum_over); | |
102 pesi=reshape(app,[dpsz,1]); | |
103 clear app; | |
104 else | |
105 pesi=reshape(fullm.T,[dpsz,1]); | |
106 end | |
107 | |
108 assert(approxeq(sum(pesi),1)); | |
109 | |
110 % 3) estimate (if R is hidden) or recover (if R is obs) self'value---------------------------- | |
111 if isempty(evidence{self}) % R is hidden | |
112 app=mk_stochastic(fullm.T); % P(self|Pa1,...,Pak)=fmarginal/prob(Pa1,...,Pak) | |
113 app=reshape(app,[dpsz ns(self)]); % matrix size: prod{j,ns(Paj)} x ns(self) | |
114 r=app; | |
115 clear app; | |
116 else | |
117 r = zeros(dpsz,ns(self)); | |
118 for i=1:dpsz | |
119 if pesi(i)~=0, r(i,evidence{self}) = 1; end | |
120 end | |
121 end | |
122 for i=1:dpsz | |
123 if pesi(i) ~=0, assert(approxeq(sum(r(i,:)),1)); end | |
124 end | |
125 | |
126 CPD.nsamples = CPD.nsamples + 1; | |
127 CPD.parent_vals(CPD.nsamples,:) = x(:)'; | |
128 for i=1:dpsz | |
129 CPD.eso_weights(CPD.nsamples,:,i)=pesi(i); | |
130 CPD.self_vals(CPD.nsamples,:,i) = r(i,:); | |
131 end |