To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.
root / _FullBNT / BNT / CPDs / @gaussian_CPD / maximize_params.m @ 8:b5b38998ef3b
History | View | Annotate | Download (1.64 KB)
| 1 |
function CPD = maximize_params(CPD, temp) |
|---|---|
| 2 |
% MAXIMIZE_PARAMS Set the params of a CPD to their ML values (Gaussian) |
| 3 |
% CPD = maximize_params(CPD, temperature) |
| 4 |
% |
| 5 |
% Temperature is currently ignored. |
| 6 |
|
| 7 |
if ~adjustable_CPD(CPD), return; end |
| 8 |
|
| 9 |
|
| 10 |
if CPD.clamped_mean |
| 11 |
cl_mean = CPD.mean; |
| 12 |
else |
| 13 |
cl_mean = []; |
| 14 |
end |
| 15 |
|
| 16 |
if CPD.clamped_cov |
| 17 |
cl_cov = CPD.cov; |
| 18 |
else |
| 19 |
cl_cov = []; |
| 20 |
end |
| 21 |
|
| 22 |
if CPD.clamped_weights |
| 23 |
cl_weights = CPD.weights; |
| 24 |
else |
| 25 |
cl_weights = []; |
| 26 |
end |
| 27 |
|
| 28 |
[ssz psz Q] = size(CPD.weights); |
| 29 |
|
| 30 |
[ss cpsz dpsz] = size(CPD.weights); % ss = self size = ssz |
| 31 |
if cpsz > CPD.nsamples |
| 32 |
fprintf('gaussian_CPD/maximize_params: warning: input dimension (%d) > nsamples (%d)\n', ...
|
| 33 |
cpsz, CPD.nsamples); |
| 34 |
end |
| 35 |
|
| 36 |
prior = repmat(CPD.cov_prior_weight*eye(ssz,ssz), [1 1 Q]); |
| 37 |
|
| 38 |
|
| 39 |
[CPD.mean, CPD.cov, CPD.weights] = ... |
| 40 |
clg_Mstep(CPD.Wsum, CPD.WYsum, CPD.WYYsum, [], CPD.WXsum, CPD.WXXsum, CPD.WXYsum, ... |
| 41 |
'cov_type', CPD.cov_type, 'clamped_mean', cl_mean, ... |
| 42 |
'clamped_cov', cl_cov, 'clamped_weights', cl_weights, ... |
| 43 |
'tied_cov', CPD.tied_cov, ... |
| 44 |
'cov_prior', prior); |
| 45 |
|
| 46 |
if 0 |
| 47 |
CPD.mean = reshape(CPD.mean, [ss dpsz]); |
| 48 |
CPD.cov = reshape(CPD.cov, [ss ss dpsz]); |
| 49 |
CPD.weights = reshape(CPD.weights, [ss cpsz dpsz]); |
| 50 |
end |
| 51 |
|
| 52 |
% Bug fix 11 May 2003 KPM |
| 53 |
% clg_Mstep collapses all discrete parents into one mega-node |
| 54 |
% but convert_to_CPT needs access to each parent separately |
| 55 |
sz = CPD.sizes; |
| 56 |
ss = sz(end); |
| 57 |
|
| 58 |
% Bug fix KPM 20 May 2003: |
| 59 |
cpsz = sum(sz(CPD.cps)); |
| 60 |
%if isempty(CPD.cps) |
| 61 |
% cpsz = 0; |
| 62 |
%else |
| 63 |
% cpsz = sz(CPD.cps); |
| 64 |
%end |
| 65 |
dpsz = sz(CPD.dps); |
| 66 |
CPD.mean = myreshape(CPD.mean, [ss dpsz]); |
| 67 |
CPD.cov = myreshape(CPD.cov, [ss ss dpsz]); |
| 68 |
CPD.weights = myreshape(CPD.weights, [ss cpsz dpsz]); |