To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @tabular_CPD / Old / tabular_CPD.m @ 8:b5b38998ef3b

History | View | Annotate | Download (5.83 KB)

1
function CPD = tabular_CPD(bnet, self, varargin)
2
% TABULAR_CPD Make a multinomial conditional prob. distrib. (CPT)
3
%
4
% CPD = tabular_CPD(bnet, node) creates a random CPT.
5
%
6
% The following arguments can be specified [default in brackets]
7
%
8
% CPT - specifies the params ['rnd']
9
%   - T means use table T; it will be reshaped to the size of node's family.
10
%   - 'rnd' creates rnd params (drawn from uniform)
11
%   - 'unif' creates a uniform distribution
12
%   - 'leftright' only transitions from i to i/i+1 are allowed, for each non-self parent context.
13
%       The non-self parents are all parents except oldself.
14
% selfprob - The prob of transition from i to i if CPT = 'leftright' [0.1]
15
% old_self - id of the node corresponding to self in the previous slice [self-ss]
16
% adjustable - 0 means don't adjust the parameters during learning [1]
17
% prior_type - defines type of prior ['none']
18
%  - 'none' means do ML estimation
19
%  - 'dirichlet' means add pseudo-counts to every cell
20
%  - 'entropic' means use a prior P(theta) propto exp(-H(theta)) (see Brand)
21
% dirichlet_weight - equivalent sample size (ess) of the dirichlet prior [1]
22
% dirichlet_type - defines the type of Dirichlet prior ['BDeu']
23
%  - 'unif' means put dirichlet_weight in every cell
24
%  - 'BDeu' means we put 'dirichlet_weight/(r q)' in every cell
25
%    where r = self_sz and q = prod(parent_sz) (see Heckerman)
26
% trim - 1 means trim redundant params (rows in CPT) when using entropic prior [0]
27
%
28
% e.g., tabular_CPD(bnet, i, 'CPT', T)
29
% e.g., tabular_CPD(bnet, i, 'CPT', 'unif', 'dirichlet_weight', 2, 'dirichlet_type', 'unif')
30
%
31
% REFERENCES
32
% M. Brand - "Structure learning in conditional probability models via an entropic  prior
33
%   and parameter extinction", Neural Computation 11 (1999): 1155--1182
34
% M. Brand - "Pattern discovery via entropy minimization" [covers annealing]
35
%   AI & Statistics 1999. Equation numbers refer to this paper, which is available from
36
%   www.merl.com/reports/docs/TR98-21.pdf
37
% D. Heckerman, D. Geiger and M. Chickering, 
38
%   "Learning Bayesian networks: the combination of knowledge and statistical data",
39
%   Microsoft Research Tech Report, 1994
40

    
41

    
42
if nargin==0
43
  % This occurs if we are trying to load an object from a file.
44
  CPD = init_fields;
45
  CPD = class(CPD, 'tabular_CPD', discrete_CPD(0, []));
46
  return;
47
elseif isa(bnet, 'tabular_CPD')
48
  % This might occur if we are copying an object.
49
  CPD = bnet;
50
  return;
51
end
52
CPD = init_fields;
53

    
54
ns = bnet.node_sizes;
55
ps = parents(bnet.dag, self);
56
fam_sz = ns([ps self]);
57
CPD.sizes = fam_sz;
58
CPD.leftright = 0;
59

    
60
% set defaults
61
CPD.CPT = mk_stochastic(myrand(fam_sz));
62
CPD.adjustable = 1;
63
CPD.prior_type = 'none';
64
dirichlet_type = 'BDeu';
65
dirichlet_weight = 1;
66
CPD.trim = 0;
67
selfprob = 0.1;
68

    
69
% extract optional args
70
args = varargin;
71
% check for old syntax CPD(bnet, i, CPT) as opposed to CPD(bnet, i, 'CPT', CPT)
72
if ~isempty(args) & ~isstr(args{1})
73
  CPD.CPT = myreshape(args{1}, fam_sz);
74
  args = [];
75
end
76

    
77
% if old_self is specified, read in the value before CPT is created
78
old_self = []; 
79
for i=1:2:length(args)
80
  switch args{i},
81
   case 'old_self', old_self = args{i+1};
82
  end
83
end
84

    
85
for i=1:2:length(args)
86
  switch args{i},
87
   case 'CPT',
88
    T = args{i+1};
89
    if ischar(T)
90
      switch T
91
       case 'unif', CPD.CPT = mk_stochastic(myones(fam_sz));
92
       case 'rnd',  CPD.CPT = mk_stochastic(myrand(fam_sz));
93
       case 'leftright', 
94
	% we just initialise the CPT to leftright - this structure will
95
	% be maintained by EM, assuming we don't use a prior...
96
	CPD.leftright = 1;
97
	if isempty(old_self) % we assume the network is a DBN
98
	  ss = bnet.nnodes_per_slice;
99
	  old_self = self-ss;
100
	end
101
	other_ps = mysetdiff(ps, old_self);
102
	Qps = prod(ns(other_ps));
103
	Q = ns(self);
104
	p = selfprob;
105
	LR = mk_leftright_transmat(Q, p);
106
	transprob = repmat(reshape(LR, [1 Q Q]), [Qps 1 1]); % transprob(k,i,j)
107
	transprob = permute(transprob, [2 1 3]); % now transprob(i,k,j)
108
	CPD.CPT = myreshape(transprob, fam_sz);
109
       otherwise,   error(['invalid CPT ' T]);       
110
      end
111
    else
112
      CPD.CPT = myreshape(T, fam_sz);
113
    end
114
    
115
   case 'prior_type', CPD.prior_type = args{i+1};
116
   case 'dirichlet_type', dirichlet_type = args{i+1};
117
   case 'dirichlet_weight', dirichlet_weight = args{i+1};
118
   case 'adjustable', CPD.adjustable = args{i+1};
119
   case 'clamped', CPD.adjustable = ~args{i+1};
120
   case 'trim', CPD.trim = args{i+1};
121
   case 'old_self', noop = 1; % already read in
122
   otherwise, error(['invalid argument name: ' args{i}]);       
123
  end
124
end
125

    
126
switch CPD.prior_type
127
 case 'dirichlet',
128
  switch dirichlet_type
129
   case 'unif', CPD.dirichlet = dirichlet_weight * myones(fam_sz);
130
   case 'BDeu',  CPD.dirichlet = dirichlet_weight * mk_stochastic(myones(fam_sz));
131
   otherwise, error(['invalid dirichlet_type ' dirichlet_type])
132
  end
133
 case {'entropic', 'none'}
134
  CPD.dirichlet = [];
135
 otherwise, error(['invalid prior_type ' prior_type])
136
end
137

    
138
  
139

    
140
% fields to do with learning
141
if ~CPD.adjustable
142
  CPD.counts = [];
143
  CPD.nparams = 0;
144
  CPD.nsamples = [];
145
else
146
  CPD.counts = zeros(size(CPD.CPT));
147
  psz = fam_sz(1:end-1);
148
  ss = fam_sz(end);
149
  if CPD.leftright
150
    % For each of the Qps contexts, we specify Q elements on the diagoanl
151
    CPD.nparams = Qps * Q;
152
  else
153
    % sum-to-1 constraint reduces the effective arity of the node by 1
154
    CPD.nparams = prod([psz ss-1]);
155
  end
156
  CPD.nsamples = 0;
157
end
158

    
159
fam_sz = CPD.sizes;
160
psz = prod(fam_sz(1:end-1));
161
ssz = fam_sz(end);
162
CPD.trimmed_trans = zeros(psz, ssz); % must declare before reading
163

    
164
CPD = class(CPD, 'tabular_CPD', discrete_CPD(~CPD.adjustable, fam_sz));
165

    
166

    
167
%%%%%%%%%%%
168

    
169
function CPD = init_fields()
170
% This ensures we define the fields in the same order 
171
% no matter whether we load an object from a file,
172
% or create it from scratch. (Matlab requires this.)
173

    
174
CPD.CPT = [];
175
CPD.sizes = [];
176
CPD.prior_type = [];
177
CPD.dirichlet = [];
178
CPD.adjustable = [];
179
CPD.counts = [];
180
CPD.nparams = [];
181
CPD.nsamples = [];
182
CPD.trim = [];
183
CPD.trimmed_trans = [];
184
CPD.leftright = [];
185

    
186