annotate toolboxes/FullBNT-1.0.7/bnt/CPDs/@gaussian_CPD/gaussian_CPD.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function CPD = gaussian_CPD(bnet, self, varargin)
wolffd@0 2 % GAUSSIAN_CPD Make a conditional linear Gaussian distrib.
wolffd@0 3 %
wolffd@0 4 % CPD = gaussian_CPD(bnet, node, ...) will create a CPD with random parameters,
wolffd@0 5 % where node is the number of a node in this equivalence class.
wolffd@0 6
wolffd@0 7 % To define this CPD precisely, call the continuous (cts) parents (if any) X,
wolffd@0 8 % the discrete parents (if any) Q, and this node Y. Then the distribution on Y is:
wolffd@0 9 % - no parents: Y ~ N(mu, Sigma)
wolffd@0 10 % - cts parents : Y|X=x ~ N(mu + W x, Sigma)
wolffd@0 11 % - discrete parents: Y|Q=i ~ N(mu(i), Sigma(i))
wolffd@0 12 % - cts and discrete parents: Y|X=x,Q=i ~ N(mu(i) + W(i) x, Sigma(i))
wolffd@0 13 %
wolffd@0 14 % The list below gives optional arguments [default value in brackets].
wolffd@0 15 % (Let ns(i) be the size of node i, X = ns(X), Y = ns(Y) and Q = prod(ns(Q)).)
wolffd@0 16 % Parameters will be reshaped to the right size if necessary.
wolffd@0 17 %
wolffd@0 18 % mean - mu(:,i) is the mean given Q=i [ randn(Y,Q) ]
wolffd@0 19 % cov - Sigma(:,:,i) is the covariance given Q=i [ repmat(100*eye(Y,Y), [1 1 Q]) ]
wolffd@0 20 % weights - W(:,:,i) is the regression matrix given Q=i [ randn(Y,X,Q) ]
wolffd@0 21 % cov_type - if 'diag', Sigma(:,:,i) is diagonal [ 'full' ]
wolffd@0 22 % tied_cov - if 1, we constrain Sigma(:,:,i) to be the same for all i [0]
wolffd@0 23 % clamp_mean - if 1, we do not adjust mu(:,i) during learning [0]
wolffd@0 24 % clamp_cov - if 1, we do not adjust Sigma(:,:,i) during learning [0]
wolffd@0 25 % clamp_weights - if 1, we do not adjust W(:,:,i) during learning [0]
wolffd@0 26 % cov_prior_weight - weight given to I prior for estimating Sigma [0.01]
wolffd@0 27 % cov_prior_entropic - if 1, we also use an entropic prior for Sigma [0]
wolffd@0 28 %
wolffd@0 29 % e.g., CPD = gaussian_CPD(bnet, i, 'mean', [0; 0], 'clamp_mean', 1)
wolffd@0 30
wolffd@0 31 if nargin==0
wolffd@0 32 % This occurs if we are trying to load an object from a file.
wolffd@0 33 CPD = init_fields;
wolffd@0 34 clamp = 0;
wolffd@0 35 CPD = class(CPD, 'gaussian_CPD', generic_CPD(clamp));
wolffd@0 36 return;
wolffd@0 37 elseif isa(bnet, 'gaussian_CPD')
wolffd@0 38 % This might occur if we are copying an object.
wolffd@0 39 CPD = bnet;
wolffd@0 40 return;
wolffd@0 41 end
wolffd@0 42 CPD = init_fields;
wolffd@0 43
wolffd@0 44 CPD = class(CPD, 'gaussian_CPD', generic_CPD(0));
wolffd@0 45
wolffd@0 46 args = varargin;
wolffd@0 47 ns = bnet.node_sizes;
wolffd@0 48 ps = parents(bnet.dag, self);
wolffd@0 49 dps = myintersect(ps, bnet.dnodes);
wolffd@0 50 cps = myintersect(ps, bnet.cnodes);
wolffd@0 51 fam_sz = ns([ps self]);
wolffd@0 52
wolffd@0 53 CPD.self = self;
wolffd@0 54 CPD.sizes = fam_sz;
wolffd@0 55
wolffd@0 56 % Figure out which (if any) of the parents are discrete, and which cts, and how big they are
wolffd@0 57 % dps = discrete parents, cps = cts parents
wolffd@0 58 CPD.cps = find_equiv_posns(cps, ps); % cts parent index
wolffd@0 59 CPD.dps = find_equiv_posns(dps, ps);
wolffd@0 60 ss = fam_sz(end);
wolffd@0 61 psz = fam_sz(1:end-1);
wolffd@0 62 dpsz = prod(psz(CPD.dps));
wolffd@0 63 cpsz = sum(psz(CPD.cps));
wolffd@0 64
wolffd@0 65 % set default params
wolffd@0 66 CPD.mean = randn(ss, dpsz);
wolffd@0 67 CPD.cov = 100*repmat(eye(ss), [1 1 dpsz]);
wolffd@0 68 CPD.weights = randn(ss, cpsz, dpsz);
wolffd@0 69 CPD.cov_type = 'full';
wolffd@0 70 CPD.tied_cov = 0;
wolffd@0 71 CPD.clamped_mean = 0;
wolffd@0 72 CPD.clamped_cov = 0;
wolffd@0 73 CPD.clamped_weights = 0;
wolffd@0 74 CPD.cov_prior_weight = 0.01;
wolffd@0 75 CPD.cov_prior_entropic = 0;
wolffd@0 76 nargs = length(args);
wolffd@0 77 if nargs > 0
wolffd@0 78 CPD = set_fields(CPD, args{:});
wolffd@0 79 end
wolffd@0 80
wolffd@0 81 % Make sure the matrices have 1 dimension per discrete parent.
wolffd@0 82 % Bug fix due to Xuejing Sun 3/6/01
wolffd@0 83 CPD.mean = myreshape(CPD.mean, [ss ns(dps)]);
wolffd@0 84 CPD.cov = myreshape(CPD.cov, [ss ss ns(dps)]);
wolffd@0 85 CPD.weights = myreshape(CPD.weights, [ss cpsz ns(dps)]);
wolffd@0 86
wolffd@0 87 % Precompute indices into block structured matrices
wolffd@0 88 % to speed up CPD_to_lambda_msg and CPD_to_pi
wolffd@0 89 cpsizes = CPD.sizes(CPD.cps);
wolffd@0 90 CPD.cps_block_ndx = cell(1, length(cps));
wolffd@0 91 for i=1:length(cps)
wolffd@0 92 CPD.cps_block_ndx{i} = block(i, cpsizes);
wolffd@0 93 end
wolffd@0 94
wolffd@0 95 %%%%%%%%%%%
wolffd@0 96 % Learning stuff
wolffd@0 97
wolffd@0 98 % expected sufficient statistics
wolffd@0 99 CPD.Wsum = zeros(dpsz,1);
wolffd@0 100 CPD.WYsum = zeros(ss, dpsz);
wolffd@0 101 CPD.WXsum = zeros(cpsz, dpsz);
wolffd@0 102 CPD.WYYsum = zeros(ss, ss, dpsz);
wolffd@0 103 CPD.WXXsum = zeros(cpsz, cpsz, dpsz);
wolffd@0 104 CPD.WXYsum = zeros(cpsz, ss, dpsz);
wolffd@0 105
wolffd@0 106 % For BIC
wolffd@0 107 CPD.nsamples = 0;
wolffd@0 108 switch CPD.cov_type
wolffd@0 109 case 'full',
wolffd@0 110 % since symmetric
wolffd@0 111 %ncov_params = ss*(ss-1)/2;
wolffd@0 112 ncov_params = ss*(ss+1)/2;
wolffd@0 113 case 'diag',
wolffd@0 114 ncov_params = ss;
wolffd@0 115 otherwise
wolffd@0 116 error(['unrecognized cov_type ' cov_type]);
wolffd@0 117 end
wolffd@0 118 % params = weights + mean + cov
wolffd@0 119 if CPD.tied_cov
wolffd@0 120 CPD.nparams = ss*cpsz*dpsz + ss*dpsz + ncov_params;
wolffd@0 121 else
wolffd@0 122 CPD.nparams = ss*cpsz*dpsz + ss*dpsz + dpsz*ncov_params;
wolffd@0 123 end
wolffd@0 124
wolffd@0 125 % for speeding up maximize_params
wolffd@0 126 CPD.useC = exist('rep_mult');
wolffd@0 127
wolffd@0 128 clamped = CPD.clamped_mean & CPD.clamped_cov & CPD.clamped_weights;
wolffd@0 129 CPD = set_clamped(CPD, clamped);
wolffd@0 130
wolffd@0 131 %%%%%%%%%%%
wolffd@0 132
wolffd@0 133 function CPD = init_fields()
wolffd@0 134 % This ensures we define the fields in the same order
wolffd@0 135 % no matter whether we load an object from a file,
wolffd@0 136 % or create it from scratch. (Matlab requires this.)
wolffd@0 137
wolffd@0 138 CPD.self = [];
wolffd@0 139 CPD.sizes = [];
wolffd@0 140 CPD.cps = [];
wolffd@0 141 CPD.dps = [];
wolffd@0 142 CPD.mean = [];
wolffd@0 143 CPD.cov = [];
wolffd@0 144 CPD.weights = [];
wolffd@0 145 CPD.clamped_mean = [];
wolffd@0 146 CPD.clamped_cov = [];
wolffd@0 147 CPD.clamped_weights = [];
wolffd@0 148 CPD.cov_type = [];
wolffd@0 149 CPD.tied_cov = [];
wolffd@0 150 CPD.Wsum = [];
wolffd@0 151 CPD.WYsum = [];
wolffd@0 152 CPD.WXsum = [];
wolffd@0 153 CPD.WYYsum = [];
wolffd@0 154 CPD.WXXsum = [];
wolffd@0 155 CPD.WXYsum = [];
wolffd@0 156 CPD.nsamples = [];
wolffd@0 157 CPD.nparams = [];
wolffd@0 158 CPD.cov_prior_weight = [];
wolffd@0 159 CPD.cov_prior_entropic = [];
wolffd@0 160 CPD.useC = [];
wolffd@0 161 CPD.cps_block_ndx = [];