annotate toolboxes/FullBNT-1.0.7/bnt/CPDs/@gmux_CPD/gmux_CPD.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function CPD = gmux_CPD(bnet, self, varargin)
wolffd@0 2 % GMUX_CPD Make a Gaussian multiplexer node
wolffd@0 3 %
wolffd@0 4 % CPD = gmux_CPD(bnet, node, ...) is used similarly to gaussian_CPD,
wolffd@0 5 % except we assume there is exactly one discrete parent (call it M)
wolffd@0 6 % which is used to select which cts parent to pass through to the output.
wolffd@0 7 % i.e., we define P(Y=y|M=m, X1, ..., XK) = N(y | W(m)*x(m) + mu(m), Sigma(m))
wolffd@0 8 % where Y represents this node, and the Xi's are the cts parents.
wolffd@0 9 % All the Xi must have the same size, and the num values for M must be K.
wolffd@0 10 %
wolffd@0 11 % Currently the params for this kind of CPD cannot be learned.
wolffd@0 12 %
wolffd@0 13 % Optional arguments [ default in brackets ]
wolffd@0 14 %
wolffd@0 15 % mean - mu(:,i) is the mean given M=i [ zeros(Y,K) ]
wolffd@0 16 % cov - Sigma(:,:,i) is the covariance given M=i [ repmat(1*eye(Y,Y), [1 1 K]) ]
wolffd@0 17 % weights - W(:,:,i) is the regression matrix given M=i [ randn(Y,X,K) ]
wolffd@0 18
wolffd@0 19 if nargin==0
wolffd@0 20 % This occurs if we are trying to load an object from a file.
wolffd@0 21 CPD = init_fields;
wolffd@0 22 clamp = 0;
wolffd@0 23 CPD = class(CPD, 'gmux_CPD', generic_CPD(clamp));
wolffd@0 24 return;
wolffd@0 25 elseif isa(bnet, 'gmux_CPD')
wolffd@0 26 % This might occur if we are copying an object.
wolffd@0 27 CPD = bnet;
wolffd@0 28 return;
wolffd@0 29 end
wolffd@0 30 CPD = init_fields;
wolffd@0 31
wolffd@0 32 CPD = class(CPD, 'gmux_CPD', generic_CPD(1));
wolffd@0 33
wolffd@0 34 ns = bnet.node_sizes;
wolffd@0 35 ps = parents(bnet.dag, self);
wolffd@0 36 dps = myintersect(ps, bnet.dnodes);
wolffd@0 37 cps = myintersect(ps, bnet.cnodes);
wolffd@0 38 fam_sz = ns([ps self]);
wolffd@0 39
wolffd@0 40 CPD.self = self;
wolffd@0 41 CPD.sizes = fam_sz;
wolffd@0 42
wolffd@0 43 % Figure out which (if any) of the parents are discrete, and which cts, and how big they are
wolffd@0 44 % dps = discrete parents, cps = cts parents
wolffd@0 45 CPD.cps = find_equiv_posns(cps, ps); % cts parent index
wolffd@0 46 CPD.dps = find_equiv_posns(dps, ps);
wolffd@0 47 if length(CPD.dps) ~= 1
wolffd@0 48 error('gmux must have exactly 1 discrete parent')
wolffd@0 49 end
wolffd@0 50 ss = fam_sz(end);
wolffd@0 51 cpsz = fam_sz(CPD.cps(1)); % in gaussian_CPD, cpsz = sum(fam_sz(CPD.cps))
wolffd@0 52 if ~all(fam_sz(CPD.cps) == cpsz)
wolffd@0 53 error('all cts parents must have same size')
wolffd@0 54 end
wolffd@0 55 dpsz = fam_sz(CPD.dps);
wolffd@0 56 if dpsz ~= length(cps)
wolffd@0 57 error(['the arity of the mux node is ' num2str(dpsz) ...
wolffd@0 58 ' but there are ' num2str(length(cps)) ' cts parents']);
wolffd@0 59 end
wolffd@0 60
wolffd@0 61 % set default params
wolffd@0 62 %CPD.mean = zeros(ss, 1);
wolffd@0 63 %CPD.cov = eye(ss);
wolffd@0 64 %CPD.weights = randn(ss, cpsz);
wolffd@0 65 CPD.mean = zeros(ss, dpsz);
wolffd@0 66 CPD.cov = 1*repmat(eye(ss), [1 1 dpsz]);
wolffd@0 67 CPD.weights = randn(ss, cpsz, dpsz);
wolffd@0 68
wolffd@0 69 args = varargin;
wolffd@0 70 nargs = length(args);
wolffd@0 71 for i=1:2:nargs
wolffd@0 72 switch args{i},
wolffd@0 73 case 'mean', CPD.mean = args{i+1};
wolffd@0 74 case 'cov', CPD.cov = args{i+1};
wolffd@0 75 case 'weights', CPD.weights = args{i+1};
wolffd@0 76 otherwise,
wolffd@0 77 error(['invalid argument name ' args{i}]);
wolffd@0 78 end
wolffd@0 79 end
wolffd@0 80
wolffd@0 81 %%%%%%%%%%%
wolffd@0 82
wolffd@0 83 function CPD = init_fields()
wolffd@0 84 % This ensures we define the fields in the same order
wolffd@0 85 % no matter whether we load an object from a file,
wolffd@0 86 % or create it from scratch. (Matlab requires this.)
wolffd@0 87
wolffd@0 88 CPD.self = [];
wolffd@0 89 CPD.sizes = [];
wolffd@0 90 CPD.cps = [];
wolffd@0 91 CPD.dps = [];
wolffd@0 92 CPD.mean = [];
wolffd@0 93 CPD.cov = [];
wolffd@0 94 CPD.weights = [];
wolffd@0 95