To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @tabular_CPD / tabular_CPD.m @ 8:b5b38998ef3b

History | View | Annotate | Download (5.45 KB)

1
function CPD = tabular_CPD(bnet, self, varargin)
2
% TABULAR_CPD Make a multinomial conditional prob. distrib. (CPT)
3
%
4
% CPD = tabular_CPD(bnet, node) creates a random CPT.
5
%
6
% The following arguments can be specified [default in brackets]
7
%
8
% CPT - specifies the params ['rnd']
9
%   - T means use table T; it will be reshaped to the size of node's family.
10
%   - 'rnd' creates rnd params (drawn from uniform)
11
%   - 'unif' creates a uniform distribution
12
% adjustable - 0 means don't adjust the parameters during learning [1]
13
% prior_type - defines type of prior ['none']
14
%  - 'none' means do ML estimation
15
%  - 'dirichlet' means add pseudo-counts to every cell
16
%  - 'entropic' means use a prior P(theta) propto exp(-H(theta)) (see Brand)
17
% dirichlet_weight - equivalent sample size (ess) of the dirichlet prior [1]
18
% dirichlet_type - defines the type of Dirichlet prior ['BDeu']
19
%  - 'unif' means put dirichlet_weight in every cell
20
%  - 'BDeu' means we put 'dirichlet_weight/(r q)' in every cell
21
%    where r = self_sz and q = prod(parent_sz) (see Heckerman)
22
% trim - 1 means trim redundant params (rows in CPT) when using entropic prior [0]
23
% entropic_pcases - list of assignments to the parents nodes when we should use 
24
%      the entropic prior; all other cases will be estimated using ML [1:psz]
25
% sparse - 1 means use 1D sparse array to represent CPT [0]
26
%
27
% e.g., tabular_CPD(bnet, i, 'CPT', T)
28
% e.g., tabular_CPD(bnet, i, 'CPT', 'unif', 'dirichlet_weight', 2, 'dirichlet_type', 'unif')
29
%
30
% REFERENCES
31
% M. Brand - "Structure learning in conditional probability models via an entropic  prior
32
%   and parameter extinction", Neural Computation 11 (1999): 1155--1182
33
% M. Brand - "Pattern discovery via entropy minimization" [covers annealing]
34
%   AI & Statistics 1999. Equation numbers refer to this paper, which is available from
35
%   www.merl.com/reports/docs/TR98-21.pdf
36
% D. Heckerman, D. Geiger and M. Chickering, 
37
%   "Learning Bayesian networks: the combination of knowledge and statistical data",
38
%   Microsoft Research Tech Report, 1994
39

    
40

    
41
if nargin==0
42
  % This occurs if we are trying to load an object from a file.
43
  CPD = init_fields;
44
  CPD = class(CPD, 'tabular_CPD', discrete_CPD(0, []));
45
  return;
46
elseif isa(bnet, 'tabular_CPD')
47
  % This might occur if we are copying an object.
48
  CPD = bnet;
49
  return;
50
end
51
CPD = init_fields;
52

    
53
ns = bnet.node_sizes;
54
ps = parents(bnet.dag, self);
55
fam_sz = ns([ps self]);
56
psz = prod(ns(ps));
57
CPD.sizes = fam_sz;
58
CPD.leftright = 0;
59
CPD.sparse = 0;
60

    
61
% set defaults
62
CPD.CPT = mk_stochastic(myrand(fam_sz));
63
CPD.adjustable = 1;
64
CPD.prior_type = 'none';
65
dirichlet_type = 'BDeu';
66
dirichlet_weight = 1;
67
CPD.trim = 0;
68
selfprob = 0.1;
69
CPD.entropic_pcases = 1:psz;
70

    
71
% extract optional args
72
args = varargin;
73
% check for old syntax CPD(bnet, i, CPT) as opposed to CPD(bnet, i, 'CPT', CPT)
74
if ~isempty(args) & ~isstr(args{1})
75
  CPD.CPT = myreshape(args{1}, fam_sz);
76
  args = [];
77
end
78

    
79
for i=1:2:length(args)
80
  switch args{i},
81
   case 'CPT',
82
    T = args{i+1};
83
    if ischar(T)
84
      switch T
85
       case 'unif', CPD.CPT = mk_stochastic(myones(fam_sz));
86
       case 'rnd',  CPD.CPT = mk_stochastic(myrand(fam_sz));
87
       otherwise,   error(['invalid CPT ' T]);       
88
      end
89
    else
90
      CPD.CPT = myreshape(T, fam_sz);
91
    end
92
   case 'prior_type', CPD.prior_type = args{i+1};
93
   case 'dirichlet_type', dirichlet_type = args{i+1};
94
   case 'dirichlet_weight', dirichlet_weight = args{i+1};
95
   case 'adjustable', CPD.adjustable = args{i+1};
96
   case 'clamped', CPD.adjustable = ~args{i+1};
97
   case 'trim', CPD.trim = args{i+1};
98
   case 'entropic_pcases', CPD.entropic_pcases = args{i+1};
99
   case 'sparse', CPD.sparse = args{i+1};
100
   otherwise, error(['invalid argument name: ' args{i}]);       
101
  end
102
end
103

    
104
switch CPD.prior_type
105
 case 'dirichlet',
106
  switch dirichlet_type
107
   case 'unif', CPD.dirichlet = dirichlet_weight * myones(fam_sz);
108
   case 'BDeu',  CPD.dirichlet = (dirichlet_weight/psz) * mk_stochastic(myones(fam_sz));
109
   otherwise, error(['invalid dirichlet_type ' dirichlet_type])
110
  end
111
 case {'entropic', 'none'}
112
  CPD.dirichlet = [];
113
 otherwise, error(['invalid prior_type ' prior_type])
114
end
115

    
116
  
117

    
118
% fields to do with learning
119
if ~CPD.adjustable
120
  CPD.counts = [];
121
  CPD.nparams = 0;
122
  CPD.nsamples = [];
123
else
124
  %CPD.counts = zeros(size(CPD.CPT));
125
  CPD.counts = zeros(prod(size(CPD.CPT)), 1);
126
  psz = fam_sz(1:end-1);
127
  ss = fam_sz(end);
128
  if CPD.leftright
129
    % For each of the Qps contexts, we specify Q elements on the diagoanl
130
    CPD.nparams = Qps * Q;
131
  else
132
    % sum-to-1 constraint reduces the effective arity of the node by 1
133
    CPD.nparams = prod([psz ss-1]);
134
  end
135
  CPD.nsamples = 0;
136
end
137

    
138
CPD.trimmed_trans = [];
139
fam_sz = CPD.sizes;
140

    
141
%psz = prod(fam_sz(1:end-1));
142
%ssz = fam_sz(end);
143
%CPD.trimmed_trans = zeros(psz, ssz); % must declare before reading
144

    
145
%sparse CPT
146
if CPD.sparse
147
   CPD.CPT = sparse(CPD.CPT(:));
148
end
149

    
150
CPD = class(CPD, 'tabular_CPD', discrete_CPD(~CPD.adjustable, fam_sz));
151

    
152

    
153
%%%%%%%%%%%
154

    
155
function CPD = init_fields()
156
% This ensures we define the fields in the same order 
157
% no matter whether we load an object from a file,
158
% or create it from scratch. (Matlab requires this.)
159

    
160
CPD.CPT = [];
161
CPD.sizes = [];
162
CPD.prior_type = [];
163
CPD.dirichlet = [];
164
CPD.adjustable = [];
165
CPD.counts = [];
166
CPD.nparams = [];
167
CPD.nsamples = [];
168
CPD.trim = [];
169
CPD.trimmed_trans = [];
170
CPD.leftright = [];
171
CPD.entropic_pcases = [];
172
CPD.sparse = [];
173