comparison toolboxes/FullBNT-1.0.7/bnt/learning/learn_params_dbn_em.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin)
2 % LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM.
3 % [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...)
4 %
5 % data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden.
6 % Suppose you have L time series, each of length T, in an O*T*L array D,
7 % where O is the num of observed scalar nodes, and N is the total num nodes per slice.
8 % Then you can create data as follows, where onodes is the index of the observable nodes:
9 % data = cell(1,L);
10 % for l=1:L
11 % data{l} = cell(N, T);
12 % data{l}(onodes,:) = num2cell(D(:,:,l));
13 % end
14 % Of course it is possible for different sets of nodes to be observed in
15 % each slice/ sequence, and for each sequence to be a different length.
16 %
17 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration.
18 %
19 % Optional arguments [default]
20 %
21 % max_iter - specifies the maximum number of iterations [100]
22 % thresh - specifies the thresold for stopping EM [1e-3]
23 % We stop when |f(t) - f(t-1)| / avg < threshold,
24 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik.
25 % verbose - display loglik at each iteration [1]
26 % anneal - 1 means do deterministic annealing (only for entropic priors) [0]
27 % anneal_rate - geometric cooling rate [0.8]
28 % init_temp - initial annealing temperature [10]
29 % final_temp - final annealing temperature [1e-3]
30 %
31
32 max_iter = 100;
33 thresh = 1e-3;
34 anneal = 0;
35 anneal_rate = 0.8;
36 init_temp = 10;
37 final_temp = 1e-3;
38 verbose = 1;
39
40 for i=1:2:length(varargin)
41 switch varargin{i}
42 case 'max_iter', max_iter = varargin{i+1};
43 case 'thresh', thresh = varargin{i+1};
44 case 'anneal', anneal = varargin{i+1};
45 case 'anneal_rate', anneal_rate = varargin{i+1};
46 case 'init_temp', init_temp = varargin{i+1};
47 case 'final_temp', final_temp = varargin{i+1};
48 otherwise, error(['unrecognized argument' varargin{i}])
49 end
50 end
51
52 % take 1 EM step at each temperature value, then when temp=0, run to convergence
53 % When using an entropic prior, Z = 1-T, so
54 % T=2 => Z=-1 (max entropy)
55 % T=1 => Z=0 (max likelihood)
56 % T=0 => Z=1 (min entropy / max structure)
57 num_iter = 1;
58 LL = [];
59 if anneal
60 temperature = init_temp;
61 while temperature > final_temp
62 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
63 if verbose
64 fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ...
65 num_iter, loglik, logpost, temperature);
66 end
67 num_iter = num_iter + 1;
68 LL = [LL loglik];
69 temperature = temperature * anneal_rate;
70 end
71 temperature = 0;
72 previous_loglik = loglik;
73 previous_logpost = logpost;
74 else
75 temperature = 0;
76 previous_loglik = -inf;
77 previous_logpost = -inf;
78 end
79
80 converged = 0;
81 while ~converged & (num_iter <= max_iter)
82 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
83 if verbose
84 %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ...
85 % num_iter, loglik, logpost);
86 fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik);
87 end
88 num_iter = num_iter + 1;
89 [converged, decreased] = em_converged(loglik, previous_loglik, thresh);
90 %[converged, decreased] = em_converged(logpost, previous_logpost, thresh);
91 previous_loglik = loglik;
92 previous_logpost = logpost;
93 LL = [LL loglik];
94 end
95
96 bnet = bnet_from_engine(engine);
97
98 %%%%%%%%%
99
100 function [engine, loglik, logpost] = EM_step(engine, cases, temp)
101
102 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step
103 ss = length(bnet.intra);
104 CPDs = bnet.CPD; % these are the new params that get maximized
105 num_CPDs = length(CPDs);
106
107 % log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D))
108 % where log P(D|theta) = sum_cases log P(case|theta)
109 % and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied!
110 % logpost = log P(theta,D) (un-normalized)
111 % This should be negative, and increase at every step.
112
113 adjustable = zeros(1,num_CPDs);
114 logprior = zeros(1, num_CPDs);
115 for e=1:num_CPDs
116 adjustable(e) = adjustable_CPD(CPDs{e});
117 end
118 adj = find(adjustable);
119
120 for e=adj(:)'
121 logprior(e) = log_prior(CPDs{e});
122 CPDs{e} = reset_ess(CPDs{e});
123 end
124
125 loglik = 0;
126 for l=1:length(cases)
127 evidence = cases{l};
128 if ~iscell(evidence)
129 error('training data must be a cell array of cell arrays')
130 end
131 [engine, ll] = enter_evidence(engine, evidence);
132 assert(~isnan(ll))
133 loglik = loglik + ll;
134 T = size(evidence, 2);
135
136 % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number
137 % so that they extract evidence from the right place.
138 % (The CPD should really store its own version of ns and cnodes...)
139 ns = repmat(bnet.node_sizes_slice(:), [1 T]);
140 cnodes = unroll_set(bnet.cnodes_slice, ss, T);
141
142 %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]);
143 hidden_bitv = zeros(ss, T);
144 hidden_bitv(isemptycell(evidence))=1;
145 % hidden_bitv(i) = 1 means node i is hidden.
146 % We pass this in, rather than using isemptycell(evidence(dom)), because
147 % isemptycell is very slow.
148
149 t = 1;
150 for i=1:ss
151 e = bnet.equiv_class(i,1);
152 if adjustable(e)
153 fmarg = marginal_family(engine, i, t);
154 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
155 end
156 end
157
158 for i=1:ss
159 e = bnet.equiv_class(i,2);
160 if adjustable(e)
161 for t=2:T
162 fmarg = marginal_family(engine, i, t);
163 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
164 end
165 end
166 end
167 end
168
169 logpost = loglik + sum(logprior(:));
170
171 for e=adj(:)'
172 CPDs{e} = maximize_params(CPDs{e}, temp);
173 end
174
175 engine = update_engine(engine, CPDs);
176
177
178
179