annotate toolboxes/FullBNT-1.0.7/bnt/learning/learn_params_dbn_em.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin)
wolffd@0 2 % LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM.
wolffd@0 3 % [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...)
wolffd@0 4 %
wolffd@0 5 % data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden.
wolffd@0 6 % Suppose you have L time series, each of length T, in an O*T*L array D,
wolffd@0 7 % where O is the num of observed scalar nodes, and N is the total num nodes per slice.
wolffd@0 8 % Then you can create data as follows, where onodes is the index of the observable nodes:
wolffd@0 9 % data = cell(1,L);
wolffd@0 10 % for l=1:L
wolffd@0 11 % data{l} = cell(N, T);
wolffd@0 12 % data{l}(onodes,:) = num2cell(D(:,:,l));
wolffd@0 13 % end
wolffd@0 14 % Of course it is possible for different sets of nodes to be observed in
wolffd@0 15 % each slice/ sequence, and for each sequence to be a different length.
wolffd@0 16 %
wolffd@0 17 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration.
wolffd@0 18 %
wolffd@0 19 % Optional arguments [default]
wolffd@0 20 %
wolffd@0 21 % max_iter - specifies the maximum number of iterations [100]
wolffd@0 22 % thresh - specifies the thresold for stopping EM [1e-3]
wolffd@0 23 % We stop when |f(t) - f(t-1)| / avg < threshold,
wolffd@0 24 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik.
wolffd@0 25 % verbose - display loglik at each iteration [1]
wolffd@0 26 % anneal - 1 means do deterministic annealing (only for entropic priors) [0]
wolffd@0 27 % anneal_rate - geometric cooling rate [0.8]
wolffd@0 28 % init_temp - initial annealing temperature [10]
wolffd@0 29 % final_temp - final annealing temperature [1e-3]
wolffd@0 30 %
wolffd@0 31
wolffd@0 32 max_iter = 100;
wolffd@0 33 thresh = 1e-3;
wolffd@0 34 anneal = 0;
wolffd@0 35 anneal_rate = 0.8;
wolffd@0 36 init_temp = 10;
wolffd@0 37 final_temp = 1e-3;
wolffd@0 38 verbose = 1;
wolffd@0 39
wolffd@0 40 for i=1:2:length(varargin)
wolffd@0 41 switch varargin{i}
wolffd@0 42 case 'max_iter', max_iter = varargin{i+1};
wolffd@0 43 case 'thresh', thresh = varargin{i+1};
wolffd@0 44 case 'anneal', anneal = varargin{i+1};
wolffd@0 45 case 'anneal_rate', anneal_rate = varargin{i+1};
wolffd@0 46 case 'init_temp', init_temp = varargin{i+1};
wolffd@0 47 case 'final_temp', final_temp = varargin{i+1};
wolffd@0 48 otherwise, error(['unrecognized argument' varargin{i}])
wolffd@0 49 end
wolffd@0 50 end
wolffd@0 51
wolffd@0 52 % take 1 EM step at each temperature value, then when temp=0, run to convergence
wolffd@0 53 % When using an entropic prior, Z = 1-T, so
wolffd@0 54 % T=2 => Z=-1 (max entropy)
wolffd@0 55 % T=1 => Z=0 (max likelihood)
wolffd@0 56 % T=0 => Z=1 (min entropy / max structure)
wolffd@0 57 num_iter = 1;
wolffd@0 58 LL = [];
wolffd@0 59 if anneal
wolffd@0 60 temperature = init_temp;
wolffd@0 61 while temperature > final_temp
wolffd@0 62 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
wolffd@0 63 if verbose
wolffd@0 64 fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ...
wolffd@0 65 num_iter, loglik, logpost, temperature);
wolffd@0 66 end
wolffd@0 67 num_iter = num_iter + 1;
wolffd@0 68 LL = [LL loglik];
wolffd@0 69 temperature = temperature * anneal_rate;
wolffd@0 70 end
wolffd@0 71 temperature = 0;
wolffd@0 72 previous_loglik = loglik;
wolffd@0 73 previous_logpost = logpost;
wolffd@0 74 else
wolffd@0 75 temperature = 0;
wolffd@0 76 previous_loglik = -inf;
wolffd@0 77 previous_logpost = -inf;
wolffd@0 78 end
wolffd@0 79
wolffd@0 80 converged = 0;
wolffd@0 81 while ~converged & (num_iter <= max_iter)
wolffd@0 82 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
wolffd@0 83 if verbose
wolffd@0 84 %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ...
wolffd@0 85 % num_iter, loglik, logpost);
wolffd@0 86 fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik);
wolffd@0 87 end
wolffd@0 88 num_iter = num_iter + 1;
wolffd@0 89 [converged, decreased] = em_converged(loglik, previous_loglik, thresh);
wolffd@0 90 %[converged, decreased] = em_converged(logpost, previous_logpost, thresh);
wolffd@0 91 previous_loglik = loglik;
wolffd@0 92 previous_logpost = logpost;
wolffd@0 93 LL = [LL loglik];
wolffd@0 94 end
wolffd@0 95
wolffd@0 96 bnet = bnet_from_engine(engine);
wolffd@0 97
wolffd@0 98 %%%%%%%%%
wolffd@0 99
wolffd@0 100 function [engine, loglik, logpost] = EM_step(engine, cases, temp)
wolffd@0 101
wolffd@0 102 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step
wolffd@0 103 ss = length(bnet.intra);
wolffd@0 104 CPDs = bnet.CPD; % these are the new params that get maximized
wolffd@0 105 num_CPDs = length(CPDs);
wolffd@0 106
wolffd@0 107 % log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D))
wolffd@0 108 % where log P(D|theta) = sum_cases log P(case|theta)
wolffd@0 109 % and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied!
wolffd@0 110 % logpost = log P(theta,D) (un-normalized)
wolffd@0 111 % This should be negative, and increase at every step.
wolffd@0 112
wolffd@0 113 adjustable = zeros(1,num_CPDs);
wolffd@0 114 logprior = zeros(1, num_CPDs);
wolffd@0 115 for e=1:num_CPDs
wolffd@0 116 adjustable(e) = adjustable_CPD(CPDs{e});
wolffd@0 117 end
wolffd@0 118 adj = find(adjustable);
wolffd@0 119
wolffd@0 120 for e=adj(:)'
wolffd@0 121 logprior(e) = log_prior(CPDs{e});
wolffd@0 122 CPDs{e} = reset_ess(CPDs{e});
wolffd@0 123 end
wolffd@0 124
wolffd@0 125 loglik = 0;
wolffd@0 126 for l=1:length(cases)
wolffd@0 127 evidence = cases{l};
wolffd@0 128 if ~iscell(evidence)
wolffd@0 129 error('training data must be a cell array of cell arrays')
wolffd@0 130 end
wolffd@0 131 [engine, ll] = enter_evidence(engine, evidence);
wolffd@0 132 assert(~isnan(ll))
wolffd@0 133 loglik = loglik + ll;
wolffd@0 134 T = size(evidence, 2);
wolffd@0 135
wolffd@0 136 % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number
wolffd@0 137 % so that they extract evidence from the right place.
wolffd@0 138 % (The CPD should really store its own version of ns and cnodes...)
wolffd@0 139 ns = repmat(bnet.node_sizes_slice(:), [1 T]);
wolffd@0 140 cnodes = unroll_set(bnet.cnodes_slice, ss, T);
wolffd@0 141
wolffd@0 142 %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]);
wolffd@0 143 hidden_bitv = zeros(ss, T);
wolffd@0 144 hidden_bitv(isemptycell(evidence))=1;
wolffd@0 145 % hidden_bitv(i) = 1 means node i is hidden.
wolffd@0 146 % We pass this in, rather than using isemptycell(evidence(dom)), because
wolffd@0 147 % isemptycell is very slow.
wolffd@0 148
wolffd@0 149 t = 1;
wolffd@0 150 for i=1:ss
wolffd@0 151 e = bnet.equiv_class(i,1);
wolffd@0 152 if adjustable(e)
wolffd@0 153 fmarg = marginal_family(engine, i, t);
wolffd@0 154 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
wolffd@0 155 end
wolffd@0 156 end
wolffd@0 157
wolffd@0 158 for i=1:ss
wolffd@0 159 e = bnet.equiv_class(i,2);
wolffd@0 160 if adjustable(e)
wolffd@0 161 for t=2:T
wolffd@0 162 fmarg = marginal_family(engine, i, t);
wolffd@0 163 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
wolffd@0 164 end
wolffd@0 165 end
wolffd@0 166 end
wolffd@0 167 end
wolffd@0 168
wolffd@0 169 logpost = loglik + sum(logprior(:));
wolffd@0 170
wolffd@0 171 for e=adj(:)'
wolffd@0 172 CPDs{e} = maximize_params(CPDs{e}, temp);
wolffd@0 173 end
wolffd@0 174
wolffd@0 175 engine = update_engine(engine, CPDs);
wolffd@0 176
wolffd@0 177
wolffd@0 178
wolffd@0 179