To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @gaussian_CPD / gaussian_CPD.m @ 8:b5b38998ef3b

History | View | Annotate | Download (4.76 KB)

1
function CPD = gaussian_CPD(bnet, self, varargin)
2
% GAUSSIAN_CPD Make a conditional linear Gaussian distrib.
3
%
4
% CPD = gaussian_CPD(bnet, node, ...) will create a CPD with random parameters,
5
% where node is the number of a node in this equivalence class.
6

    
7
% To define this CPD precisely, call the continuous (cts) parents (if any) X,
8
% the discrete parents (if any) Q, and this node Y. Then the distribution on Y is:
9
% - no parents: Y ~ N(mu, Sigma)
10
% - cts parents : Y|X=x ~ N(mu + W x, Sigma)
11
% - discrete parents: Y|Q=i ~ N(mu(i), Sigma(i))
12
% - cts and discrete parents: Y|X=x,Q=i ~ N(mu(i) + W(i) x, Sigma(i))
13
%
14
% The list below gives optional arguments [default value in brackets].
15
% (Let ns(i) be the size of node i, X = ns(X), Y = ns(Y) and Q = prod(ns(Q)).)
16
% Parameters will be reshaped to the right size if necessary.
17
%
18
% mean       - mu(:,i) is the mean given Q=i [ randn(Y,Q) ]
19
% cov        - Sigma(:,:,i) is the covariance given Q=i [ repmat(100*eye(Y,Y), [1 1 Q]) ]
20
% weights    - W(:,:,i) is the regression matrix given Q=i [ randn(Y,X,Q) ]
21
% cov_type   - if 'diag', Sigma(:,:,i) is diagonal [ 'full' ]
22
% tied_cov   - if 1, we constrain Sigma(:,:,i) to be the same for all i [0]
23
% clamp_mean - if 1, we do not adjust mu(:,i) during learning [0]
24
% clamp_cov  - if 1, we do not adjust Sigma(:,:,i) during learning [0]
25
% clamp_weights - if 1, we do not adjust W(:,:,i) during learning [0]
26
% cov_prior_weight - weight given to I prior for estimating Sigma [0.01]
27
% cov_prior_entropic - if 1, we also use an entropic prior for Sigma [0]
28
%
29
% e.g., CPD = gaussian_CPD(bnet, i, 'mean', [0; 0], 'clamp_mean', 1)
30

    
31
if nargin==0
32
  % This occurs if we are trying to load an object from a file.
33
  CPD = init_fields;
34
  clamp = 0;
35
  CPD = class(CPD, 'gaussian_CPD', generic_CPD(clamp));
36
  return;
37
elseif isa(bnet, 'gaussian_CPD')
38
  % This might occur if we are copying an object.
39
  CPD = bnet;
40
  return;
41
end
42
CPD = init_fields;
43
 
44
CPD = class(CPD, 'gaussian_CPD', generic_CPD(0));
45

    
46
args = varargin;
47
ns = bnet.node_sizes;
48
ps = parents(bnet.dag, self);
49
dps = myintersect(ps, bnet.dnodes);
50
cps = myintersect(ps, bnet.cnodes);
51
fam_sz = ns([ps self]);
52

    
53
CPD.self = self;
54
CPD.sizes = fam_sz;
55

    
56
% Figure out which (if any) of the parents are discrete, and which cts, and how big they are
57
% dps = discrete parents, cps = cts parents
58
CPD.cps = find_equiv_posns(cps, ps); % cts parent index
59
CPD.dps = find_equiv_posns(dps, ps);
60
ss = fam_sz(end);
61
psz = fam_sz(1:end-1);
62
dpsz = prod(psz(CPD.dps));
63
cpsz = sum(psz(CPD.cps));
64

    
65
% set default params
66
CPD.mean = randn(ss, dpsz);
67
CPD.cov = 100*repmat(eye(ss), [1 1 dpsz]);    
68
CPD.weights = randn(ss, cpsz, dpsz);
69
CPD.cov_type = 'full';
70
CPD.tied_cov = 0;
71
CPD.clamped_mean = 0;
72
CPD.clamped_cov = 0;
73
CPD.clamped_weights = 0;
74
CPD.cov_prior_weight = 0.01;
75
CPD.cov_prior_entropic = 0;
76
nargs = length(args);
77
if nargs > 0
78
  CPD = set_fields(CPD, args{:});
