Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@tabular_CPD/tabular_CPD.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPD = tabular_CPD(bnet, self, varargin) | |
2 % TABULAR_CPD Make a multinomial conditional prob. distrib. (CPT) | |
3 % | |
4 % CPD = tabular_CPD(bnet, node) creates a random CPT. | |
5 % | |
6 % The following arguments can be specified [default in brackets] | |
7 % | |
8 % CPT - specifies the params ['rnd'] | |
9 % - T means use table T; it will be reshaped to the size of node's family. | |
10 % - 'rnd' creates rnd params (drawn from uniform) | |
11 % - 'unif' creates a uniform distribution | |
12 % adjustable - 0 means don't adjust the parameters during learning [1] | |
13 % prior_type - defines type of prior ['none'] | |
14 % - 'none' means do ML estimation | |
15 % - 'dirichlet' means add pseudo-counts to every cell | |
16 % - 'entropic' means use a prior P(theta) propto exp(-H(theta)) (see Brand) | |
17 % dirichlet_weight - equivalent sample size (ess) of the dirichlet prior [1] | |
18 % dirichlet_type - defines the type of Dirichlet prior ['BDeu'] | |
19 % - 'unif' means put dirichlet_weight in every cell | |
20 % - 'BDeu' means we put 'dirichlet_weight/(r q)' in every cell | |
21 % where r = self_sz and q = prod(parent_sz) (see Heckerman) | |
22 % trim - 1 means trim redundant params (rows in CPT) when using entropic prior [0] | |
23 % entropic_pcases - list of assignments to the parents nodes when we should use | |
24 % the entropic prior; all other cases will be estimated using ML [1:psz] | |
25 % sparse - 1 means use 1D sparse array to represent CPT [0] | |
26 % | |
27 % e.g., tabular_CPD(bnet, i, 'CPT', T) | |
28 % e.g., tabular_CPD(bnet, i, 'CPT', 'unif', 'dirichlet_weight', 2, 'dirichlet_type', 'unif') | |
29 % | |
30 % REFERENCES | |
31 % M. Brand - "Structure learning in conditional probability models via an entropic prior | |
32 % and parameter extinction", Neural Computation 11 (1999): 1155--1182 | |
33 % M. Brand - "Pattern discovery via entropy minimization" [covers annealing] | |
34 % AI & Statistics 1999. Equation numbers refer to this paper, which is available from | |
35 % www.merl.com/reports/docs/TR98-21.pdf | |
36 % D. Heckerman, D. Geiger and M. Chickering, | |
37 % "Learning Bayesian networks: the combination of knowledge and statistical data", | |
38 % Microsoft Research Tech Report, 1994 | |
39 | |
40 | |
41 if nargin==0 | |
42 % This occurs if we are trying to load an object from a file. | |
43 CPD = init_fields; | |
44 CPD = class(CPD, 'tabular_CPD', discrete_CPD(0, [])); | |
45 return; | |
46 elseif isa(bnet, 'tabular_CPD') | |
47 % This might occur if we are copying an object. | |
48 CPD = bnet; | |
49 return; | |
50 end | |
51 CPD = init_fields; | |
52 | |
53 ns = bnet.node_sizes; | |
54 ps = parents(bnet.dag, self); | |
55 fam_sz = ns([ps self]); | |
56 psz = prod(ns(ps)); | |
57 CPD.sizes = fam_sz; | |
58 CPD.leftright = 0; | |
59 CPD.sparse = 0; | |
60 | |
61 % set defaults | |
62 CPD.CPT = mk_stochastic(myrand(fam_sz)); | |
63 CPD.adjustable = 1; | |
64 CPD.prior_type = 'none'; | |
65 dirichlet_type = 'BDeu'; | |
66 dirichlet_weight = 1; | |
67 CPD.trim = 0; | |
68 selfprob = 0.1; | |
69 CPD.entropic_pcases = 1:psz; | |
70 | |
71 % extract optional args | |
72 args = varargin; | |
73 % check for old syntax CPD(bnet, i, CPT) as opposed to CPD(bnet, i, 'CPT', CPT) | |
74 if ~isempty(args) & ~isstr(args{1}) | |
75 CPD.CPT = myreshape(args{1}, fam_sz); | |
76 args = []; | |
77 end | |
78 | |
79 for i=1:2:length(args) | |
80 switch args{i}, | |
81 case 'CPT', | |
82 T = args{i+1}; | |
83 if ischar(T) | |
84 switch T | |
85 case 'unif', CPD.CPT = mk_stochastic(myones(fam_sz)); | |
86 case 'rnd', CPD.CPT = mk_stochastic(myrand(fam_sz)); | |
87 otherwise, error(['invalid CPT ' T]); | |
88 end | |
89 else | |
90 CPD.CPT = myreshape(T, fam_sz); | |
91 end | |
92 case 'prior_type', CPD.prior_type = args{i+1}; | |
93 case 'dirichlet_type', dirichlet_type = args{i+1}; | |
94 case 'dirichlet_weight', dirichlet_weight = args{i+1}; | |
95 case 'adjustable', CPD.adjustable = args{i+1}; | |
96 case 'clamped', CPD.adjustable = ~args{i+1}; | |
97 case 'trim', CPD.trim = args{i+1}; | |
98 case 'entropic_pcases', CPD.entropic_pcases = args{i+1}; | |
99 case 'sparse', CPD.sparse = args{i+1}; | |
100 otherwise, error(['invalid argument name: ' args{i}]); | |
101 end | |
102 end | |
103 | |
104 switch CPD.prior_type | |
105 case 'dirichlet', | |
106 switch dirichlet_type | |
107 case 'unif', CPD.dirichlet = dirichlet_weight * myones(fam_sz); | |
108 case 'BDeu', CPD.dirichlet = (dirichlet_weight/psz) * mk_stochastic(myones(fam_sz)); | |
109 otherwise, error(['invalid dirichlet_type ' dirichlet_type]) | |
110 end | |
111 case {'entropic', 'none'} | |
112 CPD.dirichlet = []; | |
113 otherwise, error(['invalid prior_type ' prior_type]) | |
114 end | |
115 | |
116 | |
117 | |
118 % fields to do with learning | |
119 if ~CPD.adjustable | |
120 CPD.counts = []; | |
121 CPD.nparams = 0; | |
122 CPD.nsamples = []; | |
123 else | |
124 %CPD.counts = zeros(size(CPD.CPT)); | |
125 CPD.counts = zeros(prod(size(CPD.CPT)), 1); | |
126 psz = fam_sz(1:end-1); | |
127 ss = fam_sz(end); | |
128 if CPD.leftright | |
129 % For each of the Qps contexts, we specify Q elements on the diagoanl | |
130 CPD.nparams = Qps * Q; | |
131 else | |
132 % sum-to-1 constraint reduces the effective arity of the node by 1 | |
133 CPD.nparams = prod([psz ss-1]); | |
134 end | |
135 CPD.nsamples = 0; | |
136 end | |
137 | |
138 CPD.trimmed_trans = []; | |
139 fam_sz = CPD.sizes; | |
140 | |
141 %psz = prod(fam_sz(1:end-1)); | |
142 %ssz = fam_sz(end); | |
143 %CPD.trimmed_trans = zeros(psz, ssz); % must declare before reading | |
144 | |
145 %sparse CPT | |
146 if CPD.sparse | |
147 CPD.CPT = sparse(CPD.CPT(:)); | |
148 end | |
149 | |
150 CPD = class(CPD, 'tabular_CPD', discrete_CPD(~CPD.adjustable, fam_sz)); | |
151 | |
152 | |
153 %%%%%%%%%%% | |
154 | |
155 function CPD = init_fields() | |
156 % This ensures we define the fields in the same order | |
157 % no matter whether we load an object from a file, | |
158 % or create it from scratch. (Matlab requires this.) | |
159 | |
160 CPD.CPT = []; | |
161 CPD.sizes = []; | |
162 CPD.prior_type = []; | |
163 CPD.dirichlet = []; | |
164 CPD.adjustable = []; | |
165 CPD.counts = []; | |
166 CPD.nparams = []; | |
167 CPD.nsamples = []; | |
168 CPD.trim = []; | |
169 CPD.trimmed_trans = []; | |
170 CPD.leftright = []; | |
171 CPD.entropic_pcases = []; | |
172 CPD.sparse = []; | |
173 |