To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.
root / _FullBNT / BNT / CPDs / @gaussian_CPD / Old / gaussian_CPD.m @ 8:b5b38998ef3b
History | View | Annotate | Download (5.9 KB)
| 1 |
function CPD = gaussian_CPD(varargin) |
|---|---|
| 2 |
% GAUSSIAN_CPD Make a conditional linear Gaussian distrib. |
| 3 |
% |
| 4 |
% To define this CPD precisely, call the continuous (cts) parents (if any) X, |
| 5 |
% the discrete parents (if any) Q, and this node Y. Then the distribution on Y is: |
| 6 |
% - no parents: Y ~ N(mu, Sigma) |
| 7 |
% - cts parents : Y|X=x ~ N(mu + W x, Sigma) |
| 8 |
% - discrete parents: Y|Q=i ~ N(mu(i), Sigma(i)) |
| 9 |
% - cts and discrete parents: Y|X=x,Q=i ~ N(mu(i) + W(i) x, Sigma(i)) |
| 10 |
% |
| 11 |
% CPD = gaussian_CPD(bnet, node, ...) will create a CPD with random parameters, |
| 12 |
% where node is the number of a node in this equivalence class. |
| 13 |
% |
| 14 |
% The list below gives optional arguments [default value in brackets]. |
| 15 |
% (Let ns(i) be the size of node i, X = ns(X), Y = ns(Y) and Q = prod(ns(Q)).) |
| 16 |
% |
| 17 |
% mean - mu(:,i) is the mean given Q=i [ randn(Y,Q) ] |
| 18 |
% cov - Sigma(:,:,i) is the covariance given Q=i [ repmat(eye(Y,Y), [1 1 Q]) ] |
| 19 |
% weights - W(:,:,i) is the regression matrix given Q=i [ randn(Y,X,Q) ] |
| 20 |
% cov_type - if 'diag', Sigma(:,:,i) is diagonal [ 'full' ] |
| 21 |
% tied_cov - if 1, we constrain Sigma(:,:,i) to be the same for all i [0] |
| 22 |
% clamp_mean - if 1, we do not adjust mu(:,i) during learning [0] |
| 23 |
% clamp_cov - if 1, we do not adjust Sigma(:,:,i) during learning [0] |
| 24 |
% clamp_weights - if 1, we do not adjust W(:,:,i) during learning [0] |
| 25 |
% cov_prior_weight - weight given to I prior for estimating Sigma [0.01] |
| 26 |
% |
| 27 |
% e.g., CPD = gaussian_CPD(bnet, i, 'mean', [0; 0], 'clamp_mean', 'yes') |
| 28 |
% |
| 29 |
% For backwards compatibility with BNT2, you can also specify the parameters in the following order |
| 30 |
% CPD = gaussian_CPD(bnet, self, mu, Sigma, W, cov_type, tied_cov, clamp_mean, clamp_cov, clamp_weight) |
| 31 |
% |
| 32 |
% Sometimes it is useful to create an "isolated" CPD, without needing to pass in a bnet. |
| 33 |
% In this case, you must specify the discrete and cts parents (dps, cps) and the family sizes, followed |
| 34 |
% by the optional arguments above: |
| 35 |
% CPD = gaussian_CPD('self', i, 'dps', dps, 'cps', cps, 'sz', fam_size, ...)
|
| 36 |
|
| 37 |
|
| 38 |
if nargin==0 |
| 39 |
% This occurs if we are trying to load an object from a file. |
| 40 |
CPD = init_fields; |
| 41 |
clamp = 0; |
| 42 |
CPD = class(CPD, 'gaussian_CPD', generic_CPD(clamp)); |
| 43 |
return; |
| 44 |
elseif isa(varargin{1}, 'gaussian_CPD')
|
| 45 |
% This might occur if we are copying an object. |
| 46 |
CPD = varargin{1};
|
| 47 |
return; |
| 48 |
end |
| 49 |
CPD = init_fields; |
| 50 |
|
| 51 |
CPD = class(CPD, 'gaussian_CPD', generic_CPD(0)); |
| 52 |
|
| 53 |
|
| 54 |
% parse mandatory arguments |
| 55 |
if ~isstr(varargin{1}) % pass in bnet
|
| 56 |
bnet = varargin{1};
|
| 57 |
self = varargin{2};
|
| 58 |
args = varargin(3:end); |
| 59 |
ns = bnet.node_sizes; |
| 60 |
ps = parents(bnet.dag, self); |
| 61 |
dps = myintersect(ps, bnet.dnodes); |
| 62 |
cps = myintersect(ps, bnet.cnodes); |
| 63 |
fam_sz = ns([ps self]); |
| 64 |
else |
| 65 |
disp('parsing new style')
|
| 66 |
for i=1:2:length(varargin) |
| 67 |
switch varargin{i},
|
| 68 |
case 'self', self = varargin{i+1};
|
| 69 |
case 'dps', dps = varargin{i+1};
|
| 70 |
case 'cps', cps = varargin{i+1};
|
| 71 |
case 'sz', fam_sz = varargin{i+1};
|
| 72 |
end |
| 73 |
end |
| 74 |
ps = myunion(dps, cps); |
| 75 |
args = varargin; |
| 76 |
end |
| 77 |
|
| 78 |
CPD.self = self; |
| 79 |
CPD.sizes = fam_sz; |
| 80 |
|
| 81 |
% Figure out which (if any) of the parents are discrete, and which cts, and how big they are |
| 82 |
% dps = discrete parents, cps = cts parents |
| 83 |
CPD.cps = find_equiv_posns(cps, ps); % cts parent index |
| 84 |
CPD.dps = find_equiv_posns(dps, ps); |
| 85 |
ss = fam_sz(end); |
| 86 |
psz = fam_sz(1:end-1); |
| 87 |
dpsz = prod(psz(CPD.dps)); |
| 88 |
cpsz = sum(psz(CPD.cps)); |
| 89 |
|
| 90 |
% set default params |
| 91 |
CPD.mean = randn(ss, dpsz); |
| 92 |
CPD.cov = 100*repmat(eye(ss), [1 1 dpsz]); |
| 93 |
CPD.weights = randn(ss, cpsz, dpsz); |
| 94 |
CPD.cov_type = 'full'; |
| 95 |
CPD.tied_cov = 0; |
| 96 |
CPD.clamped_mean = 0; |
| 97 |
CPD.clamped_cov = 0; |
| 98 |
CPD.clamped_weights = 0; |
| 99 |
CPD.cov_prior_weight = 0.01; |
| 100 |
|
| 101 |
nargs = length(args); |
| 102 |
if nargs > 0 |
| 103 |
if ~isstr(args{1})
|
| 104 |
% gaussian_CPD(bnet, self, mu, Sigma, W, cov_type, tied_cov, clamp_mean, clamp_cov, clamp_weights) |
| 105 |
if nargs >= 1 & ~isempty(args{1}), CPD.mean = args{1}; end
|
| 106 |
if nargs >= 2 & ~isempty(args{2}), CPD.cov = args{2}; end
|
| 107 |
if nargs >= 3 & ~isempty(args{3}), CPD.weights = args{3}; end
|
| 108 |
if nargs >= 4 & ~isempty(args{4}), CPD.cov_type = args{4}; end
|
| 109 |
if nargs >= 5 & ~isempty(args{5}) & strcmp(args{5}, 'tied'), CPD.tied_cov = 1; end
|
| 110 |
if nargs >= 6 & ~isempty(args{6}), CPD.clamped_mean = 1; end
|
| 111 |
if nargs >= 7 & ~isempty(args{7}), CPD.clamped_cov = 1; end
|
| 112 |
if nargs >= 8 & ~isempty(args{8}), CPD.clamped_weights = 1; end
|
| 113 |
else |
| 114 |
CPD = set_fields(CPD, args{:});
|
| 115 |
end |
| 116 |
end |
| 117 |
|
| 118 |
% Make sure the matrices have 1 dimension per discrete parent. |
| 119 |
% Bug fix due to Xuejing Sun 3/6/01 |
| 120 |
CPD.mean = myreshape(CPD.mean, [ss ns(dps)]); |
| 121 |
CPD.cov = myreshape(CPD.cov, [ss ss ns(dps)]); |
| 122 |
CPD.weights = myreshape(CPD.weights, [ss cpsz ns(dps)]); |
| 123 |
|
| 124 |
CPD.init_cov = CPD.cov; % we reset to this if things go wrong during learning |
| 125 |
|
| 126 |
% expected sufficient statistics |
| 127 |
CPD.Wsum = zeros(dpsz,1); |
| 128 |
CPD.WYsum = zeros(ss, dpsz); |
| 129 |
CPD.WXsum = zeros(cpsz, dpsz); |
| 130 |
CPD.WYYsum = zeros(ss, ss, dpsz); |
| 131 |
CPD.WXXsum = zeros(cpsz, cpsz, dpsz); |
| 132 |
CPD.WXYsum = zeros(cpsz, ss, dpsz); |
| 133 |
|
| 134 |
% For BIC |
| 135 |
CPD.nsamples = 0; |
| 136 |
switch CPD.cov_type |
| 137 |
case 'full', |
| 138 |
ncov_params = ss*(ss-1)/2; % since symmetric (and positive definite) |
| 139 |
case 'diag', |
| 140 |
ncov_params = ss; |
| 141 |
otherwise |
| 142 |
error(['unrecognized cov_type ' cov_type]); |
| 143 |
end |
| 144 |
% params = weights + mean + cov |
| 145 |
if CPD.tied_cov |
| 146 |
CPD.nparams = ss*cpsz*dpsz + ss*dpsz + ncov_params; |
| 147 |
else |
| 148 |
CPD.nparams = ss*cpsz*dpsz + ss*dpsz + dpsz*ncov_params; |
| 149 |
end |
| 150 |
|
| 151 |
|
| 152 |
|
| 153 |
clamped = CPD.clamped_mean & CPD.clamped_cov & CPD.clamped_weights; |
| 154 |
CPD = set_clamped(CPD, clamped); |
| 155 |
|
| 156 |
%%%%%%%%%%% |
| 157 |
|
| 158 |
function CPD = init_fields() |
| 159 |
% This ensures we define the fields in the same order |
| 160 |
% no matter whether we load an object from a file, |
| 161 |
% or create it from scratch. (Matlab requires this.) |
| 162 |
|
| 163 |
CPD.self = []; |
| 164 |
CPD.sizes = []; |
| 165 |
CPD.cps = []; |
| 166 |
CPD.dps = []; |
| 167 |
CPD.mean = []; |
| 168 |
CPD.cov = []; |
| 169 |
CPD.weights = []; |
| 170 |
CPD.clamped_mean = []; |
| 171 |
CPD.clamped_cov = []; |
| 172 |
CPD.clamped_weights = []; |
| 173 |
CPD.init_cov = []; |
| 174 |
CPD.cov_type = []; |
| 175 |
CPD.tied_cov = []; |
| 176 |
CPD.Wsum = []; |
| 177 |
CPD.WYsum = []; |
| 178 |
CPD.WXsum = []; |
| 179 |
CPD.WYYsum = []; |
| 180 |
CPD.WXXsum = []; |
| 181 |
CPD.WXYsum = []; |
| 182 |
CPD.nsamples = []; |
| 183 |
CPD.nparams = []; |
| 184 |
CPD.cov_prior_weight = []; |