comparison toolboxes/FullBNT-1.0.7/bnt/learning/learn_params_em.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [bnet, LL, engine] = learn_params_em(engine, evidence, max_iter, thresh)
2 % LEARN_PARAMS_EM Set the parameters of each adjustable node to their ML/MAP values using batch EM.
3 % [bnet, LLtrace, engine] = learn_params_em(engine, data, max_iter, thresh)
4 %
5 % data{i,l} is the value of node i in case l, or [] if hidden.
6 % Suppose you have L training cases in an O*L array, D, where O is the num observed
7 % scalar nodes, and N is the total num nodes.
8 % Then you can create 'data' as follows, where onodes is the index of the observable nodes:
9 % data = cell(N, L);
10 % data(onodes,:) = num2cell(D);
11 % Of course it is possible for different sets of nodes to be observed in each case.
12 %
13 % We return the modified bnet and engine.
14 % To see the learned parameters for node i, use the construct
15 % s = struct(bnet.CPD{i}); % violate object privacy
16 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration.
17 %
18 % max_iter specifies the maximum number of iterations. Default: 10.
19 %
20 % thresh specifies the thresold for stopping EM. Default: 1e-3.
21 % We stop when |f(t) - f(t-1)| / avg < threshold,
22 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik.
23
24 if nargin < 3, max_iter = 10; end
25 if nargin < 4, thresh = 1e-3; end
26
27 verbose = 1;
28
29 loglik = 0;
30 previous_loglik = -inf;
31 converged = 0;
32 num_iter = 1;
33 LL = [];
34
35 while ~converged & (num_iter <= max_iter)
36 [engine, loglik] = EM_step(engine, evidence);
37 if verbose, fprintf('EM iteration %d, ll = %8.4f\n', num_iter, loglik); end
38 num_iter = num_iter + 1;
39 converged = em_converged(loglik, previous_loglik, thresh);
40 previous_loglik = loglik;
41 LL = [LL loglik];
42 end
43 if verbose, fprintf('\n'); end
44
45 bnet = bnet_from_engine(engine);
46
47 %%%%%%%%%
48
49 function [engine, loglik] = EM_step(engine, cases)
50
51 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step
52 CPDs = bnet.CPD; % these are the new params that get maximized
53 num_CPDs = length(CPDs);
54 adjustable = zeros(1,num_CPDs);
55 for e=1:num_CPDs
56 adjustable(e) = adjustable_CPD(CPDs{e});
57 end
58 adj = find(adjustable);
59 n = length(bnet.dag);
60
61 for e=adj(:)'
62 CPDs{e} = reset_ess(CPDs{e});
63 end
64
65 loglik = 0;
66 ncases = size(cases, 2);
67 for l=1:ncases
68 evidence = cases(:,l);
69 [engine, ll] = enter_evidence(engine, evidence);
70 loglik = loglik + ll;
71 hidden_bitv = zeros(1,n);
72 hidden_bitv(isemptycell(evidence))=1;
73 for i=1:n
74 e = bnet.equiv_class(i);
75 if adjustable(e)
76 fmarg = marginal_family(engine, i);
77 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, bnet.node_sizes, bnet.cnodes, hidden_bitv);
78 end
79 end
80 end
81
82 for e=adj(:)'
83 CPDs{e} = maximize_params(CPDs{e});
84 end
85
86 engine = update_engine(engine, CPDs);
87
88