To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @tabular_CPD / bayes_update_params.m @ 8:b5b38998ef3b

History | View | Annotate | Download (1.85 KB)

1
function CPD = bayes_update_params(CPD, self_ev, pev)
2
% UPDATE_PARAMS_COMPLETE Bayesian parameter updating given completely observed data (tabular)
3
% CPD = update_params_complete(CPD, self_ev, pev)
4
%
5
% self_ev(m) is the evidence on this node in case m.
6
% pev(i,m) is the evidence on the i'th parent in case m (if there are any parents).
7
% These can be arrays or cell arrays.
8
%
9
% We update the Dirichlet pseudo counts and set the CPT to the mean of the posterior.
10

    
11
if iscell(self_ev), usecell = 1; else usecell = 0; end
12

    
13
ncases = length(self_ev);
14
sz = CPD.sizes;
15
nparents = length(sz)-1;
16
assert(nparents == size(pev,1));
17

    
18
if ncases == 0 | ~adjustable_CPD(CPD)
19
  return;
20
elseif ncases == 1 % speedup the sequential learning case by avoiding normalization of the whole array
21
  if usecell
22
    x = cat(1, pev{:})';
23
    y = self_ev{1};
24
  else
25
    x = pev(:)';
26
    y = self_ev;
27
  end
28
  switch nparents
29
   case 0,
30
    CPD.dirichlet(y) = CPD.dirichlet(y)+1;
31
    CPD.CPT = CPD.dirichlet / sum(CPD.dirichlet);
32
   case 1,
33
    CPD.dirichlet(x(1), y) = CPD.dirichlet(x(1), y)+1;
34
    CPD.CPT(x(1), :) = CPD.dirichlet(x(1), :) ./ sum(CPD.dirichlet(x(1), :));
35
   case 2,
36
    CPD.dirichlet(x(1), x(2), y) = CPD.dirichlet(x(1), x(2), y)+1;
37
    CPD.CPT(x(1), x(2), :) = CPD.dirichlet(x(1), x(2), :) ./ sum(CPD.dirichlet(x(1), x(2), :));
38
   case 3,
39
    CPD.dirichlet(x(1), x(2), x(3), y) = CPD.dirichlet(x(1), x(2), x(3), y)+1;
40
    CPD.CPT(x(1), x(2), x(3), :) = CPD.dirichlet(x(1), x(2), x(3), :) ./ sum(CPD.dirichlet(x(1), x(2), x(3), :));
41
   otherwise,
42
    ind = subv2ind(sz, [x y]);
43
    CPD.dirichlet(ind) = CPD.dirichlet(ind) + 1;
44
    CPD.CPT = mk_stochastic(CPD.dirichlet);
45
  end
46
else  
47
  if usecell
48
    data = [cell2num(pev); cell2num(self_ev)]; 
49
  else
50
    data = [pev; self_ev];
51
  end
52
  counts = compute_counts(data, sz);
53
  CPD.dirichlet = CPD.dirichlet + counts;
54
  CPD.CPT = mk_stochastic(CPD.dirichlet);
55
end