Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/learning/learn_params_dbn_em.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/learning/learn_params_dbn_em.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,179 @@ +function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin) +% LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM. +% [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...) +% +% data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden. +% Suppose you have L time series, each of length T, in an O*T*L array D, +% where O is the num of observed scalar nodes, and N is the total num nodes per slice. +% Then you can create data as follows, where onodes is the index of the observable nodes: +% data = cell(1,L); +% for l=1:L +% data{l} = cell(N, T); +% data{l}(onodes,:) = num2cell(D(:,:,l)); +% end +% Of course it is possible for different sets of nodes to be observed in +% each slice/ sequence, and for each sequence to be a different length. +% +% LLtrace is the learning curve: the vector of log-likelihood scores at each iteration. +% +% Optional arguments [default] +% +% max_iter - specifies the maximum number of iterations [100] +% thresh - specifies the thresold for stopping EM [1e-3] +% We stop when |f(t) - f(t-1)| / avg < threshold, +% where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik. +% verbose - display loglik at each iteration [1] +% anneal - 1 means do deterministic annealing (only for entropic priors) [0] +% anneal_rate - geometric cooling rate [0.8] +% init_temp - initial annealing temperature [10] +% final_temp - final annealing temperature [1e-3] +% + +max_iter = 100; +thresh = 1e-3; +anneal = 0; +anneal_rate = 0.8; +init_temp = 10; +final_temp = 1e-3; +verbose = 1; + +for i=1:2:length(varargin) + switch varargin{i} + case 'max_iter', max_iter = varargin{i+1}; + case 'thresh', thresh = varargin{i+1}; + case 'anneal', anneal = varargin{i+1}; + case 'anneal_rate', anneal_rate = varargin{i+1}; + case 'init_temp', init_temp = varargin{i+1}; + case 'final_temp', final_temp = varargin{i+1}; + otherwise, error(['unrecognized argument' varargin{i}]) + end +end + +% take 1 EM step at each temperature value, then when temp=0, run to convergence +% When using an entropic prior, Z = 1-T, so +% T=2 => Z=-1 (max entropy) +% T=1 => Z=0 (max likelihood) +% T=0 => Z=1 (min entropy / max structure) +num_iter = 1; +LL = []; +if anneal + temperature = init_temp; + while temperature > final_temp + [engine, loglik, logpost] = EM_step(engine, evidence, temperature); + if verbose + fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ... + num_iter, loglik, logpost, temperature); + end + num_iter = num_iter + 1; + LL = [LL loglik]; + temperature = temperature * anneal_rate; + end + temperature = 0; + previous_loglik = loglik; + previous_logpost = logpost; +else + temperature = 0; + previous_loglik = -inf; + previous_logpost = -inf; +end + +converged = 0; +while ~converged & (num_iter <= max_iter) + [engine, loglik, logpost] = EM_step(engine, evidence, temperature); + if verbose + %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ... + % num_iter, loglik, logpost); + fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik); + end + num_iter = num_iter + 1; + [converged, decreased] = em_converged(loglik, previous_loglik, thresh); + %[converged, decreased] = em_converged(logpost, previous_logpost, thresh); + previous_loglik = loglik; + previous_logpost = logpost; + LL = [LL loglik]; +end + +bnet = bnet_from_engine(engine); + +%%%%%%%%% + +function [engine, loglik, logpost] = EM_step(engine, cases, temp) + +bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step +ss = length(bnet.intra); +CPDs = bnet.CPD; % these are the new params that get maximized +num_CPDs = length(CPDs); + +% log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D)) +% where log P(D|theta) = sum_cases log P(case|theta) +% and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied! +% logpost = log P(theta,D) (un-normalized) +% This should be negative, and increase at every step. + +adjustable = zeros(1,num_CPDs); +logprior = zeros(1, num_CPDs); +for e=1:num_CPDs + adjustable(e) = adjustable_CPD(CPDs{e}); +end +adj = find(adjustable); + +for e=adj(:)' + logprior(e) = log_prior(CPDs{e}); + CPDs{e} = reset_ess(CPDs{e}); +end + +loglik = 0; +for l=1:length(cases) + evidence = cases{l}; + if ~iscell(evidence) + error('training data must be a cell array of cell arrays') + end + [engine, ll] = enter_evidence(engine, evidence); + assert(~isnan(ll)) + loglik = loglik + ll; + T = size(evidence, 2); + + % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number + % so that they extract evidence from the right place. + % (The CPD should really store its own version of ns and cnodes...) + ns = repmat(bnet.node_sizes_slice(:), [1 T]); + cnodes = unroll_set(bnet.cnodes_slice, ss, T); + + %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]); + hidden_bitv = zeros(ss, T); + hidden_bitv(isemptycell(evidence))=1; + % hidden_bitv(i) = 1 means node i is hidden. + % We pass this in, rather than using isemptycell(evidence(dom)), because + % isemptycell is very slow. + + t = 1; + for i=1:ss + e = bnet.equiv_class(i,1); + if adjustable(e) + fmarg = marginal_family(engine, i, t); + CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); + end + end + + for i=1:ss + e = bnet.equiv_class(i,2); + if adjustable(e) + for t=2:T + fmarg = marginal_family(engine, i, t); + CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); + end + end + end +end + +logpost = loglik + sum(logprior(:)); + +for e=adj(:)' + CPDs{e} = maximize_params(CPDs{e}, temp); +end + +engine = update_engine(engine, CPDs); + + + +