wolffd@0
|
1 function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin)
|
wolffd@0
|
2 % LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM.
|
wolffd@0
|
3 % [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden.
|
wolffd@0
|
6 % Suppose you have L time series, each of length T, in an O*T*L array D,
|
wolffd@0
|
7 % where O is the num of observed scalar nodes, and N is the total num nodes per slice.
|
wolffd@0
|
8 % Then you can create data as follows, where onodes is the index of the observable nodes:
|
wolffd@0
|
9 % data = cell(1,L);
|
wolffd@0
|
10 % for l=1:L
|
wolffd@0
|
11 % data{l} = cell(N, T);
|
wolffd@0
|
12 % data{l}(onodes,:) = num2cell(D(:,:,l));
|
wolffd@0
|
13 % end
|
wolffd@0
|
14 % Of course it is possible for different sets of nodes to be observed in
|
wolffd@0
|
15 % each slice/ sequence, and for each sequence to be a different length.
|
wolffd@0
|
16 %
|
wolffd@0
|
17 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration.
|
wolffd@0
|
18 %
|
wolffd@0
|
19 % Optional arguments [default]
|
wolffd@0
|
20 %
|
wolffd@0
|
21 % max_iter - specifies the maximum number of iterations [100]
|
wolffd@0
|
22 % thresh - specifies the thresold for stopping EM [1e-3]
|
wolffd@0
|
23 % We stop when |f(t) - f(t-1)| / avg < threshold,
|
wolffd@0
|
24 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik.
|
wolffd@0
|
25 % verbose - display loglik at each iteration [1]
|
wolffd@0
|
26 % anneal - 1 means do deterministic annealing (only for entropic priors) [0]
|
wolffd@0
|
27 % anneal_rate - geometric cooling rate [0.8]
|
wolffd@0
|
28 % init_temp - initial annealing temperature [10]
|
wolffd@0
|
29 % final_temp - final annealing temperature [1e-3]
|
wolffd@0
|
30 %
|
wolffd@0
|
31
|
wolffd@0
|
32 max_iter = 100;
|
wolffd@0
|
33 thresh = 1e-3;
|
wolffd@0
|
34 anneal = 0;
|
wolffd@0
|
35 anneal_rate = 0.8;
|
wolffd@0
|
36 init_temp = 10;
|
wolffd@0
|
37 final_temp = 1e-3;
|
wolffd@0
|
38 verbose = 1;
|
wolffd@0
|
39
|
wolffd@0
|
40 for i=1:2:length(varargin)
|
wolffd@0
|
41 switch varargin{i}
|
wolffd@0
|
42 case 'max_iter', max_iter = varargin{i+1};
|
wolffd@0
|
43 case 'thresh', thresh = varargin{i+1};
|
wolffd@0
|
44 case 'anneal', anneal = varargin{i+1};
|
wolffd@0
|
45 case 'anneal_rate', anneal_rate = varargin{i+1};
|
wolffd@0
|
46 case 'init_temp', init_temp = varargin{i+1};
|
wolffd@0
|
47 case 'final_temp', final_temp = varargin{i+1};
|
wolffd@0
|
48 otherwise, error(['unrecognized argument' varargin{i}])
|
wolffd@0
|
49 end
|
wolffd@0
|
50 end
|
wolffd@0
|
51
|
wolffd@0
|
52 % take 1 EM step at each temperature value, then when temp=0, run to convergence
|
wolffd@0
|
53 % When using an entropic prior, Z = 1-T, so
|
wolffd@0
|
54 % T=2 => Z=-1 (max entropy)
|
wolffd@0
|
55 % T=1 => Z=0 (max likelihood)
|
wolffd@0
|
56 % T=0 => Z=1 (min entropy / max structure)
|
wolffd@0
|
57 num_iter = 1;
|
wolffd@0
|
58 LL = [];
|
wolffd@0
|
59 if anneal
|
wolffd@0
|
60 temperature = init_temp;
|
wolffd@0
|
61 while temperature > final_temp
|
wolffd@0
|
62 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
|
wolffd@0
|
63 if verbose
|
wolffd@0
|
64 fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ...
|
wolffd@0
|
65 num_iter, loglik, logpost, temperature);
|
wolffd@0
|
66 end
|
wolffd@0
|
67 num_iter = num_iter + 1;
|
wolffd@0
|
68 LL = [LL loglik];
|
wolffd@0
|
69 temperature = temperature * anneal_rate;
|
wolffd@0
|
70 end
|
wolffd@0
|
71 temperature = 0;
|
wolffd@0
|
72 previous_loglik = loglik;
|
wolffd@0
|
73 previous_logpost = logpost;
|
wolffd@0
|
74 else
|
wolffd@0
|
75 temperature = 0;
|
wolffd@0
|
76 previous_loglik = -inf;
|
wolffd@0
|
77 previous_logpost = -inf;
|
wolffd@0
|
78 end
|
wolffd@0
|
79
|
wolffd@0
|
80 converged = 0;
|
wolffd@0
|
81 while ~converged & (num_iter <= max_iter)
|
wolffd@0
|
82 [engine, loglik, logpost] = EM_step(engine, evidence, temperature);
|
wolffd@0
|
83 if verbose
|
wolffd@0
|
84 %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ...
|
wolffd@0
|
85 % num_iter, loglik, logpost);
|
wolffd@0
|
86 fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik);
|
wolffd@0
|
87 end
|
wolffd@0
|
88 num_iter = num_iter + 1;
|
wolffd@0
|
89 [converged, decreased] = em_converged(loglik, previous_loglik, thresh);
|
wolffd@0
|
90 %[converged, decreased] = em_converged(logpost, previous_logpost, thresh);
|
wolffd@0
|
91 previous_loglik = loglik;
|
wolffd@0
|
92 previous_logpost = logpost;
|
wolffd@0
|
93 LL = [LL loglik];
|
wolffd@0
|
94 end
|
wolffd@0
|
95
|
wolffd@0
|
96 bnet = bnet_from_engine(engine);
|
wolffd@0
|
97
|
wolffd@0
|
98 %%%%%%%%%
|
wolffd@0
|
99
|
wolffd@0
|
100 function [engine, loglik, logpost] = EM_step(engine, cases, temp)
|
wolffd@0
|
101
|
wolffd@0
|
102 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step
|
wolffd@0
|
103 ss = length(bnet.intra);
|
wolffd@0
|
104 CPDs = bnet.CPD; % these are the new params that get maximized
|
wolffd@0
|
105 num_CPDs = length(CPDs);
|
wolffd@0
|
106
|
wolffd@0
|
107 % log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D))
|
wolffd@0
|
108 % where log P(D|theta) = sum_cases log P(case|theta)
|
wolffd@0
|
109 % and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied!
|
wolffd@0
|
110 % logpost = log P(theta,D) (un-normalized)
|
wolffd@0
|
111 % This should be negative, and increase at every step.
|
wolffd@0
|
112
|
wolffd@0
|
113 adjustable = zeros(1,num_CPDs);
|
wolffd@0
|
114 logprior = zeros(1, num_CPDs);
|
wolffd@0
|
115 for e=1:num_CPDs
|
wolffd@0
|
116 adjustable(e) = adjustable_CPD(CPDs{e});
|
wolffd@0
|
117 end
|
wolffd@0
|
118 adj = find(adjustable);
|
wolffd@0
|
119
|
wolffd@0
|
120 for e=adj(:)'
|
wolffd@0
|
121 logprior(e) = log_prior(CPDs{e});
|
wolffd@0
|
122 CPDs{e} = reset_ess(CPDs{e});
|
wolffd@0
|
123 end
|
wolffd@0
|
124
|
wolffd@0
|
125 loglik = 0;
|
wolffd@0
|
126 for l=1:length(cases)
|
wolffd@0
|
127 evidence = cases{l};
|
wolffd@0
|
128 if ~iscell(evidence)
|
wolffd@0
|
129 error('training data must be a cell array of cell arrays')
|
wolffd@0
|
130 end
|
wolffd@0
|
131 [engine, ll] = enter_evidence(engine, evidence);
|
wolffd@0
|
132 assert(~isnan(ll))
|
wolffd@0
|
133 loglik = loglik + ll;
|
wolffd@0
|
134 T = size(evidence, 2);
|
wolffd@0
|
135
|
wolffd@0
|
136 % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number
|
wolffd@0
|
137 % so that they extract evidence from the right place.
|
wolffd@0
|
138 % (The CPD should really store its own version of ns and cnodes...)
|
wolffd@0
|
139 ns = repmat(bnet.node_sizes_slice(:), [1 T]);
|
wolffd@0
|
140 cnodes = unroll_set(bnet.cnodes_slice, ss, T);
|
wolffd@0
|
141
|
wolffd@0
|
142 %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]);
|
wolffd@0
|
143 hidden_bitv = zeros(ss, T);
|
wolffd@0
|
144 hidden_bitv(isemptycell(evidence))=1;
|
wolffd@0
|
145 % hidden_bitv(i) = 1 means node i is hidden.
|
wolffd@0
|
146 % We pass this in, rather than using isemptycell(evidence(dom)), because
|
wolffd@0
|
147 % isemptycell is very slow.
|
wolffd@0
|
148
|
wolffd@0
|
149 t = 1;
|
wolffd@0
|
150 for i=1:ss
|
wolffd@0
|
151 e = bnet.equiv_class(i,1);
|
wolffd@0
|
152 if adjustable(e)
|
wolffd@0
|
153 fmarg = marginal_family(engine, i, t);
|
wolffd@0
|
154 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
|
wolffd@0
|
155 end
|
wolffd@0
|
156 end
|
wolffd@0
|
157
|
wolffd@0
|
158 for i=1:ss
|
wolffd@0
|
159 e = bnet.equiv_class(i,2);
|
wolffd@0
|
160 if adjustable(e)
|
wolffd@0
|
161 for t=2:T
|
wolffd@0
|
162 fmarg = marginal_family(engine, i, t);
|
wolffd@0
|
163 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:));
|
wolffd@0
|
164 end
|
wolffd@0
|
165 end
|
wolffd@0
|
166 end
|
wolffd@0
|
167 end
|
wolffd@0
|
168
|
wolffd@0
|
169 logpost = loglik + sum(logprior(:));
|
wolffd@0
|
170
|
wolffd@0
|
171 for e=adj(:)'
|
wolffd@0
|
172 CPDs{e} = maximize_params(CPDs{e}, temp);
|
wolffd@0
|
173 end
|
wolffd@0
|
174
|
wolffd@0
|
175 engine = update_engine(engine, CPDs);
|
wolffd@0
|
176
|
wolffd@0
|
177
|
wolffd@0
|
178
|
wolffd@0
|
179
|