To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @gaussian_CPD / Old / gaussian_CPD.m @ 8:b5b38998ef3b

History | View | Annotate | Download (5.9 KB)

1
function CPD = gaussian_CPD(varargin)
2
% GAUSSIAN_CPD Make a conditional linear Gaussian distrib.
3
%
4
% To define this CPD precisely, call the continuous (cts) parents (if any) X,
5
% the discrete parents (if any) Q, and this node Y. Then the distribution on Y is:
6
% - no parents: Y ~ N(mu, Sigma)
7
% - cts parents : Y|X=x ~ N(mu + W x, Sigma)
8
% - discrete parents: Y|Q=i ~ N(mu(i), Sigma(i))
9
% - cts and discrete parents: Y|X=x,Q=i ~ N(mu(i) + W(i) x, Sigma(i))
10
%
11
% CPD = gaussian_CPD(bnet, node, ...) will create a CPD with random parameters,
12
% where node is the number of a node in this equivalence class.
13
%
14
% The list below gives optional arguments [default value in brackets].
15
% (Let ns(i) be the size of node i, X = ns(X), Y = ns(Y) and Q = prod(ns(Q)).)
16
%
17
% mean       - mu(:,i) is the mean given Q=i [ randn(Y,Q) ]
18
% cov        - Sigma(:,:,i) is the covariance given Q=i [ repmat(eye(Y,Y), [1 1 Q]) ]
19
% weights    - W(:,:,i) is the regression matrix given Q=i [ randn(Y,X,Q) ]
20
% cov_type   - if 'diag', Sigma(:,:,i) is diagonal [ 'full' ]
21
% tied_cov   - if 1, we constrain Sigma(:,:,i) to be the same for all i [0]
22
% clamp_mean - if 1, we do not adjust mu(:,i) during learning [0]
23
% clamp_cov  - if 1, we do not adjust Sigma(:,:,i) during learning [0]
24
% clamp_weights - if 1, we do not adjust W(:,:,i) during learning [0]
25
% cov_prior_weight - weight given to I prior for estimating Sigma [0.01]
26
%
27
% e.g., CPD = gaussian_CPD(bnet, i, 'mean', [0; 0], 'clamp_mean', 'yes')
28
%
29
% For backwards compatibility with BNT2, you can also specify the parameters in the following order
30
%   CPD = gaussian_CPD(bnet, self, mu, Sigma, W, cov_type, tied_cov, clamp_mean, clamp_cov, clamp_weight)
31
%
32
% Sometimes it is useful to create an "isolated" CPD, without needing to pass in a bnet.
33
% In this case, you must specify the discrete and cts parents (dps, cps) and the family sizes, followed
34
% by the optional arguments above:
35
%   CPD = gaussian_CPD('self', i, 'dps', dps, 'cps', cps, 'sz', fam_size, ...)
36

    
37

    
38
if nargin==0
39
  % This occurs if we are trying to load an object from a file.
40
  CPD = init_fields;
41
  clamp = 0;
42
  CPD = class(CPD, 'gaussian_CPD', generic_CPD(clamp));
43
  return;
44
elseif isa(varargin{1}, 'gaussian_CPD')
45
  % This might occur if we are copying an object.
46
  CPD = varargin{1};
47
  return;
48
end
49
CPD = init_fields;
50
 
51
CPD = class(CPD, 'gaussian_CPD', generic_CPD(0));
52

    
53

    
54
% parse mandatory arguments
55
if ~isstr(varargin{1}) % pass in bnet
56
  bnet = varargin{1};
57
  self = varargin{2};
58
  args = varargin(3:end);
59
  ns = bnet.node_sizes;
60
  ps = parents(bnet.dag, self);
61
  dps = myintersect(ps, bnet.dnodes);
62
  cps = myintersect(ps, bnet.cnodes);
63
  fam_sz = ns([ps self]);
64
else
65
  disp('parsing new style')
66
  for i=1:2:length(varargin)
67
    switch varargin{i},
68
     case 'self', self = varargin{i+1}; 
69
     case 'dps',  dps = varargin{i+1};
70
     case 'cps',  cps = varargin{i+1};
71
     case 'sz',   fam_sz = varargin{i+1};
72
    end
73
  end
74
  ps = myunion(dps, cps);
75
  args = varargin;
