Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@gaussian_CPD/maximize_params.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPD = maximize_params(CPD, temp) | |
2 % MAXIMIZE_PARAMS Set the params of a CPD to their ML values (Gaussian) | |
3 % CPD = maximize_params(CPD, temperature) | |
4 % | |
5 % Temperature is currently ignored. | |
6 | |
7 if ~adjustable_CPD(CPD), return; end | |
8 | |
9 | |
10 if CPD.clamped_mean | |
11 cl_mean = CPD.mean; | |
12 else | |
13 cl_mean = []; | |
14 end | |
15 | |
16 if CPD.clamped_cov | |
17 cl_cov = CPD.cov; | |
18 else | |
19 cl_cov = []; | |
20 end | |
21 | |
22 if CPD.clamped_weights | |
23 cl_weights = CPD.weights; | |
24 else | |
25 cl_weights = []; | |
26 end | |
27 | |
28 [ssz psz Q] = size(CPD.weights); | |
29 | |
30 [ss cpsz dpsz] = size(CPD.weights); % ss = self size = ssz | |
31 if cpsz > CPD.nsamples | |
32 fprintf('gaussian_CPD/maximize_params: warning: input dimension (%d) > nsamples (%d)\n', ... | |
33 cpsz, CPD.nsamples); | |
34 end | |
35 | |
36 prior = repmat(CPD.cov_prior_weight*eye(ssz,ssz), [1 1 Q]); | |
37 | |
38 | |
39 [CPD.mean, CPD.cov, CPD.weights] = ... | |
40 clg_Mstep(CPD.Wsum, CPD.WYsum, CPD.WYYsum, [], CPD.WXsum, CPD.WXXsum, CPD.WXYsum, ... | |
41 'cov_type', CPD.cov_type, 'clamped_mean', cl_mean, ... | |
42 'clamped_cov', cl_cov, 'clamped_weights', cl_weights, ... | |
43 'tied_cov', CPD.tied_cov, ... | |
44 'cov_prior', prior); | |
45 | |
46 if 0 | |
47 CPD.mean = reshape(CPD.mean, [ss dpsz]); | |
48 CPD.cov = reshape(CPD.cov, [ss ss dpsz]); | |
49 CPD.weights = reshape(CPD.weights, [ss cpsz dpsz]); | |
50 end | |
51 | |
52 % Bug fix 11 May 2003 KPM | |
53 % clg_Mstep collapses all discrete parents into one mega-node | |
54 % but convert_to_CPT needs access to each parent separately | |
55 sz = CPD.sizes; | |
56 ss = sz(end); | |
57 | |
58 % Bug fix KPM 20 May 2003: | |
59 cpsz = sum(sz(CPD.cps)); | |
60 %if isempty(CPD.cps) | |
61 % cpsz = 0; | |
62 %else | |
63 % cpsz = sz(CPD.cps); | |
64 %end | |
65 dpsz = sz(CPD.dps); | |
66 CPD.mean = myreshape(CPD.mean, [ss dpsz]); | |
67 CPD.cov = myreshape(CPD.cov, [ss ss dpsz]); | |
68 CPD.weights = myreshape(CPD.weights, [ss cpsz dpsz]); |