Daniel@0: function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin) Daniel@0: % LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM. Daniel@0: % [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...) Daniel@0: % Daniel@0: % data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden. Daniel@0: % Suppose you have L time series, each of length T, in an O*T*L array D, Daniel@0: % where O is the num of observed scalar nodes, and N is the total num nodes per slice. Daniel@0: % Then you can create data as follows, where onodes is the index of the observable nodes: Daniel@0: % data = cell(1,L); Daniel@0: % for l=1:L Daniel@0: % data{l} = cell(N, T); Daniel@0: % data{l}(onodes,:) = num2cell(D(:,:,l)); Daniel@0: % end Daniel@0: % Of course it is possible for different sets of nodes to be observed in Daniel@0: % each slice/ sequence, and for each sequence to be a different length. Daniel@0: % Daniel@0: % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration. Daniel@0: % Daniel@0: % Optional arguments [default] Daniel@0: % Daniel@0: % max_iter - specifies the maximum number of iterations [100] Daniel@0: % thresh - specifies the thresold for stopping EM [1e-3] Daniel@0: % We stop when |f(t) - f(t-1)| / avg < threshold, Daniel@0: % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik. Daniel@0: % verbose - display loglik at each iteration [1] Daniel@0: % anneal - 1 means do deterministic annealing (only for entropic priors) [0] Daniel@0: % anneal_rate - geometric cooling rate [0.8] Daniel@0: % init_temp - initial annealing temperature [10] Daniel@0: % final_temp - final annealing temperature [1e-3] Daniel@0: % Daniel@0: Daniel@0: max_iter = 100; Daniel@0: thresh = 1e-3; Daniel@0: anneal = 0; Daniel@0: anneal_rate = 0.8; Daniel@0: init_temp = 10; Daniel@0: final_temp = 1e-3; Daniel@0: verbose = 1; Daniel@0: Daniel@0: for i=1:2:length(varargin) Daniel@0: switch varargin{i} Daniel@0: case 'max_iter', max_iter = varargin{i+1}; Daniel@0: case 'thresh', thresh = varargin{i+1}; Daniel@0: case 'anneal', anneal = varargin{i+1}; Daniel@0: case 'anneal_rate', anneal_rate = varargin{i+1}; Daniel@0: case 'init_temp', init_temp = varargin{i+1}; Daniel@0: case 'final_temp', final_temp = varargin{i+1}; Daniel@0: otherwise, error(['unrecognized argument' varargin{i}]) Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: % take 1 EM step at each temperature value, then when temp=0, run to convergence Daniel@0: % When using an entropic prior, Z = 1-T, so Daniel@0: % T=2 => Z=-1 (max entropy) Daniel@0: % T=1 => Z=0 (max likelihood) Daniel@0: % T=0 => Z=1 (min entropy / max structure) Daniel@0: num_iter = 1; Daniel@0: LL = []; Daniel@0: if anneal Daniel@0: temperature = init_temp; Daniel@0: while temperature > final_temp Daniel@0: [engine, loglik, logpost] = EM_step(engine, evidence, temperature); Daniel@0: if verbose Daniel@0: fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ... Daniel@0: num_iter, loglik, logpost, temperature); Daniel@0: end Daniel@0: num_iter = num_iter + 1; Daniel@0: LL = [LL loglik]; Daniel@0: temperature = temperature * anneal_rate; Daniel@0: end Daniel@0: temperature = 0; Daniel@0: previous_loglik = loglik; Daniel@0: previous_logpost = logpost; Daniel@0: else Daniel@0: temperature = 0; Daniel@0: previous_loglik = -inf; Daniel@0: previous_logpost = -inf; Daniel@0: end Daniel@0: Daniel@0: converged = 0; Daniel@0: while ~converged & (num_iter <= max_iter) Daniel@0: [engine, loglik, logpost] = EM_step(engine, evidence, temperature); Daniel@0: if verbose Daniel@0: %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ... Daniel@0: % num_iter, loglik, logpost); Daniel@0: fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik); Daniel@0: end Daniel@0: num_iter = num_iter + 1; Daniel@0: [converged, decreased] = em_converged(loglik, previous_loglik, thresh); Daniel@0: %[converged, decreased] = em_converged(logpost, previous_logpost, thresh); Daniel@0: previous_loglik = loglik; Daniel@0: previous_logpost = logpost; Daniel@0: LL = [LL loglik]; Daniel@0: end Daniel@0: Daniel@0: bnet = bnet_from_engine(engine); Daniel@0: Daniel@0: %%%%%%%%% Daniel@0: Daniel@0: function [engine, loglik, logpost] = EM_step(engine, cases, temp) Daniel@0: Daniel@0: bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step Daniel@0: ss = length(bnet.intra); Daniel@0: CPDs = bnet.CPD; % these are the new params that get maximized Daniel@0: num_CPDs = length(CPDs); Daniel@0: Daniel@0: % log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D)) Daniel@0: % where log P(D|theta) = sum_cases log P(case|theta) Daniel@0: % and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied! Daniel@0: % logpost = log P(theta,D) (un-normalized) Daniel@0: % This should be negative, and increase at every step. Daniel@0: Daniel@0: adjustable = zeros(1,num_CPDs); Daniel@0: logprior = zeros(1, num_CPDs); Daniel@0: for e=1:num_CPDs Daniel@0: adjustable(e) = adjustable_CPD(CPDs{e}); Daniel@0: end Daniel@0: adj = find(adjustable); Daniel@0: Daniel@0: for e=adj(:)' Daniel@0: logprior(e) = log_prior(CPDs{e}); Daniel@0: CPDs{e} = reset_ess(CPDs{e}); Daniel@0: end Daniel@0: Daniel@0: loglik = 0; Daniel@0: for l=1:length(cases) Daniel@0: evidence = cases{l}; Daniel@0: if ~iscell(evidence) Daniel@0: error('training data must be a cell array of cell arrays') Daniel@0: end Daniel@0: [engine, ll] = enter_evidence(engine, evidence); Daniel@0: assert(~isnan(ll)) Daniel@0: loglik = loglik + ll; Daniel@0: T = size(evidence, 2); Daniel@0: Daniel@0: % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number Daniel@0: % so that they extract evidence from the right place. Daniel@0: % (The CPD should really store its own version of ns and cnodes...) Daniel@0: ns = repmat(bnet.node_sizes_slice(:), [1 T]); Daniel@0: cnodes = unroll_set(bnet.cnodes_slice, ss, T); Daniel@0: Daniel@0: %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]); Daniel@0: hidden_bitv = zeros(ss, T); Daniel@0: hidden_bitv(isemptycell(evidence))=1; Daniel@0: % hidden_bitv(i) = 1 means node i is hidden. Daniel@0: % We pass this in, rather than using isemptycell(evidence(dom)), because Daniel@0: % isemptycell is very slow. Daniel@0: Daniel@0: t = 1; Daniel@0: for i=1:ss Daniel@0: e = bnet.equiv_class(i,1); Daniel@0: if adjustable(e) Daniel@0: fmarg = marginal_family(engine, i, t); Daniel@0: CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: for i=1:ss Daniel@0: e = bnet.equiv_class(i,2); Daniel@0: if adjustable(e) Daniel@0: for t=2:T Daniel@0: fmarg = marginal_family(engine, i, t); Daniel@0: CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); Daniel@0: end Daniel@0: end Daniel@0: end Daniel@0: end Daniel@0: Daniel@0: logpost = loglik + sum(logprior(:)); Daniel@0: Daniel@0: for e=adj(:)' Daniel@0: CPDs{e} = maximize_params(CPDs{e}, temp); Daniel@0: end Daniel@0: Daniel@0: engine = update_engine(engine, CPDs); Daniel@0: Daniel@0: Daniel@0: Daniel@0: