matthiasm@8: function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin) matthiasm@8: % LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM. matthiasm@8: % [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...) matthiasm@8: % matthiasm@8: % data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden. matthiasm@8: % Suppose you have L time series, each of length T, in an O*T*L array D, matthiasm@8: % where O is the num of observed scalar nodes, and N is the total num nodes per slice. matthiasm@8: % Then you can create data as follows, where onodes is the index of the observable nodes: matthiasm@8: % data = cell(1,L); matthiasm@8: % for l=1:L matthiasm@8: % data{l} = cell(N, T); matthiasm@8: % data{l}(onodes,:) = num2cell(D(:,:,l)); matthiasm@8: % end matthiasm@8: % Of course it is possible for different sets of nodes to be observed in matthiasm@8: % each slice/ sequence, and for each sequence to be a different length. matthiasm@8: % matthiasm@8: % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration. matthiasm@8: % matthiasm@8: % Optional arguments [default] matthiasm@8: % matthiasm@8: % max_iter - specifies the maximum number of iterations [100] matthiasm@8: % thresh - specifies the thresold for stopping EM [1e-3] matthiasm@8: % We stop when |f(t) - f(t-1)| / avg < threshold, matthiasm@8: % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik. matthiasm@8: % verbose - display loglik at each iteration [1] matthiasm@8: % anneal - 1 means do deterministic annealing (only for entropic priors) [0] matthiasm@8: % anneal_rate - geometric cooling rate [0.8] matthiasm@8: % init_temp - initial annealing temperature [10] matthiasm@8: % final_temp - final annealing temperature [1e-3] matthiasm@8: % matthiasm@8: matthiasm@8: max_iter = 100; matthiasm@8: thresh = 1e-3; matthiasm@8: anneal = 0; matthiasm@8: anneal_rate = 0.8; matthiasm@8: init_temp = 10; matthiasm@8: final_temp = 1e-3; matthiasm@8: verbose = 1; matthiasm@8: matthiasm@8: for i=1:2:length(varargin) matthiasm@8: switch varargin{i} matthiasm@8: case 'max_iter', max_iter = varargin{i+1}; matthiasm@8: case 'thresh', thresh = varargin{i+1}; matthiasm@8: case 'anneal', anneal = varargin{i+1}; matthiasm@8: case 'anneal_rate', anneal_rate = varargin{i+1}; matthiasm@8: case 'init_temp', init_temp = varargin{i+1}; matthiasm@8: case 'final_temp', final_temp = varargin{i+1}; matthiasm@8: otherwise, error(['unrecognized argument' varargin{i}]) matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: % take 1 EM step at each temperature value, then when temp=0, run to convergence matthiasm@8: % When using an entropic prior, Z = 1-T, so matthiasm@8: % T=2 => Z=-1 (max entropy) matthiasm@8: % T=1 => Z=0 (max likelihood) matthiasm@8: % T=0 => Z=1 (min entropy / max structure) matthiasm@8: num_iter = 1; matthiasm@8: LL = []; matthiasm@8: if anneal matthiasm@8: temperature = init_temp; matthiasm@8: while temperature > final_temp matthiasm@8: [engine, loglik, logpost] = EM_step(engine, evidence, temperature); matthiasm@8: if verbose matthiasm@8: fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ... matthiasm@8: num_iter, loglik, logpost, temperature); matthiasm@8: end matthiasm@8: num_iter = num_iter + 1; matthiasm@8: LL = [LL loglik]; matthiasm@8: temperature = temperature * anneal_rate; matthiasm@8: end matthiasm@8: temperature = 0; matthiasm@8: previous_loglik = loglik; matthiasm@8: previous_logpost = logpost; matthiasm@8: else matthiasm@8: temperature = 0; matthiasm@8: previous_loglik = -inf; matthiasm@8: previous_logpost = -inf; matthiasm@8: end matthiasm@8: matthiasm@8: converged = 0; matthiasm@8: while ~converged & (num_iter <= max_iter) matthiasm@8: [engine, loglik, logpost] = EM_step(engine, evidence, temperature); matthiasm@8: if verbose matthiasm@8: %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ... matthiasm@8: % num_iter, loglik, logpost); matthiasm@8: fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik); matthiasm@8: end matthiasm@8: num_iter = num_iter + 1; matthiasm@8: [converged, decreased] = em_converged(loglik, previous_loglik, thresh); matthiasm@8: %[converged, decreased] = em_converged(logpost, previous_logpost, thresh); matthiasm@8: previous_loglik = loglik; matthiasm@8: previous_logpost = logpost; matthiasm@8: LL = [LL loglik]; matthiasm@8: end matthiasm@8: matthiasm@8: bnet = bnet_from_engine(engine); matthiasm@8: matthiasm@8: %%%%%%%%% matthiasm@8: matthiasm@8: function [engine, loglik, logpost] = EM_step(engine, cases, temp) matthiasm@8: matthiasm@8: bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step matthiasm@8: ss = length(bnet.intra); matthiasm@8: CPDs = bnet.CPD; % these are the new params that get maximized matthiasm@8: num_CPDs = length(CPDs); matthiasm@8: matthiasm@8: % log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D)) matthiasm@8: % where log P(D|theta) = sum_cases log P(case|theta) matthiasm@8: % and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied! matthiasm@8: % logpost = log P(theta,D) (un-normalized) matthiasm@8: % This should be negative, and increase at every step. matthiasm@8: matthiasm@8: adjustable = zeros(1,num_CPDs); matthiasm@8: logprior = zeros(1, num_CPDs); matthiasm@8: for e=1:num_CPDs matthiasm@8: adjustable(e) = adjustable_CPD(CPDs{e}); matthiasm@8: end matthiasm@8: adj = find(adjustable); matthiasm@8: matthiasm@8: for e=adj(:)' matthiasm@8: logprior(e) = log_prior(CPDs{e}); matthiasm@8: CPDs{e} = reset_ess(CPDs{e}); matthiasm@8: end matthiasm@8: matthiasm@8: loglik = 0; matthiasm@8: for l=1:length(cases) matthiasm@8: evidence = cases{l}; matthiasm@8: if ~iscell(evidence) matthiasm@8: error('training data must be a cell array of cell arrays') matthiasm@8: end matthiasm@8: [engine, ll] = enter_evidence(engine, evidence); matthiasm@8: assert(~isnan(ll)) matthiasm@8: loglik = loglik + ll; matthiasm@8: T = size(evidence, 2); matthiasm@8: matthiasm@8: % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number matthiasm@8: % so that they extract evidence from the right place. matthiasm@8: % (The CPD should really store its own version of ns and cnodes...) matthiasm@8: ns = repmat(bnet.node_sizes_slice(:), [1 T]); matthiasm@8: cnodes = unroll_set(bnet.cnodes_slice, ss, T); matthiasm@8: matthiasm@8: %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]); matthiasm@8: hidden_bitv = zeros(ss, T); matthiasm@8: hidden_bitv(isemptycell(evidence))=1; matthiasm@8: % hidden_bitv(i) = 1 means node i is hidden. matthiasm@8: % We pass this in, rather than using isemptycell(evidence(dom)), because matthiasm@8: % isemptycell is very slow. matthiasm@8: matthiasm@8: t = 1; matthiasm@8: for i=1:ss matthiasm@8: e = bnet.equiv_class(i,1); matthiasm@8: if adjustable(e) matthiasm@8: fmarg = marginal_family(engine, i, t); matthiasm@8: CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: for i=1:ss matthiasm@8: e = bnet.equiv_class(i,2); matthiasm@8: if adjustable(e) matthiasm@8: for t=2:T matthiasm@8: fmarg = marginal_family(engine, i, t); matthiasm@8: CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); matthiasm@8: end matthiasm@8: end matthiasm@8: end matthiasm@8: end matthiasm@8: matthiasm@8: logpost = loglik + sum(logprior(:)); matthiasm@8: matthiasm@8: for e=adj(:)' matthiasm@8: CPDs{e} = maximize_params(CPDs{e}, temp); matthiasm@8: end matthiasm@8: matthiasm@8: engine = update_engine(engine, CPDs); matthiasm@8: matthiasm@8: matthiasm@8: matthiasm@8: