annotate toolboxes/FullBNT-1.0.7/bnt/learning/learn_params_dbn_em.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
rev   line source
Daniel@0 1 function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin)
Daniel@0 2 % LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM.
Daniel@0 3 % [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...)
Daniel@0 4 %
Daniel@0 5 % data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden.
Daniel@0 6 % Suppose you have L time series, each of length T, in an O*T*L array D,
Daniel@0 7 % where O is the num of observed scalar nodes, and N is the total num nodes per slice.
Daniel@0 8 % Then you can create data as follows, where onodes is the index of the observable nodes:
Daniel@0 9 % data = cell(1,L);
Daniel@0 10 % for l=1:L
Daniel@0 11 % data{l} = cell(N, T);
Daniel@0 12 % data{l}(onodes,:) = num2cell(D(:,:,l));
Daniel@0 13 % end
Daniel@0 14 % Of course it is possible for different sets of nodes to be observed in
Daniel@0 15 % each slice/ sequence, and for each sequence to be a different length.
Daniel@0 16 %
Daniel@0 17 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration.
Daniel@0 18 %
Daniel@0 19 % Optional arguments [default]
Daniel@0 20 %
Daniel@0 21 % max_iter - specifies the maximum number of iterations [100]
Daniel@0 22 % thresh - specifies the thresold for stopping EM [1e-3]
Daniel@0 23 % We stop when |f(t) - f(t-1)| / avg < threshold,
Daniel@0 24 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik.
Daniel@0 25 % verbose - display loglik at each iteration [1]
Daniel@0 26 % anneal - 1 means do deterministic annealing (only for entropic priors) [0]
Daniel@0 27 % anneal_rate - geometric cooling rate [0.8]
Daniel@0 28 % init_temp - initial annealing temperature [10]
Daniel@0 29 % final_temp - final annealing temperature [1e-3]
Daniel@0 30 %
Daniel@0 31
Daniel@0 32 max_iter = 100;
Daniel@0 33 thresh = 1e-3;
Daniel@0 34 anneal = 0;
Daniel@0 35 anneal_rate = 0.8;
Daniel@0 36 init_temp = 10;
Daniel@0 37 final_temp = 1e-3;
Daniel@0 38 verbose = 1;
Daniel@0 39
Daniel@0 40 for i=1:2:length(varargin)
Daniel@0 41 switch varargin{i}
Daniel@0 42 case 'max_iter', max_iter = varargin{i+1};
Daniel@0 43 case 'thresh', thresh = varargin{i+1};
Daniel@0 44 case 'anneal', anneal = varargin{i+1};
Daniel@0 45 case 'anneal_rate', anneal_rate = varargin{i+1};
Daniel@0 46 case 'init_temp', init_temp = varargin{i+1};
Daniel@0 47 case 'final_temp', final_temp = varargin{i+1};
Daniel@0 48 otherwise, error(['unrecognized argument' varargin{i}])
Daniel@0 49 end
Daniel@0 50 end
Daniel@0 51
Daniel@0 52 % take 1 EM step at each temperature value, then when temp=0, run to convergence
Daniel@0 53 % When using an entropic prior, Z = 1-T, so
Daniel@0 54 % T=2 => Z=-1 (max entropy)
Daniel@0 55 % T=1 => Z=0 (max likelihood)
Daniel@0 56 % T=0 => Z=1 (min entropy / max structure)
Daniel@0 57 num_iter = 1;
Daniel@0 58 LL = [];
Daniel@0 59 if anneal
Daniel@0 60 temperature = init_temp;
Daniel@0 61 while temperature > final_temp
Daniel@0 62 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
Daniel@0 63 if verbose
Daniel@0 64 fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ...
Daniel@0 65 num_iter, loglik, logpost, temperature);
Daniel@0 66 end
Daniel@0 67 num_iter = num_iter + 1;
Daniel@0 68 LL = [LL loglik];
Daniel@0 69 temperature = temperature * anneal_rate;
Daniel@0 70 end
Daniel@0 71 temperature = 0;
Daniel@0 72 previous_loglik = loglik;
Daniel@0 73 previous_logpost = logpost;
Daniel@0 74 else
Daniel@0 75 temperature = 0;
Daniel@0 76 previous_loglik = -inf;
Daniel@0 77 previous_logpost = -inf;
Daniel@0 78 end
Daniel@0 79
Daniel@0 80 converged = 0;
Daniel@0 81 while ~converged & (num_iter <= max_iter)
Daniel@0 82 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
Daniel@0 83 if verbose
Daniel@0 84 %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ...
Daniel@0 85 % num_iter, loglik, logpost);
Daniel@0 86 fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik);
Daniel@0 87 end
Daniel@0 88 num_iter = num_iter + 1;
Daniel@0 89 [converged, decreased] = em_converged(loglik, previous_loglik, thresh);
Daniel@0 90 %[converged, decreased] = em_converged(logpost, previous_logpost, thresh);
Daniel@0 91 previous_loglik = loglik;
Daniel@0 92 previous_logpost = logpost;
Daniel@0 93 LL = [LL loglik];
Daniel@0 94 end
Daniel@0 95
Daniel@0 96 bnet = bnet_from_engine(engine);
Daniel@0 97
Daniel@0 98 %%%%%%%%%
Daniel@0 99
Daniel@0 100 function [engine, loglik, logpost] = EM_step(engine, cases, temp)
Daniel@0 101
Daniel@0 102 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step
Daniel@0 103 ss = length(bnet.intra);
Daniel@0 104 CPDs = bnet.CPD; % these are the new params that get maximized
Daniel@0 105 num_CPDs = length(CPDs);
Daniel@0 106
Daniel@0 107 % log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D))
Daniel@0 108 % where log P(D|theta) = sum_cases log P(case|theta)
Daniel@0 109 % and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied!
Daniel@0 110 % logpost = log P(theta,D) (un-normalized)
Daniel@0 111 % This should be negative, and increase at every step.
Daniel@0 112
Daniel@0 113 adjustable = zeros(1,num_CPDs);
Daniel@0 114 logprior = zeros(1, num_CPDs);
Daniel@0 115 for e=1:num_CPDs
Daniel@0 116 adjustable(e) = adjustable_CPD(CPDs{e});
Daniel@0 117 end
Daniel@0 118 adj = find(adjustable);
Daniel@0 119
Daniel@0 120 for e=adj(:)'
Daniel@0 121 logprior(e) = log_prior(CPDs{e});
Daniel@0 122 CPDs{e} = reset_ess(CPDs{e});
Daniel@0 123 end
Daniel@0 124
Daniel@0 125 loglik = 0;
Daniel@0 126 for l=1:length(cases)
Daniel@0 127 evidence = cases{l};
Daniel@0 128 if ~iscell(evidence)
Daniel@0 129 error('training data must be a cell array of cell arrays')
Daniel@0 130 end
Daniel@0 131 [engine, ll] = enter_evidence(engine, evidence);
Daniel@0 132 assert(~isnan(ll))
Daniel@0 133 loglik = loglik + ll;
Daniel@0 134 T = size(evidence, 2);
Daniel@0 135
Daniel@0 136 % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number
Daniel@0 137 % so that they extract evidence from the right place.
Daniel@0 138 % (The CPD should really store its own version of ns and cnodes...)
Daniel@0 139 ns = repmat(bnet.node_sizes_slice(:), [1 T]);
Daniel@0 140 cnodes = unroll_set(bnet.cnodes_slice, ss, T);
Daniel@0 141
Daniel@0 142 %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]);
Daniel@0 143 hidden_bitv = zeros(ss, T);
Daniel@0 144 hidden_bitv(isemptycell(evidence))=1;
Daniel@0 145 % hidden_bitv(i) = 1 means node i is hidden.
Daniel@0 146 % We pass this in, rather than using isemptycell(evidence(dom)), because
Daniel@0 147 % isemptycell is very slow.
Daniel@0 148
Daniel@0 149 t = 1;
Daniel@0 150 for i=1:ss
Daniel@0 151 e = bnet.equiv_class(i,1);
Daniel@0 152 if adjustable(e)
Daniel@0 153 fmarg = marginal_family(engine, i, t);
Daniel@0 154 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
Daniel@0 155 end
Daniel@0 156 end
Daniel@0 157
Daniel@0 158 for i=1:ss
Daniel@0 159 e = bnet.equiv_class(i,2);
Daniel@0 160 if adjustable(e)
Daniel@0 161 for t=2:T
Daniel@0 162 fmarg = marginal_family(engine, i, t);
Daniel@0 163 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
Daniel@0 164 end
Daniel@0 165 end
Daniel@0 166 end
Daniel@0 167 end
Daniel@0 168
Daniel@0 169 logpost = loglik + sum(logprior(:));
Daniel@0 170
Daniel@0 171 for e=adj(:)'
Daniel@0 172 CPDs{e} = maximize_params(CPDs{e}, temp);
Daniel@0 173 end
Daniel@0 174
Daniel@0 175 engine = update_engine(engine, CPDs);
Daniel@0 176
Daniel@0 177
Daniel@0 178
Daniel@0 179