76
end
77

    
78
CPD.self = self;
79
CPD.sizes = fam_sz;
80

    
81
% Figure out which (if any) of the parents are discrete, and which cts, and how big they are
82
% dps = discrete parents, cps = cts parents
83
CPD.cps = find_equiv_posns(cps, ps); % cts parent index
84
CPD.dps = find_equiv_posns(dps, ps);
85
ss = fam_sz(end);
86
psz = fam_sz(1:end-1);
87
dpsz = prod(psz(CPD.dps));
88
cpsz = sum(psz(CPD.cps));
89

    
90
% set default params
91
CPD.mean = randn(ss, dpsz);
92
CPD.cov = 100*repmat(eye(ss), [1 1 dpsz]);    
93
CPD.weights = randn(ss, cpsz, dpsz);
94
CPD.cov_type = 'full';
95
CPD.tied_cov = 0;
96
CPD.clamped_mean = 0;
97
CPD.clamped_cov = 0;
98
CPD.clamped_weights = 0;
99
CPD.cov_prior_weight = 0.01;
100

    
101
nargs = length(args);
102
if nargs > 0
103
  if ~isstr(args{1})
104
    % gaussian_CPD(bnet, self, mu, Sigma, W, cov_type, tied_cov, clamp_mean, clamp_cov, clamp_weights)
105
    if nargs >= 1 & ~isempty(args{1}), CPD.mean = args{1}; end
106
    if nargs >= 2 & ~isempty(args{2}), CPD.cov = args{2}; end
107
    if nargs >= 3 & ~isempty(args{3}), CPD.weights = args{3}; end
108
    if nargs >= 4 & ~isempty(args{4}), CPD.cov_type = args{4}; end
109
    if nargs >= 5 & ~isempty(args{5}) & strcmp(args{5}, 'tied'), CPD.tied_cov = 1; end
110
    if nargs >= 6 & ~isempty(args{6}), CPD.clamped_mean = 1; end
111
    if nargs >= 7 & ~isempty(args{7}), CPD.clamped_cov = 1; end
112
    if nargs >= 8 & ~isempty(args{8}), CPD.clamped_weights = 1; end
113
  else
114
    CPD = set_fields(CPD, args{:});
115
  end
116
end
117

    
118
% Make sure the matrices have 1 dimension per discrete parent.
119
% Bug fix due to Xuejing Sun 3/6/01
120
CPD.mean = myreshape(CPD.mean, [ss ns(dps)]);
121
CPD.cov = myreshape(CPD.cov, [ss ss ns(dps)]);
122
CPD.weights = myreshape(CPD.weights, [ss cpsz ns(dps)]);
123
  
124
CPD.init_cov = CPD.cov;  % we reset to this if things go wrong during learning
125

    
126
% expected sufficient statistics 
127
CPD.Wsum = zeros(dpsz,1);
128
CPD.WYsum = zeros(ss, dpsz);
129
CPD.WXsum = zeros(cpsz, dpsz);
130
CPD.WYYsum = zeros(ss, ss, dpsz);
131
CPD.WXXsum = zeros(cpsz, cpsz, dpsz);
132
CPD.WXYsum = zeros(cpsz, ss, dpsz);
133

    
134
% For BIC
135
CPD.nsamples = 0;
136
switch CPD.cov_type
137
  case 'full',
138
    ncov_params = ss*(ss-1)/2; % since symmetric (and positive definite)
139
  case 'diag',
140
    ncov_params = ss;
141
  otherwise
142
    error(['unrecognized cov_type ' cov_type]);
143
end
144
% params = weights + mean + cov
145
if CPD.tied_cov
146
  CPD.nparams = ss*cpsz*dpsz + ss*dpsz + ncov_params;
147
else
148
  CPD.nparams = ss*cpsz*dpsz + ss*dpsz + dpsz*ncov_params;
149
end
150

    
151

    
152

    
153
clamped = CPD.clamped_mean & CPD.clamped_cov & CPD.clamped_weights;
154
CPD = set_clamped(CPD, clamped);
155

    
156
%%%%%%%%%%%
157

    
158
function CPD = init_fields()
159
% This ensures we define the fields in the same order 
160
% no matter whether we load an object from a file,
161
% or create it from scratch. (Matlab requires this.)
162

    
163
CPD.self = [];
164
CPD.sizes = [];
165
CPD.cps = [];
166
CPD.dps = [];
167
CPD.mean = [];
168
CPD.cov = [];
169
CPD.weights = [];
170
CPD.clamped_mean = [];
171
CPD.clamped_cov = [];
172
CPD.clamped_weights = [];
173
CPD.init_cov = [];
174
CPD.cov_type = [];
175
CPD.tied_cov = [];
176
CPD.Wsum = [];
177
CPD.WYsum = [];
178
CPD.WXsum = [];
179
CPD.WYYsum = [];
180
CPD.WXXsum = [];
181
CPD.WXYsum = [];
182
CPD.nsamples = [];
183
CPD.nparams = [];            
184
CPD.cov_prior_weight = [];