annotate _FullBNT/BNT/learning/learn_params_em2.m @ 9:4ea6619cb3f5 tip

removed log files
author matthiasm
date Fri, 11 Apr 2014 15:55:11 +0100
parents b5b38998ef3b
children
rev   line source
matthiasm@8 1 function [bnet, LL, engine] = learn_params_em(engine, evidence, max_iter, thresh)
matthiasm@8 2 % LEARN_PARAMS_EM Set the parameters of each adjustable node to their ML/MAP values using batch EM.
matthiasm@8 3 % [bnet, LLtrace, engine] = learn_params_em(engine, data, max_iter, thresh)
matthiasm@8 4 %
matthiasm@8 5 % data{i,l} is the value of node i in case l, or [] if hidden.
matthiasm@8 6 % Suppose you have L training cases in an O*L array, D, where O is the num observed
matthiasm@8 7 % scalar nodes, and N is the total num nodes.
matthiasm@8 8 % Then you can create 'data' as follows, where onodes is the index of the observable nodes:
matthiasm@8 9 % data = cell(N, L);
matthiasm@8 10 % data(onodes,:) = num2cell(D);
matthiasm@8 11 % Of course it is possible for different sets of nodes to be observed in each case.
matthiasm@8 12 %
matthiasm@8 13 % We return the modified bnet and engine.
matthiasm@8 14 % To see the learned parameters for node i, use the construct
matthiasm@8 15 % s = struct(bnet.CPD{i}); % violate object privacy
matthiasm@8 16 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration.
matthiasm@8 17 %
matthiasm@8 18 % max_iter specifies the maximum number of iterations. Default: 10.
matthiasm@8 19 %
matthiasm@8 20 % thresh specifies the thresold for stopping EM. Default: 1e-3.
matthiasm@8 21 % We stop when |f(t) - f(t-1)| / avg < threshold,
matthiasm@8 22 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik.
matthiasm@8 23
matthiasm@8 24 if nargin < 3, max_iter = 10; end
matthiasm@8 25 if nargin < 4, thresh = 1e-3; end
matthiasm@8 26
matthiasm@8 27 verbose = 0;
matthiasm@8 28
matthiasm@8 29 loglik = 0;
matthiasm@8 30 previous_loglik = -inf;
matthiasm@8 31 converged = 0;
matthiasm@8 32 num_iter = 1;
matthiasm@8 33 LL = [];
matthiasm@8 34
matthiasm@8 35 while ~converged & (num_iter <= max_iter)
matthiasm@8 36 [engine, loglik] = EM_step(engine, evidence);
matthiasm@8 37 if verbose, fprintf('EM iteration %d, ll = %8.4f\n', num_iter, loglik); end
matthiasm@8 38 num_iter = num_iter + 1;
matthiasm@8 39 converged = em_converged(loglik, previous_loglik, thresh);
matthiasm@8 40 previous_loglik = loglik;
matthiasm@8 41 LL = [LL loglik];
matthiasm@8 42 end
matthiasm@8 43 if verbose, fprintf('\n'); end
matthiasm@8 44
matthiasm@8 45 bnet = bnet_from_engine(engine);
matthiasm@8 46
matthiasm@8 47 %%%%%%%%%
matthiasm@8 48
matthiasm@8 49 function [engine, loglik] = EM_step(engine, cases)
matthiasm@8 50
matthiasm@8 51 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step
matthiasm@8 52 CPDs = bnet.CPD; % these are the new params that get maximized
matthiasm@8 53 num_CPDs = length(CPDs);
matthiasm@8 54 adjustable = zeros(1,num_CPDs);
matthiasm@8 55 for e=1:num_CPDs
matthiasm@8 56 adjustable(e) = adjustable_CPD(CPDs{e});
matthiasm@8 57 end
matthiasm@8 58 adj = find(adjustable);
matthiasm@8 59 n = length(bnet.dag);
matthiasm@8 60
matthiasm@8 61 for e=adj(:)'
matthiasm@8 62 CPDs{e} = reset_ess(CPDs{e});
matthiasm@8 63 end
matthiasm@8 64
matthiasm@8 65 loglik = 0;
matthiasm@8 66 ncases = size(cases, 2);
matthiasm@8 67 for l=1:ncases
matthiasm@8 68 evidence = cases(:,l);
matthiasm@8 69 [engine, ll] = enter_evidence(engine, evidence);
matthiasm@8 70 loglik = loglik + ll;
matthiasm@8 71 hidden_bitv = zeros(1,n);
matthiasm@8 72 hidden_bitv(isemptycell(evidence))=1;
matthiasm@8 73 for i=1:n
matthiasm@8 74 e = bnet.equiv_class(i);
matthiasm@8 75 if adjustable(e)
matthiasm@8 76 fmarg = marginal_family(engine, i);
matthiasm@8 77 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, bnet.node_sizes, bnet.cnodes, hidden_bitv);
matthiasm@8 78 end
matthiasm@8 79 end
matthiasm@8 80 end
matthiasm@8 81
matthiasm@8 82 for e=adj(:)'
matthiasm@8 83 CPDs{e} = maximize_params(CPDs{e});
matthiasm@8 84 end
matthiasm@8 85 one_mean = get_field(CPDs{1},'mean');
matthiasm@8 86 one_cov = get_field(CPDs{1},'cov');
matthiasm@8 87 CPDs{1} = gaussian_CPD(bnet,1,'mean',max(one_mean,0),'cov',one_cov,'cov_type', 'diag', ...
matthiasm@8 88 'clamp_cov',1);
matthiasm@8 89 engine = update_engine(engine, CPDs);
matthiasm@8 90
matthiasm@8 91