79
end
80

    
81
% Make sure the matrices have 1 dimension per discrete parent.
82
% Bug fix due to Xuejing Sun 3/6/01
83
CPD.mean = myreshape(CPD.mean, [ss ns(dps)]);
84
CPD.cov = myreshape(CPD.cov, [ss ss ns(dps)]);
85
CPD.weights = myreshape(CPD.weights, [ss cpsz ns(dps)]);
86

    
87
% Precompute indices into block structured  matrices
88
% to speed up CPD_to_lambda_msg and CPD_to_pi
89
cpsizes = CPD.sizes(CPD.cps);
90
CPD.cps_block_ndx = cell(1, length(cps));
91
for i=1:length(cps)
92
  CPD.cps_block_ndx{i} = block(i, cpsizes);
93
end
94

    
95
%%%%%%%%%%% 
96
% Learning stuff
97

    
98
% expected sufficient statistics 
99
CPD.Wsum = zeros(dpsz,1);
100
CPD.WYsum = zeros(ss, dpsz);
101
CPD.WXsum = zeros(cpsz, dpsz);
102
CPD.WYYsum = zeros(ss, ss, dpsz);
103
CPD.WXXsum = zeros(cpsz, cpsz, dpsz);
104
CPD.WXYsum = zeros(cpsz, ss, dpsz);
105

    
106
% For BIC
107
CPD.nsamples = 0;
108
switch CPD.cov_type
109
 case 'full',
110
  % since symmetric 
111
    %ncov_params = ss*(ss-1)/2; 
112
    ncov_params = ss*(ss+1)/2; 
113
  case 'diag',
114
    ncov_params = ss;
115
  otherwise
116
    error(['unrecognized cov_type ' cov_type]);
117
end
118
% params = weights + mean + cov
119
if CPD.tied_cov
120
  CPD.nparams = ss*cpsz*dpsz + ss*dpsz + ncov_params;
121
else
122
  CPD.nparams = ss*cpsz*dpsz + ss*dpsz + dpsz*ncov_params;
123
end
124

    
125
% for speeding up maximize_params
126
CPD.useC = exist('rep_mult');
127

    
128
clamped = CPD.clamped_mean & CPD.clamped_cov & CPD.clamped_weights;
129
CPD = set_clamped(CPD, clamped);
130

    
131
%%%%%%%%%%%
132

    
133
function CPD = init_fields()
134
% This ensures we define the fields in the same order 
135
% no matter whether we load an object from a file,
136
% or create it from scratch. (Matlab requires this.)
137

    
138
CPD.self = [];
139
CPD.sizes = [];
140
CPD.cps = [];
141
CPD.dps = [];
142
CPD.mean = [];
143
CPD.cov = [];
144
CPD.weights = [];
145
CPD.clamped_mean = [];
146
CPD.clamped_cov = [];
147
CPD.clamped_weights = [];
148
CPD.cov_type = [];
149
CPD.tied_cov = [];
150
CPD.Wsum = [];
151
CPD.WYsum = [];
152
CPD.WXsum = [];
153
CPD.WYYsum = [];
154
CPD.WXXsum = [];
155
CPD.WXYsum = [];
156
CPD.nsamples = [];
157
CPD.nparams = [];            
158
CPD.cov_prior_weight = [];
159
CPD.cov_prior_entropic = [];
160
CPD.useC = [];
161
CPD.cps_block_ndx = [];