wolffd@0
|
1 function CPD = gmux_CPD(bnet, self, varargin)
|
wolffd@0
|
2 % GMUX_CPD Make a Gaussian multiplexer node
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % CPD = gmux_CPD(bnet, node, ...) is used similarly to gaussian_CPD,
|
wolffd@0
|
5 % except we assume there is exactly one discrete parent (call it M)
|
wolffd@0
|
6 % which is used to select which cts parent to pass through to the output.
|
wolffd@0
|
7 % i.e., we define P(Y=y|M=m, X1, ..., XK) = N(y | W*x(m) + mu, Sigma)
|
wolffd@0
|
8 % where Y represents this node, and the Xi's are the cts parents.
|
wolffd@0
|
9 % All the Xi must have the same size, and the num values for M must be K.
|
wolffd@0
|
10 %
|
wolffd@0
|
11 % Currently the params for this kind of CPD cannot be learned.
|
wolffd@0
|
12 %
|
wolffd@0
|
13 % Optional arguments [ default in brackets ]
|
wolffd@0
|
14 %
|
wolffd@0
|
15 % mean - mu [zeros(Y,1)]
|
wolffd@0
|
16 % cov - Sigma [eye(Y,Y)]
|
wolffd@0
|
17 % weights - W [ randn(Y,X) ]
|
wolffd@0
|
18
|
wolffd@0
|
19 if nargin==0
|
wolffd@0
|
20 % This occurs if we are trying to load an object from a file.
|
wolffd@0
|
21 CPD = init_fields;
|
wolffd@0
|
22 clamp = 0;
|
wolffd@0
|
23 CPD = class(CPD, 'gmux_CPD', generic_CPD(clamp));
|
wolffd@0
|
24 return;
|
wolffd@0
|
25 elseif isa(bnet, 'gmux_CPD')
|
wolffd@0
|
26 % This might occur if we are copying an object.
|
wolffd@0
|
27 CPD = bnet;
|
wolffd@0
|
28 return;
|
wolffd@0
|
29 end
|
wolffd@0
|
30 CPD = init_fields;
|
wolffd@0
|
31
|
wolffd@0
|
32 CPD = class(CPD, 'gmux_CPD', generic_CPD(1));
|
wolffd@0
|
33
|
wolffd@0
|
34 ns = bnet.node_sizes;
|
wolffd@0
|
35 ps = parents(bnet.dag, self);
|
wolffd@0
|
36 dps = myintersect(ps, bnet.dnodes);
|
wolffd@0
|
37 cps = myintersect(ps, bnet.cnodes);
|
wolffd@0
|
38 fam_sz = ns([ps self]);
|
wolffd@0
|
39
|
wolffd@0
|
40 CPD.self = self;
|
wolffd@0
|
41 CPD.sizes = fam_sz;
|
wolffd@0
|
42
|
wolffd@0
|
43 % Figure out which (if any) of the parents are discrete, and which cts, and how big they are
|
wolffd@0
|
44 % dps = discrete parents, cps = cts parents
|
wolffd@0
|
45 CPD.cps = find_equiv_posns(cps, ps); % cts parent index
|
wolffd@0
|
46 CPD.dps = find_equiv_posns(dps, ps);
|
wolffd@0
|
47 if length(CPD.dps) ~= 1
|
wolffd@0
|
48 error('gmux must have exactly 1 discrete parent')
|
wolffd@0
|
49 end
|
wolffd@0
|
50 ss = fam_sz(end);
|
wolffd@0
|
51 cpsz = fam_sz(CPD.cps(1)); % in gaussian_CPD, cpsz = sum(fam_sz(CPD.cps))
|
wolffd@0
|
52 if ~all(fam_sz(CPD.cps) == cpsz)
|
wolffd@0
|
53 error('all cts parents must have same size')
|
wolffd@0
|
54 end
|
wolffd@0
|
55 dpsz = fam_sz(CPD.dps);
|
wolffd@0
|
56 if dpsz ~= length(cps)
|
wolffd@0
|
57 error(['the arity of the mux node is ' num2str(dpsz) ...
|
wolffd@0
|
58 ' but there are ' num2str(length(cps)) ' cts parents']);
|
wolffd@0
|
59 end
|
wolffd@0
|
60
|
wolffd@0
|
61 % set default params
|
wolffd@0
|
62 CPD.mean = zeros(ss, 1);
|
wolffd@0
|
63 CPD.cov = eye(ss);
|
wolffd@0
|
64 CPD.weights = randn(ss, cpsz);
|
wolffd@0
|
65
|
wolffd@0
|
66 args = varargin;
|
wolffd@0
|
67 nargs = length(args);
|
wolffd@0
|
68 for i=1:2:nargs
|
wolffd@0
|
69 switch args{i},
|
wolffd@0
|
70 case 'mean', CPD.mean = args{i+1};
|
wolffd@0
|
71 case 'cov', CPD.cov = args{i+1};
|
wolffd@0
|
72 case 'weights', CPD.weights = args{i+1};
|
wolffd@0
|
73 otherwise,
|
wolffd@0
|
74 error(['invalid argument name ' args{i}]);
|
wolffd@0
|
75 end
|
wolffd@0
|
76 end
|
wolffd@0
|
77
|
wolffd@0
|
78 %%%%%%%%%%%
|
wolffd@0
|
79
|
wolffd@0
|
80 function CPD = init_fields()
|
wolffd@0
|
81 % This ensures we define the fields in the same order
|
wolffd@0
|
82 % no matter whether we load an object from a file,
|
wolffd@0
|
83 % or create it from scratch. (Matlab requires this.)
|
wolffd@0
|
84
|
wolffd@0
|
85 CPD.self = [];
|
wolffd@0
|
86 CPD.sizes = [];
|
wolffd@0
|
87 CPD.cps = [];
|
wolffd@0
|
88 CPD.dps = [];
|
wolffd@0
|
89 CPD.mean = [];
|
wolffd@0
|
90 CPD.cov = [];
|
wolffd@0
|
91 CPD.weights = [];
|
wolffd@0
|
92
|