matthiasm@8: function [bnet, LL, engine] = learn_params_em(engine, evidence, max_iter, thresh) matthiasm@8: % LEARN_PARAMS_EM Set the parameters of each adjustable node to their ML/MAP values using batch EM. matthiasm@8: % [bnet, LLtrace, engine] = learn_params_em(engine, data, max_iter, thresh) matthiasm@8: % matthiasm@8: % data{i,l} is the value of node i in case l, or [] if hidden. matthiasm@8: % Suppose you have L training cases in an O*L array, D, where O is the num observed matthiasm@8: % scalar nodes, and N is the total num nodes. matthiasm@8: % Then you can create 'data' as follows, where onodes is the index of the observable nodes: matthiasm@8: % data = cell(N, L); matthiasm@8: % data(onodes,:) = num2cell(D); matthiasm@8: % Of course it is possible for different sets of nodes to be observed in each case. matthiasm@8: % matthiasm@8: % We return the modified bnet and engine. matthiasm@8: % To see the learned parameters for node i, use the construct matthiasm@8: % s = struct(bnet.CPD{i}); % violate object privacy matthiasm@8: % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration. matthiasm@8: % matthiasm@8: % max_iter specifies the maximum number of iterations. Default: 10. matthiasm@8: % matthiasm@8: % thresh specifies the thresold for stopping EM. Default: 1e-3. matthiasm@8: % We stop when |f(t) - f(t-1)| / avg < threshold, matthiasm@8: % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik. matthiasm@8: matthiasm@8: if nargin < 3, max_iter = 10; end matthiasm@8: if nargin < 4, thresh = 1e-3; end matthiasm@8: matthiasm@8: verbose = 1; matthiasm@8: matthiasm@8: loglik = 0; matthiasm@8: previous_loglik = -inf; matthiasm@8: converged = 0; matthiasm@8: num_iter = 1; matthiasm@8: LL = []; matthiasm@8: matthiasm@8: while ~converged & (num_iter <= max_iter) matthiasm@8: [engine, loglik] = EM_step(engine, evidence); matthiasm@8: if verbose, fprintf('EM iteration %d, ll = %8.4f\n', num_iter, loglik); end matthiasm@8: num_iter = num_iter + 1; matthiasm@8: converged = em_converged(loglik, previous_loglik, thresh); matthiasm@8: previous_loglik = loglik; matthiasm@8: LL = [LL loglik]; matthiasm@8: end matthiasm@8: if verbose, fprintf('\n'); end matthiasm@8: matthiasm@8: bnet = bnet_from_engine(engine); matthiasm@8: matthiasm@8: %%%%%%%%% matthiasm@8: matthiasm@8: function [engine, loglik] = EM_step(engine, cases) matthiasm@8: matthiasm@8: bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step matthiasm@8: CPDs = bnet.CPD; % these are the new params that get maximized matthiasm@8: num_CPDs = length(CPDs); matthiasm@8: adjustable = zeros(1,num_CPDs); matthiasm@8: for e=1:num_CPDs matthiasm@8: adjustable(e) = adjustable_CPD(CPDs{e}); matthiasm@8: end matthiasm@8: adj = find(adjustable); matthiasm@8: n = length(bnet.dag); matthiasm@8: matthiasm@8: for e=adj(:)' matthiasm@8: CPDs{e} = reset_ess(CPDs{e}); matthiasm@8: end matthiasm@8: matthiasm@8: loglik = 0; matthiasm@8: ncases = size(cases, 2); matthiasm@8: for l=1:ncases matthiasm@8: evidence = cases(:,l); matthiasm@8: [engine, ll] = enter_evidence(engine, evidence); matthiasm@8: loglik = loglik + ll; matthiasm@8: hidden_bitv = zeros(1,n); matthiasm@8: hidden_bitv(isemptycell(evidence))=1; matthiasm@8: for i=1:n matthiasm@8: e = bnet.equiv_class(i); matthiasm@8: if adjustable(e) matthiasm@8: fmarg = marginal_family(engine, i); matthiasm@8: CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, bnet.node_sizes, bnet.cnodes, hidden_bitv); matthiasm@8: end matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: for e=adj(:)' matthiasm@8: CPDs{e} = maximize_params(CPDs{e}); matthiasm@8: end matthiasm@8: matthiasm@8: engine = update_engine(engine, CPDs); matthiasm@8: matthiasm@8: