wolffd@0
|
1 function [bnet, LL, engine] = learn_params_em(engine, evidence, max_iter, thresh)
|
wolffd@0
|
2 % LEARN_PARAMS_EM Set the parameters of each adjustable node to their ML/MAP values using batch EM.
|
wolffd@0
|
3 % [bnet, LLtrace, engine] = learn_params_em(engine, data, max_iter, thresh)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % data{i,l} is the value of node i in case l, or [] if hidden.
|
wolffd@0
|
6 % Suppose you have L training cases in an O*L array, D, where O is the num observed
|
wolffd@0
|
7 % scalar nodes, and N is the total num nodes.
|
wolffd@0
|
8 % Then you can create 'data' as follows, where onodes is the index of the observable nodes:
|
wolffd@0
|
9 % data = cell(N, L);
|
wolffd@0
|
10 % data(onodes,:) = num2cell(D);
|
wolffd@0
|
11 % Of course it is possible for different sets of nodes to be observed in each case.
|
wolffd@0
|
12 %
|
wolffd@0
|
13 % We return the modified bnet and engine.
|
wolffd@0
|
14 % To see the learned parameters for node i, use the construct
|
wolffd@0
|
15 % s = struct(bnet.CPD{i}); % violate object privacy
|
wolffd@0
|
16 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration.
|
wolffd@0
|
17 %
|
wolffd@0
|
18 % max_iter specifies the maximum number of iterations. Default: 10.
|
wolffd@0
|
19 %
|
wolffd@0
|
20 % thresh specifies the thresold for stopping EM. Default: 1e-3.
|
wolffd@0
|
21 % We stop when |f(t) - f(t-1)| / avg < threshold,
|
wolffd@0
|
22 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik.
|
wolffd@0
|
23
|
wolffd@0
|
24 if nargin < 3, max_iter = 10; end
|
wolffd@0
|
25 if nargin < 4, thresh = 1e-3; end
|
wolffd@0
|
26
|
wolffd@0
|
27 verbose = 1;
|
wolffd@0
|
28
|
wolffd@0
|
29 loglik = 0;
|
wolffd@0
|
30 previous_loglik = -inf;
|
wolffd@0
|
31 converged = 0;
|
wolffd@0
|
32 num_iter = 1;
|
wolffd@0
|
33 LL = [];
|
wolffd@0
|
34
|
wolffd@0
|
35 while ~converged & (num_iter <= max_iter)
|
wolffd@0
|
36 [engine, loglik] = EM_step(engine, evidence);
|
wolffd@0
|
37 if verbose, fprintf('EM iteration %d, ll = %8.4f\n', num_iter, loglik); end
|
wolffd@0
|
38 num_iter = num_iter + 1;
|
wolffd@0
|
39 converged = em_converged(loglik, previous_loglik, thresh);
|
wolffd@0
|
40 previous_loglik = loglik;
|
wolffd@0
|
41 LL = [LL loglik];
|
wolffd@0
|
42 end
|
wolffd@0
|
43 if verbose, fprintf('\n'); end
|
wolffd@0
|
44
|
wolffd@0
|
45 bnet = bnet_from_engine(engine);
|
wolffd@0
|
46
|
wolffd@0
|
47 %%%%%%%%%
|
wolffd@0
|
48
|
wolffd@0
|
49 function [engine, loglik] = EM_step(engine, cases)
|
wolffd@0
|
50
|
wolffd@0
|
51 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step
|
wolffd@0
|
52 CPDs = bnet.CPD; % these are the new params that get maximized
|
wolffd@0
|
53 num_CPDs = length(CPDs);
|
wolffd@0
|
54 adjustable = zeros(1,num_CPDs);
|
wolffd@0
|
55 for e=1:num_CPDs
|
wolffd@0
|
56 adjustable(e) = adjustable_CPD(CPDs{e});
|
wolffd@0
|
57 end
|
wolffd@0
|
58 adj = find(adjustable);
|
wolffd@0
|
59 n = length(bnet.dag);
|
wolffd@0
|
60
|
wolffd@0
|
61 for e=adj(:)'
|
wolffd@0
|
62 CPDs{e} = reset_ess(CPDs{e});
|
wolffd@0
|
63 end
|
wolffd@0
|
64
|
wolffd@0
|
65 loglik = 0;
|
wolffd@0
|
66 ncases = size(cases, 2);
|
wolffd@0
|
67 for l=1:ncases
|
wolffd@0
|
68 evidence = cases(:,l);
|
wolffd@0
|
69 [engine, ll] = enter_evidence(engine, evidence);
|
wolffd@0
|
70 loglik = loglik + ll;
|
wolffd@0
|
71 hidden_bitv = zeros(1,n);
|
wolffd@0
|
72 hidden_bitv(isemptycell(evidence))=1;
|
wolffd@0
|
73 for i=1:n
|
wolffd@0
|
74 e = bnet.equiv_class(i);
|
wolffd@0
|
75 if adjustable(e)
|
wolffd@0
|
76 fmarg = marginal_family(engine, i);
|
wolffd@0
|
77 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, bnet.node_sizes, bnet.cnodes, hidden_bitv);
|
wolffd@0
|
78 end
|
wolffd@0
|
79 end
|
wolffd@0
|
80 end
|
wolffd@0
|
81
|
wolffd@0
|
82 for e=adj(:)'
|
wolffd@0
|
83 CPDs{e} = maximize_params(CPDs{e});
|
wolffd@0
|
84 end
|
wolffd@0
|
85
|
wolffd@0
|
86 engine = update_engine(engine, CPDs);
|
wolffd@0
|
87
|
wolffd@0
|
88
|