To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @gmux_CPD / Old / gmux_CPD.m @ 8:b5b38998ef3b

History | View | Annotate | Download (2.55 KB)

1
function CPD = gmux_CPD(bnet, self, varargin)
2
% GMUX_CPD Make a Gaussian multiplexer node
3
%
4
% CPD = gmux_CPD(bnet, node, ...) is used similarly to gaussian_CPD,
5
% except we assume there is exactly one discrete parent (call it M)
6
% which is used to select which cts parent to pass through to the output.
7
% i.e., we define P(Y=y|M=m, X1, ..., XK) = N(y | W*x(m) + mu, Sigma)
8
% where Y represents this node, and the Xi's are the cts parents.
9
% All the Xi must have the same size, and the num values for M must be K.
10
%
11
% Currently the params for this kind of CPD cannot be learned.
12
%
13
% Optional arguments [ default in brackets ]
14
%
15
% mean       - mu  [zeros(Y,1)]
16
% cov        - Sigma [eye(Y,Y)]
17
% weights    - W [ randn(Y,X) ]
18

    
19
if nargin==0
20
  % This occurs if we are trying to load an object from a file.
21
  CPD = init_fields;
22
  clamp = 0;
23
  CPD = class(CPD, 'gmux_CPD', generic_CPD(clamp));
24
  return;
25
elseif isa(bnet, 'gmux_CPD')
26
  % This might occur if we are copying an object.
27
  CPD = bnet;
28
  return;
29
end
30
CPD = init_fields;
31
 
32
CPD = class(CPD, 'gmux_CPD', generic_CPD(1));
33

    
34
ns = bnet.node_sizes;
35
ps = parents(bnet.dag, self);
36
dps = myintersect(ps, bnet.dnodes);
37
cps = myintersect(ps, bnet.cnodes);
38
fam_sz = ns([ps self]);
39

    
40
CPD.self = self;
41
CPD.sizes = fam_sz;
42

    
43
% Figure out which (if any) of the parents are discrete, and which cts, and how big they are
44
% dps = discrete parents, cps = cts parents
45
CPD.cps = find_equiv_posns(cps, ps); % cts parent index
46
CPD.dps = find_equiv_posns(dps, ps);
47
if length(CPD.dps) ~= 1
48
  error('gmux must have exactly 1 discrete parent')
49
end
50
ss = fam_sz(end);
51
cpsz = fam_sz(CPD.cps(1)); % in gaussian_CPD, cpsz = sum(fam_sz(CPD.cps))
52
if ~all(fam_sz(CPD.cps) == cpsz)
53
  error('all cts parents must have same size')
54
end
55
dpsz = fam_sz(CPD.dps);
56
if dpsz ~= length(cps)
57
  error(['the arity of the mux node is ' num2str(dpsz) ...
58
	 ' but there are ' num2str(length(cps)) ' cts parents']);
59
end
60

    
61
% set default params
62
CPD.mean = zeros(ss, 1);
63
CPD.cov = eye(ss);
64
CPD.weights = randn(ss, cpsz);
65

    
66
args = varargin;
67
nargs = length(args);
68
for i=1:2:nargs
69
  switch args{i},
70
   case 'mean',        CPD.mean = args{i+1}; 
71
   case 'cov',         CPD.cov = args{i+1}; 
72
   case 'weights',    CPD.weights = args{i+1}; 
73
   otherwise,  
74
    error(['invalid argument name ' args{i}]);
75
  end
76
end
77

    
78
%%%%%%%%%%%
79

    
80
function CPD = init_fields()
81
% This ensures we define the fields in the same order 
82
% no matter whether we load an object from a file,
83
% or create it from scratch. (Matlab requires this.)
84

    
85
CPD.self = [];
86
CPD.sizes = [];
87
CPD.cps = [];
88
CPD.dps = [];
89
CPD.mean = [];
90
CPD.cov = [];
91
CPD.weights = [];
92