Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/learning/learn_params_em.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [bnet, LL, engine] = learn_params_em(engine, evidence, max_iter, thresh) | |
2 % LEARN_PARAMS_EM Set the parameters of each adjustable node to their ML/MAP values using batch EM. | |
3 % [bnet, LLtrace, engine] = learn_params_em(engine, data, max_iter, thresh) | |
4 % | |
5 % data{i,l} is the value of node i in case l, or [] if hidden. | |
6 % Suppose you have L training cases in an O*L array, D, where O is the num observed | |
7 % scalar nodes, and N is the total num nodes. | |
8 % Then you can create 'data' as follows, where onodes is the index of the observable nodes: | |
9 % data = cell(N, L); | |
10 % data(onodes,:) = num2cell(D); | |
11 % Of course it is possible for different sets of nodes to be observed in each case. | |
12 % | |
13 % We return the modified bnet and engine. | |
14 % To see the learned parameters for node i, use the construct | |
15 % s = struct(bnet.CPD{i}); % violate object privacy | |
16 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration. | |
17 % | |
18 % max_iter specifies the maximum number of iterations. Default: 10. | |
19 % | |
20 % thresh specifies the thresold for stopping EM. Default: 1e-3. | |
21 % We stop when |f(t) - f(t-1)| / avg < threshold, | |
22 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik. | |
23 | |
24 if nargin < 3, max_iter = 10; end | |
25 if nargin < 4, thresh = 1e-3; end | |
26 | |
27 verbose = 1; | |
28 | |
29 loglik = 0; | |
30 previous_loglik = -inf; | |
31 converged = 0; | |
32 num_iter = 1; | |
33 LL = []; | |
34 | |
35 while ~converged & (num_iter <= max_iter) | |
36 [engine, loglik] = EM_step(engine, evidence); | |
37 if verbose, fprintf('EM iteration %d, ll = %8.4f\n', num_iter, loglik); end | |
38 num_iter = num_iter + 1; | |
39 converged = em_converged(loglik, previous_loglik, thresh); | |
40 previous_loglik = loglik; | |
41 LL = [LL loglik]; | |
42 end | |
43 if verbose, fprintf('\n'); end | |
44 | |
45 bnet = bnet_from_engine(engine); | |
46 | |
47 %%%%%%%%% | |
48 | |
49 function [engine, loglik] = EM_step(engine, cases) | |
50 | |
51 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step | |
52 CPDs = bnet.CPD; % these are the new params that get maximized | |
53 num_CPDs = length(CPDs); | |
54 adjustable = zeros(1,num_CPDs); | |
55 for e=1:num_CPDs | |
56 adjustable(e) = adjustable_CPD(CPDs{e}); | |
57 end | |
58 adj = find(adjustable); | |
59 n = length(bnet.dag); | |
60 | |
61 for e=adj(:)' | |
62 CPDs{e} = reset_ess(CPDs{e}); | |
63 end | |
64 | |
65 loglik = 0; | |
66 ncases = size(cases, 2); | |
67 for l=1:ncases | |
68 evidence = cases(:,l); | |
69 [engine, ll] = enter_evidence(engine, evidence); | |
70 loglik = loglik + ll; | |
71 hidden_bitv = zeros(1,n); | |
72 hidden_bitv(isemptycell(evidence))=1; | |
73 for i=1:n | |
74 e = bnet.equiv_class(i); | |
75 if adjustable(e) | |
76 fmarg = marginal_family(engine, i); | |
77 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, bnet.node_sizes, bnet.cnodes, hidden_bitv); | |
78 end | |
79 end | |
80 end | |
81 | |
82 for e=adj(:)' | |
83 CPDs{e} = maximize_params(CPDs{e}); | |
84 end | |
85 | |
86 engine = update_engine(engine, CPDs); | |
87 | |
88 |