To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @gaussian_CPD / maximize_params.m @ 8:b5b38998ef3b

History | View | Annotate | Download (1.64 KB)

1
function CPD = maximize_params(CPD, temp)
2
% MAXIMIZE_PARAMS Set the params of a CPD to their ML values (Gaussian)
3
% CPD = maximize_params(CPD, temperature)
4
%
5
% Temperature is currently ignored.
6

    
7
if ~adjustable_CPD(CPD), return; end
8

    
9

    
10
if CPD.clamped_mean
11
  cl_mean = CPD.mean;
12
else
13
  cl_mean = [];
14
end
15

    
16
if CPD.clamped_cov
17
  cl_cov = CPD.cov;
18
else
19
  cl_cov = [];
20
end
21

    
22
if CPD.clamped_weights
23
  cl_weights = CPD.weights;
24
else
25
  cl_weights = [];
26
end
27

    
28
[ssz psz Q] = size(CPD.weights);
29

    
30
[ss cpsz dpsz] = size(CPD.weights); % ss = self size = ssz
31
if cpsz > CPD.nsamples
32
  fprintf('gaussian_CPD/maximize_params: warning: input dimension (%d) > nsamples (%d)\n', ...
33
	  cpsz, CPD.nsamples);
34
end
35

    
36
prior =  repmat(CPD.cov_prior_weight*eye(ssz,ssz), [1 1 Q]);
37

    
38

    
39
[CPD.mean, CPD.cov, CPD.weights] = ...
40
    clg_Mstep(CPD.Wsum, CPD.WYsum, CPD.WYYsum, [], CPD.WXsum, CPD.WXXsum, CPD.WXYsum, ...
41
	      'cov_type', CPD.cov_type, 'clamped_mean', cl_mean, ...
42
	      'clamped_cov', cl_cov, 'clamped_weights', cl_weights, ...
43
	      'tied_cov', CPD.tied_cov, ...
44
	      'cov_prior', prior);
45

    
46
if 0
47
CPD.mean = reshape(CPD.mean, [ss dpsz]);
48
CPD.cov = reshape(CPD.cov, [ss ss dpsz]);
49
CPD.weights = reshape(CPD.weights, [ss cpsz dpsz]);
50
end
51

    
52
% Bug fix 11 May 2003 KPM
53
% clg_Mstep collapses all discrete parents into one mega-node
54
% but convert_to_CPT needs access to each parent separately
55
sz = CPD.sizes;
56
ss = sz(end);
57

    
58
% Bug fix KPM 20 May 2003: 
59
cpsz = sum(sz(CPD.cps));
60
%if isempty(CPD.cps)
61
%  cpsz = 0;
62
%else
63
%  cpsz = sz(CPD.cps);
64
%end
65
dpsz = sz(CPD.dps);
66
CPD.mean = myreshape(CPD.mean, [ss dpsz]);
67
CPD.cov = myreshape(CPD.cov, [ss ss dpsz]);
68
CPD.weights = myreshape(CPD.weights, [ss cpsz dpsz]);