To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.
root / _FullBNT / BNT / CPDs / @gmux_CPD / gmux_CPD.m @ 8:b5b38998ef3b
History | View | Annotate | Download (2.79 KB)
| 1 |
function CPD = gmux_CPD(bnet, self, varargin) |
|---|---|
| 2 |
% GMUX_CPD Make a Gaussian multiplexer node |
| 3 |
% |
| 4 |
% CPD = gmux_CPD(bnet, node, ...) is used similarly to gaussian_CPD, |
| 5 |
% except we assume there is exactly one discrete parent (call it M) |
| 6 |
% which is used to select which cts parent to pass through to the output. |
| 7 |
% i.e., we define P(Y=y|M=m, X1, ..., XK) = N(y | W(m)*x(m) + mu(m), Sigma(m)) |
| 8 |
% where Y represents this node, and the Xi's are the cts parents. |
| 9 |
% All the Xi must have the same size, and the num values for M must be K. |
| 10 |
% |
| 11 |
% Currently the params for this kind of CPD cannot be learned. |
| 12 |
% |
| 13 |
% Optional arguments [ default in brackets ] |
| 14 |
% |
| 15 |
% mean - mu(:,i) is the mean given M=i [ zeros(Y,K) ] |
| 16 |
% cov - Sigma(:,:,i) is the covariance given M=i [ repmat(1*eye(Y,Y), [1 1 K]) ] |
| 17 |
% weights - W(:,:,i) is the regression matrix given M=i [ randn(Y,X,K) ] |
| 18 |
|
| 19 |
if nargin==0 |
| 20 |
% This occurs if we are trying to load an object from a file. |
| 21 |
CPD = init_fields; |
| 22 |
clamp = 0; |
| 23 |
CPD = class(CPD, 'gmux_CPD', generic_CPD(clamp)); |
| 24 |
return; |
| 25 |
elseif isa(bnet, 'gmux_CPD') |
| 26 |
% This might occur if we are copying an object. |
| 27 |
CPD = bnet; |
| 28 |
return; |
| 29 |
end |
| 30 |
CPD = init_fields; |
| 31 |
|
| 32 |
CPD = class(CPD, 'gmux_CPD', generic_CPD(1)); |
| 33 |
|
| 34 |
ns = bnet.node_sizes; |
| 35 |
ps = parents(bnet.dag, self); |
| 36 |
dps = myintersect(ps, bnet.dnodes); |
| 37 |
cps = myintersect(ps, bnet.cnodes); |
| 38 |
fam_sz = ns([ps self]); |
| 39 |
|
| 40 |
CPD.self = self; |
| 41 |
CPD.sizes = fam_sz; |
| 42 |
|
| 43 |
% Figure out which (if any) of the parents are discrete, and which cts, and how big they are |
| 44 |
% dps = discrete parents, cps = cts parents |
| 45 |
CPD.cps = find_equiv_posns(cps, ps); % cts parent index |
| 46 |
CPD.dps = find_equiv_posns(dps, ps); |
| 47 |
if length(CPD.dps) ~= 1 |
| 48 |
error('gmux must have exactly 1 discrete parent')
|
| 49 |
end |
| 50 |
ss = fam_sz(end); |
| 51 |
cpsz = fam_sz(CPD.cps(1)); % in gaussian_CPD, cpsz = sum(fam_sz(CPD.cps)) |
| 52 |
if ~all(fam_sz(CPD.cps) == cpsz) |
| 53 |
error('all cts parents must have same size')
|
| 54 |
end |
| 55 |
dpsz = fam_sz(CPD.dps); |
| 56 |
if dpsz ~= length(cps) |
| 57 |
error(['the arity of the mux node is ' num2str(dpsz) ... |
| 58 |
' but there are ' num2str(length(cps)) ' cts parents']); |
| 59 |
end |
| 60 |
|
| 61 |
% set default params |
| 62 |
%CPD.mean = zeros(ss, 1); |
| 63 |
%CPD.cov = eye(ss); |
| 64 |
%CPD.weights = randn(ss, cpsz); |
| 65 |
CPD.mean = zeros(ss, dpsz); |
| 66 |
CPD.cov = 1*repmat(eye(ss), [1 1 dpsz]); |
| 67 |
CPD.weights = randn(ss, cpsz, dpsz); |
| 68 |
|
| 69 |
args = varargin; |
| 70 |
nargs = length(args); |
| 71 |
for i=1:2:nargs |
| 72 |
switch args{i},
|
| 73 |
case 'mean', CPD.mean = args{i+1};
|
| 74 |
case 'cov', CPD.cov = args{i+1};
|
| 75 |
case 'weights', CPD.weights = args{i+1};
|
| 76 |
otherwise, |
| 77 |
error(['invalid argument name ' args{i}]);
|
| 78 |
end |
| 79 |
end |
| 80 |
|
| 81 |
%%%%%%%%%%% |
| 82 |
|
| 83 |
function CPD = init_fields() |
| 84 |
% This ensures we define the fields in the same order |
| 85 |
% no matter whether we load an object from a file, |
| 86 |
% or create it from scratch. (Matlab requires this.) |
| 87 |
|
| 88 |
CPD.self = []; |
| 89 |
CPD.sizes = []; |
| 90 |
CPD.cps = []; |
| 91 |
CPD.dps = []; |
| 92 |
CPD.mean = []; |
| 93 |
CPD.cov = []; |
| 94 |
CPD.weights = []; |
| 95 |
|