annotate _FullBNT/BNT/learning/learn_params_dbn_em.m @ 9:4ea6619cb3f5 tip

removed log files
author matthiasm
date Fri, 11 Apr 2014 15:55:11 +0100
parents b5b38998ef3b
children
rev   line source
matthiasm@8 1 function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin)
matthiasm@8 2 % LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM.
matthiasm@8 3 % [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...)
matthiasm@8 4 %
matthiasm@8 5 % data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden.
matthiasm@8 6 % Suppose you have L time series, each of length T, in an O*T*L array D,
matthiasm@8 7 % where O is the num of observed scalar nodes, and N is the total num nodes per slice.
matthiasm@8 8 % Then you can create data as follows, where onodes is the index of the observable nodes:
matthiasm@8 9 % data = cell(1,L);
matthiasm@8 10 % for l=1:L
matthiasm@8 11 % data{l} = cell(N, T);
matthiasm@8 12 % data{l}(onodes,:) = num2cell(D(:,:,l));
matthiasm@8 13 % end
matthiasm@8 14 % Of course it is possible for different sets of nodes to be observed in
matthiasm@8 15 % each slice/ sequence, and for each sequence to be a different length.
matthiasm@8 16 %
matthiasm@8 17 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration.
matthiasm@8 18 %
matthiasm@8 19 % Optional arguments [default]
matthiasm@8 20 %
matthiasm@8 21 % max_iter - specifies the maximum number of iterations [100]
matthiasm@8 22 % thresh - specifies the thresold for stopping EM [1e-3]
matthiasm@8 23 % We stop when |f(t) - f(t-1)| / avg < threshold,
matthiasm@8 24 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik.
matthiasm@8 25 % verbose - display loglik at each iteration [1]
matthiasm@8 26 % anneal - 1 means do deterministic annealing (only for entropic priors) [0]
matthiasm@8 27 % anneal_rate - geometric cooling rate [0.8]
matthiasm@8 28 % init_temp - initial annealing temperature [10]
matthiasm@8 29 % final_temp - final annealing temperature [1e-3]
matthiasm@8 30 %
matthiasm@8 31
matthiasm@8 32 max_iter = 100;
matthiasm@8 33 thresh = 1e-3;
matthiasm@8 34 anneal = 0;
matthiasm@8 35 anneal_rate = 0.8;
matthiasm@8 36 init_temp = 10;
matthiasm@8 37 final_temp = 1e-3;
matthiasm@8 38 verbose = 1;
matthiasm@8 39
matthiasm@8 40 for i=1:2:length(varargin)
matthiasm@8 41 switch varargin{i}
matthiasm@8 42 case 'max_iter', max_iter = varargin{i+1};
matthiasm@8 43 case 'thresh', thresh = varargin{i+1};
matthiasm@8 44 case 'anneal', anneal = varargin{i+1};
matthiasm@8 45 case 'anneal_rate', anneal_rate = varargin{i+1};
matthiasm@8 46 case 'init_temp', init_temp = varargin{i+1};
matthiasm@8 47 case 'final_temp', final_temp = varargin{i+1};
matthiasm@8 48 otherwise, error(['unrecognized argument' varargin{i}])
matthiasm@8 49 end
matthiasm@8 50 end
matthiasm@8 51
matthiasm@8 52 % take 1 EM step at each temperature value, then when temp=0, run to convergence
matthiasm@8 53 % When using an entropic prior, Z = 1-T, so
matthiasm@8 54 % T=2 => Z=-1 (max entropy)
matthiasm@8 55 % T=1 => Z=0 (max likelihood)
matthiasm@8 56 % T=0 => Z=1 (min entropy / max structure)
matthiasm@8 57 num_iter = 1;
matthiasm@8 58 LL = [];
matthiasm@8 59 if anneal
matthiasm@8 60 temperature = init_temp;
matthiasm@8 61 while temperature > final_temp
matthiasm@8 62 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
matthiasm@8 63 if verbose
matthiasm@8 64 fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ...
matthiasm@8 65 num_iter, loglik, logpost, temperature);
matthiasm@8 66 end
matthiasm@8 67 num_iter = num_iter + 1;
matthiasm@8 68 LL = [LL loglik];
matthiasm@8 69 temperature = temperature * anneal_rate;
matthiasm@8 70 end
matthiasm@8 71 temperature = 0;
matthiasm@8 72 previous_loglik = loglik;
matthiasm@8 73 previous_logpost = logpost;
matthiasm@8 74 else
matthiasm@8 75 temperature = 0;
matthiasm@8 76 previous_loglik = -inf;
matthiasm@8 77 previous_logpost = -inf;
matthiasm@8 78 end
matthiasm@8 79
matthiasm@8 80 converged = 0;
matthiasm@8 81 while ~converged & (num_iter <= max_iter)
matthiasm@8 82 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
matthiasm@8 83 if verbose
matthiasm@8 84 %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ...
matthiasm@8 85 % num_iter, loglik, logpost);
matthiasm@8 86 fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik);
matthiasm@8 87 end
matthiasm@8 88 num_iter = num_iter + 1;
matthiasm@8 89 [converged, decreased] = em_converged(loglik, previous_loglik, thresh);
matthiasm@8 90 %[converged, decreased] = em_converged(logpost, previous_logpost, thresh);
matthiasm@8 91 previous_loglik = loglik;
matthiasm@8 92 previous_logpost = logpost;
matthiasm@8 93 LL = [LL loglik];
matthiasm@8 94 end
matthiasm@8 95
matthiasm@8 96 bnet = bnet_from_engine(engine);
matthiasm@8 97
matthiasm@8 98 %%%%%%%%%
matthiasm@8 99
matthiasm@8 100 function [engine, loglik, logpost] = EM_step(engine, cases, temp)
matthiasm@8 101
matthiasm@8 102 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step
matthiasm@8 103 ss = length(bnet.intra);
matthiasm@8 104 CPDs = bnet.CPD; % these are the new params that get maximized
matthiasm@8 105 num_CPDs = length(CPDs);
matthiasm@8 106
matthiasm@8 107 % log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D))
matthiasm@8 108 % where log P(D|theta) = sum_cases log P(case|theta)
matthiasm@8 109 % and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied!
matthiasm@8 110 % logpost = log P(theta,D) (un-normalized)
matthiasm@8 111 % This should be negative, and increase at every step.
matthiasm@8 112
matthiasm@8 113 adjustable = zeros(1,num_CPDs);
matthiasm@8 114 logprior = zeros(1, num_CPDs);
matthiasm@8 115 for e=1:num_CPDs
matthiasm@8 116 adjustable(e) = adjustable_CPD(CPDs{e});
matthiasm@8 117 end
matthiasm@8 118 adj = find(adjustable);
matthiasm@8 119
matthiasm@8 120 for e=adj(:)'
matthiasm@8 121 logprior(e) = log_prior(CPDs{e});
matthiasm@8 122 CPDs{e} = reset_ess(CPDs{e});
matthiasm@8 123 end
matthiasm@8 124
matthiasm@8 125 loglik = 0;
matthiasm@8 126 for l=1:length(cases)
matthiasm@8 127 evidence = cases{l};
matthiasm@8 128 if ~iscell(evidence)
matthiasm@8 129 error('training data must be a cell array of cell arrays')
matthiasm@8 130 end
matthiasm@8 131 [engine, ll] = enter_evidence(engine, evidence);
matthiasm@8 132 assert(~isnan(ll))
matthiasm@8 133 loglik = loglik + ll;
matthiasm@8 134 T = size(evidence, 2);
matthiasm@8 135
matthiasm@8 136 % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number
matthiasm@8 137 % so that they extract evidence from the right place.
matthiasm@8 138 % (The CPD should really store its own version of ns and cnodes...)
matthiasm@8 139 ns = repmat(bnet.node_sizes_slice(:), [1 T]);
matthiasm@8 140 cnodes = unroll_set(bnet.cnodes_slice, ss, T);
matthiasm@8 141
matthiasm@8 142 %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]);
matthiasm@8 143 hidden_bitv = zeros(ss, T);
matthiasm@8 144 hidden_bitv(isemptycell(evidence))=1;
matthiasm@8 145 % hidden_bitv(i) = 1 means node i is hidden.
matthiasm@8 146 % We pass this in, rather than using isemptycell(evidence(dom)), because
matthiasm@8 147 % isemptycell is very slow.
matthiasm@8 148
matthiasm@8 149 t = 1;
matthiasm@8 150 for i=1:ss
matthiasm@8 151 e = bnet.equiv_class(i,1);
matthiasm@8 152 if adjustable(e)
matthiasm@8 153 fmarg = marginal_family(engine, i, t);
matthiasm@8 154 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
matthiasm@8 155 end
matthiasm@8 156 end
matthiasm@8 157
matthiasm@8 158 for i=1:ss
matthiasm@8 159 e = bnet.equiv_class(i,2);
matthiasm@8 160 if adjustable(e)
matthiasm@8 161 for t=2:T
matthiasm@8 162 fmarg = marginal_family(engine, i, t);
matthiasm@8 163 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
matthiasm@8 164 end
matthiasm@8 165 end
matthiasm@8 166 end
matthiasm@8 167 end
matthiasm@8 168
matthiasm@8 169 logpost = loglik + sum(logprior(:));
matthiasm@8 170
matthiasm@8 171 for e=adj(:)'
matthiasm@8 172 CPDs{e} = maximize_params(CPDs{e}, temp);
matthiasm@8 173 end
matthiasm@8 174
matthiasm@8 175 engine = update_engine(engine, CPDs);
matthiasm@8 176
matthiasm@8 177
matthiasm@8 178
matthiasm@8 179