Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@gmux_CPD/gmux_CPD.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPD = gmux_CPD(bnet, self, varargin) | |
2 % GMUX_CPD Make a Gaussian multiplexer node | |
3 % | |
4 % CPD = gmux_CPD(bnet, node, ...) is used similarly to gaussian_CPD, | |
5 % except we assume there is exactly one discrete parent (call it M) | |
6 % which is used to select which cts parent to pass through to the output. | |
7 % i.e., we define P(Y=y|M=m, X1, ..., XK) = N(y | W(m)*x(m) + mu(m), Sigma(m)) | |
8 % where Y represents this node, and the Xi's are the cts parents. | |
9 % All the Xi must have the same size, and the num values for M must be K. | |
10 % | |
11 % Currently the params for this kind of CPD cannot be learned. | |
12 % | |
13 % Optional arguments [ default in brackets ] | |
14 % | |
15 % mean - mu(:,i) is the mean given M=i [ zeros(Y,K) ] | |
16 % cov - Sigma(:,:,i) is the covariance given M=i [ repmat(1*eye(Y,Y), [1 1 K]) ] | |
17 % weights - W(:,:,i) is the regression matrix given M=i [ randn(Y,X,K) ] | |
18 | |
19 if nargin==0 | |
20 % This occurs if we are trying to load an object from a file. | |
21 CPD = init_fields; | |
22 clamp = 0; | |
23 CPD = class(CPD, 'gmux_CPD', generic_CPD(clamp)); | |
24 return; | |
25 elseif isa(bnet, 'gmux_CPD') | |
26 % This might occur if we are copying an object. | |
27 CPD = bnet; | |
28 return; | |
29 end | |
30 CPD = init_fields; | |
31 | |
32 CPD = class(CPD, 'gmux_CPD', generic_CPD(1)); | |
33 | |
34 ns = bnet.node_sizes; | |
35 ps = parents(bnet.dag, self); | |
36 dps = myintersect(ps, bnet.dnodes); | |
37 cps = myintersect(ps, bnet.cnodes); | |
38 fam_sz = ns([ps self]); | |
39 | |
40 CPD.self = self; | |
41 CPD.sizes = fam_sz; | |
42 | |
43 % Figure out which (if any) of the parents are discrete, and which cts, and how big they are | |
44 % dps = discrete parents, cps = cts parents | |
45 CPD.cps = find_equiv_posns(cps, ps); % cts parent index | |
46 CPD.dps = find_equiv_posns(dps, ps); | |
47 if length(CPD.dps) ~= 1 | |
48 error('gmux must have exactly 1 discrete parent') | |
49 end | |
50 ss = fam_sz(end); | |
51 cpsz = fam_sz(CPD.cps(1)); % in gaussian_CPD, cpsz = sum(fam_sz(CPD.cps)) | |
52 if ~all(fam_sz(CPD.cps) == cpsz) | |
53 error('all cts parents must have same size') | |
54 end | |
55 dpsz = fam_sz(CPD.dps); | |
56 if dpsz ~= length(cps) | |
57 error(['the arity of the mux node is ' num2str(dpsz) ... | |
58 ' but there are ' num2str(length(cps)) ' cts parents']); | |
59 end | |
60 | |
61 % set default params | |
62 %CPD.mean = zeros(ss, 1); | |
63 %CPD.cov = eye(ss); | |
64 %CPD.weights = randn(ss, cpsz); | |
65 CPD.mean = zeros(ss, dpsz); | |
66 CPD.cov = 1*repmat(eye(ss), [1 1 dpsz]); | |
67 CPD.weights = randn(ss, cpsz, dpsz); | |
68 | |
69 args = varargin; | |
70 nargs = length(args); | |
71 for i=1:2:nargs | |
72 switch args{i}, | |
73 case 'mean', CPD.mean = args{i+1}; | |
74 case 'cov', CPD.cov = args{i+1}; | |
75 case 'weights', CPD.weights = args{i+1}; | |
76 otherwise, | |
77 error(['invalid argument name ' args{i}]); | |
78 end | |
79 end | |
80 | |
81 %%%%%%%%%%% | |
82 | |
83 function CPD = init_fields() | |
84 % This ensures we define the fields in the same order | |
85 % no matter whether we load an object from a file, | |
86 % or create it from scratch. (Matlab requires this.) | |
87 | |
88 CPD.self = []; | |
89 CPD.sizes = []; | |
90 CPD.cps = []; | |
91 CPD.dps = []; | |
92 CPD.mean = []; | |
93 CPD.cov = []; | |
94 CPD.weights = []; | |
95 |