wolffd@0: function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin) wolffd@0: % LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM. wolffd@0: % [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...) wolffd@0: % wolffd@0: % data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden. wolffd@0: % Suppose you have L time series, each of length T, in an O*T*L array D, wolffd@0: % where O is the num of observed scalar nodes, and N is the total num nodes per slice. wolffd@0: % Then you can create data as follows, where onodes is the index of the observable nodes: wolffd@0: % data = cell(1,L); wolffd@0: % for l=1:L wolffd@0: % data{l} = cell(N, T); wolffd@0: % data{l}(onodes,:) = num2cell(D(:,:,l)); wolffd@0: % end wolffd@0: % Of course it is possible for different sets of nodes to be observed in wolffd@0: % each slice/ sequence, and for each sequence to be a different length. wolffd@0: % wolffd@0: % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration. wolffd@0: % wolffd@0: % Optional arguments [default] wolffd@0: % wolffd@0: % max_iter - specifies the maximum number of iterations [100] wolffd@0: % thresh - specifies the thresold for stopping EM [1e-3] wolffd@0: % We stop when |f(t) - f(t-1)| / avg < threshold, wolffd@0: % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik. wolffd@0: % verbose - display loglik at each iteration [1] wolffd@0: % anneal - 1 means do deterministic annealing (only for entropic priors) [0] wolffd@0: % anneal_rate - geometric cooling rate [0.8] wolffd@0: % init_temp - initial annealing temperature [10] wolffd@0: % final_temp - final annealing temperature [1e-3] wolffd@0: % wolffd@0: wolffd@0: max_iter = 100; wolffd@0: thresh = 1e-3; wolffd@0: anneal = 0; wolffd@0: anneal_rate = 0.8; wolffd@0: init_temp = 10; wolffd@0: final_temp = 1e-3; wolffd@0: verbose = 1; wolffd@0: wolffd@0: for i=1:2:length(varargin) wolffd@0: switch varargin{i} wolffd@0: case 'max_iter', max_iter = varargin{i+1}; wolffd@0: case 'thresh', thresh = varargin{i+1}; wolffd@0: case 'anneal', anneal = varargin{i+1}; wolffd@0: case 'anneal_rate', anneal_rate = varargin{i+1}; wolffd@0: case 'init_temp', init_temp = varargin{i+1}; wolffd@0: case 'final_temp', final_temp = varargin{i+1}; wolffd@0: otherwise, error(['unrecognized argument' varargin{i}]) wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % take 1 EM step at each temperature value, then when temp=0, run to convergence wolffd@0: % When using an entropic prior, Z = 1-T, so wolffd@0: % T=2 => Z=-1 (max entropy) wolffd@0: % T=1 => Z=0 (max likelihood) wolffd@0: % T=0 => Z=1 (min entropy / max structure) wolffd@0: num_iter = 1; wolffd@0: LL = []; wolffd@0: if anneal wolffd@0: temperature = init_temp; wolffd@0: while temperature > final_temp wolffd@0: [engine, loglik, logpost] = EM_step(engine, evidence, temperature); wolffd@0: if verbose wolffd@0: fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ... wolffd@0: num_iter, loglik, logpost, temperature); wolffd@0: end wolffd@0: num_iter = num_iter + 1; wolffd@0: LL = [LL loglik]; wolffd@0: temperature = temperature * anneal_rate; wolffd@0: end wolffd@0: temperature = 0; wolffd@0: previous_loglik = loglik; wolffd@0: previous_logpost = logpost; wolffd@0: else wolffd@0: temperature = 0; wolffd@0: previous_loglik = -inf; wolffd@0: previous_logpost = -inf; wolffd@0: end wolffd@0: wolffd@0: converged = 0; wolffd@0: while ~converged & (num_iter <= max_iter) wolffd@0: [engine, loglik, logpost] = EM_step(engine, evidence, temperature); wolffd@0: if verbose wolffd@0: %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ... wolffd@0: % num_iter, loglik, logpost); wolffd@0: fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik); wolffd@0: end wolffd@0: num_iter = num_iter + 1; wolffd@0: [converged, decreased] = em_converged(loglik, previous_loglik, thresh); wolffd@0: %[converged, decreased] = em_converged(logpost, previous_logpost, thresh); wolffd@0: previous_loglik = loglik; wolffd@0: previous_logpost = logpost; wolffd@0: LL = [LL loglik]; wolffd@0: end wolffd@0: wolffd@0: bnet = bnet_from_engine(engine); wolffd@0: wolffd@0: %%%%%%%%% wolffd@0: wolffd@0: function [engine, loglik, logpost] = EM_step(engine, cases, temp) wolffd@0: wolffd@0: bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step wolffd@0: ss = length(bnet.intra); wolffd@0: CPDs = bnet.CPD; % these are the new params that get maximized wolffd@0: num_CPDs = length(CPDs); wolffd@0: wolffd@0: % log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D)) wolffd@0: % where log P(D|theta) = sum_cases log P(case|theta) wolffd@0: % and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied! wolffd@0: % logpost = log P(theta,D) (un-normalized) wolffd@0: % This should be negative, and increase at every step. wolffd@0: wolffd@0: adjustable = zeros(1,num_CPDs); wolffd@0: logprior = zeros(1, num_CPDs); wolffd@0: for e=1:num_CPDs wolffd@0: adjustable(e) = adjustable_CPD(CPDs{e}); wolffd@0: end wolffd@0: adj = find(adjustable); wolffd@0: wolffd@0: for e=adj(:)' wolffd@0: logprior(e) = log_prior(CPDs{e}); wolffd@0: CPDs{e} = reset_ess(CPDs{e}); wolffd@0: end wolffd@0: wolffd@0: loglik = 0; wolffd@0: for l=1:length(cases) wolffd@0: evidence = cases{l}; wolffd@0: if ~iscell(evidence) wolffd@0: error('training data must be a cell array of cell arrays') wolffd@0: end wolffd@0: [engine, ll] = enter_evidence(engine, evidence); wolffd@0: assert(~isnan(ll)) wolffd@0: loglik = loglik + ll; wolffd@0: T = size(evidence, 2); wolffd@0: wolffd@0: % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number wolffd@0: % so that they extract evidence from the right place. wolffd@0: % (The CPD should really store its own version of ns and cnodes...) wolffd@0: ns = repmat(bnet.node_sizes_slice(:), [1 T]); wolffd@0: cnodes = unroll_set(bnet.cnodes_slice, ss, T); wolffd@0: wolffd@0: %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]); wolffd@0: hidden_bitv = zeros(ss, T); wolffd@0: hidden_bitv(isemptycell(evidence))=1; wolffd@0: % hidden_bitv(i) = 1 means node i is hidden. wolffd@0: % We pass this in, rather than using isemptycell(evidence(dom)), because wolffd@0: % isemptycell is very slow. wolffd@0: wolffd@0: t = 1; wolffd@0: for i=1:ss wolffd@0: e = bnet.equiv_class(i,1); wolffd@0: if adjustable(e) wolffd@0: fmarg = marginal_family(engine, i, t); wolffd@0: CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: for i=1:ss wolffd@0: e = bnet.equiv_class(i,2); wolffd@0: if adjustable(e) wolffd@0: for t=2:T wolffd@0: fmarg = marginal_family(engine, i, t); wolffd@0: CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: logpost = loglik + sum(logprior(:)); wolffd@0: wolffd@0: for e=adj(:)' wolffd@0: CPDs{e} = maximize_params(CPDs{e}, temp); wolffd@0: end wolffd@0: wolffd@0: engine = update_engine(engine, CPDs); wolffd@0: wolffd@0: wolffd@0: wolffd@0: