Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@gaussian_CPD/gaussian_CPD.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPD = gaussian_CPD(bnet, self, varargin) | |
2 % GAUSSIAN_CPD Make a conditional linear Gaussian distrib. | |
3 % | |
4 % CPD = gaussian_CPD(bnet, node, ...) will create a CPD with random parameters, | |
5 % where node is the number of a node in this equivalence class. | |
6 | |
7 % To define this CPD precisely, call the continuous (cts) parents (if any) X, | |
8 % the discrete parents (if any) Q, and this node Y. Then the distribution on Y is: | |
9 % - no parents: Y ~ N(mu, Sigma) | |
10 % - cts parents : Y|X=x ~ N(mu + W x, Sigma) | |
11 % - discrete parents: Y|Q=i ~ N(mu(i), Sigma(i)) | |
12 % - cts and discrete parents: Y|X=x,Q=i ~ N(mu(i) + W(i) x, Sigma(i)) | |
13 % | |
14 % The list below gives optional arguments [default value in brackets]. | |
15 % (Let ns(i) be the size of node i, X = ns(X), Y = ns(Y) and Q = prod(ns(Q)).) | |
16 % Parameters will be reshaped to the right size if necessary. | |
17 % | |
18 % mean - mu(:,i) is the mean given Q=i [ randn(Y,Q) ] | |
19 % cov - Sigma(:,:,i) is the covariance given Q=i [ repmat(100*eye(Y,Y), [1 1 Q]) ] | |
20 % weights - W(:,:,i) is the regression matrix given Q=i [ randn(Y,X,Q) ] | |
21 % cov_type - if 'diag', Sigma(:,:,i) is diagonal [ 'full' ] | |
22 % tied_cov - if 1, we constrain Sigma(:,:,i) to be the same for all i [0] | |
23 % clamp_mean - if 1, we do not adjust mu(:,i) during learning [0] | |
24 % clamp_cov - if 1, we do not adjust Sigma(:,:,i) during learning [0] | |
25 % clamp_weights - if 1, we do not adjust W(:,:,i) during learning [0] | |
26 % cov_prior_weight - weight given to I prior for estimating Sigma [0.01] | |
27 % cov_prior_entropic - if 1, we also use an entropic prior for Sigma [0] | |
28 % | |
29 % e.g., CPD = gaussian_CPD(bnet, i, 'mean', [0; 0], 'clamp_mean', 1) | |
30 | |
31 if nargin==0 | |
32 % This occurs if we are trying to load an object from a file. | |
33 CPD = init_fields; | |
34 clamp = 0; | |
35 CPD = class(CPD, 'gaussian_CPD', generic_CPD(clamp)); | |
36 return; | |
37 elseif isa(bnet, 'gaussian_CPD') | |
38 % This might occur if we are copying an object. | |
39 CPD = bnet; | |
40 return; | |
41 end | |
42 CPD = init_fields; | |
43 | |
44 CPD = class(CPD, 'gaussian_CPD', generic_CPD(0)); | |
45 | |
46 args = varargin; | |
47 ns = bnet.node_sizes; | |
48 ps = parents(bnet.dag, self); | |
49 dps = myintersect(ps, bnet.dnodes); | |
50 cps = myintersect(ps, bnet.cnodes); | |
51 fam_sz = ns([ps self]); | |
52 | |
53 CPD.self = self; | |
54 CPD.sizes = fam_sz; | |
55 | |
56 % Figure out which (if any) of the parents are discrete, and which cts, and how big they are | |
57 % dps = discrete parents, cps = cts parents | |
58 CPD.cps = find_equiv_posns(cps, ps); % cts parent index | |
59 CPD.dps = find_equiv_posns(dps, ps); | |
60 ss = fam_sz(end); | |
61 psz = fam_sz(1:end-1); | |
62 dpsz = prod(psz(CPD.dps)); | |
63 cpsz = sum(psz(CPD.cps)); | |
64 | |
65 % set default params | |
66 CPD.mean = randn(ss, dpsz); | |
67 CPD.cov = 100*repmat(eye(ss), [1 1 dpsz]); | |
68 CPD.weights = randn(ss, cpsz, dpsz); | |
69 CPD.cov_type = 'full'; | |
70 CPD.tied_cov = 0; | |
71 CPD.clamped_mean = 0; | |
72 CPD.clamped_cov = 0; | |
73 CPD.clamped_weights = 0; | |
74 CPD.cov_prior_weight = 0.01; | |
75 CPD.cov_prior_entropic = 0; | |
76 nargs = length(args); | |
77 if nargs > 0 | |
78 CPD = set_fields(CPD, args{:}); | |
79 end | |
80 | |
81 % Make sure the matrices have 1 dimension per discrete parent. | |
82 % Bug fix due to Xuejing Sun 3/6/01 | |
83 CPD.mean = myreshape(CPD.mean, [ss ns(dps)]); | |
84 CPD.cov = myreshape(CPD.cov, [ss ss ns(dps)]); | |
85 CPD.weights = myreshape(CPD.weights, [ss cpsz ns(dps)]); | |
86 | |
87 % Precompute indices into block structured matrices | |
88 % to speed up CPD_to_lambda_msg and CPD_to_pi | |
89 cpsizes = CPD.sizes(CPD.cps); | |
90 CPD.cps_block_ndx = cell(1, length(cps)); | |
91 for i=1:length(cps) | |
92 CPD.cps_block_ndx{i} = block(i, cpsizes); | |
93 end | |
94 | |
95 %%%%%%%%%%% | |
96 % Learning stuff | |
97 | |
98 % expected sufficient statistics | |
99 CPD.Wsum = zeros(dpsz,1); | |
100 CPD.WYsum = zeros(ss, dpsz); | |
101 CPD.WXsum = zeros(cpsz, dpsz); | |
102 CPD.WYYsum = zeros(ss, ss, dpsz); | |
103 CPD.WXXsum = zeros(cpsz, cpsz, dpsz); | |
104 CPD.WXYsum = zeros(cpsz, ss, dpsz); | |
105 | |
106 % For BIC | |
107 CPD.nsamples = 0; | |
108 switch CPD.cov_type | |
109 case 'full', | |
110 % since symmetric | |
111 %ncov_params = ss*(ss-1)/2; | |
112 ncov_params = ss*(ss+1)/2; | |
113 case 'diag', | |
114 ncov_params = ss; | |
115 otherwise | |
116 error(['unrecognized cov_type ' cov_type]); | |
117 end | |
118 % params = weights + mean + cov | |
119 if CPD.tied_cov | |
120 CPD.nparams = ss*cpsz*dpsz + ss*dpsz + ncov_params; | |
121 else | |
122 CPD.nparams = ss*cpsz*dpsz + ss*dpsz + dpsz*ncov_params; | |
123 end | |
124 | |
125 % for speeding up maximize_params | |
126 CPD.useC = exist('rep_mult'); | |
127 | |
128 clamped = CPD.clamped_mean & CPD.clamped_cov & CPD.clamped_weights; | |
129 CPD = set_clamped(CPD, clamped); | |
130 | |
131 %%%%%%%%%%% | |
132 | |
133 function CPD = init_fields() | |
134 % This ensures we define the fields in the same order | |
135 % no matter whether we load an object from a file, | |
136 % or create it from scratch. (Matlab requires this.) | |
137 | |
138 CPD.self = []; | |
139 CPD.sizes = []; | |
140 CPD.cps = []; | |
141 CPD.dps = []; | |
142 CPD.mean = []; | |
143 CPD.cov = []; | |
144 CPD.weights = []; | |
145 CPD.clamped_mean = []; | |
146 CPD.clamped_cov = []; | |
147 CPD.clamped_weights = []; | |
148 CPD.cov_type = []; | |
149 CPD.tied_cov = []; | |
150 CPD.Wsum = []; | |
151 CPD.WYsum = []; | |
152 CPD.WXsum = []; | |
153 CPD.WYYsum = []; | |
154 CPD.WXXsum = []; | |
155 CPD.WXYsum = []; | |
156 CPD.nsamples = []; | |
157 CPD.nparams = []; | |
158 CPD.cov_prior_weight = []; | |
159 CPD.cov_prior_entropic = []; | |
160 CPD.useC = []; | |
161 CPD.cps_block_ndx = []; |