Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/learning/learn_params_dbn_em.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [bnet, LL, engine] = learn_params_dbn_em(engine, evidence, varargin) | |
2 % LEARN_PARAMS_DBN Set the parameters in a DBN to their ML/MAP values using batch EM. | |
3 % [bnet, LLtrace, engine] = learn_params_dbn_em(engine, data, ...) | |
4 % | |
5 % data{l}{i,t} = value of node i in slice t of time-series l, or [] if hidden. | |
6 % Suppose you have L time series, each of length T, in an O*T*L array D, | |
7 % where O is the num of observed scalar nodes, and N is the total num nodes per slice. | |
8 % Then you can create data as follows, where onodes is the index of the observable nodes: | |
9 % data = cell(1,L); | |
10 % for l=1:L | |
11 % data{l} = cell(N, T); | |
12 % data{l}(onodes,:) = num2cell(D(:,:,l)); | |
13 % end | |
14 % Of course it is possible for different sets of nodes to be observed in | |
15 % each slice/ sequence, and for each sequence to be a different length. | |
16 % | |
17 % LLtrace is the learning curve: the vector of log-likelihood scores at each iteration. | |
18 % | |
19 % Optional arguments [default] | |
20 % | |
21 % max_iter - specifies the maximum number of iterations [100] | |
22 % thresh - specifies the thresold for stopping EM [1e-3] | |
23 % We stop when |f(t) - f(t-1)| / avg < threshold, | |
24 % where avg = (|f(t)| + |f(t-1)|)/2 and f is log lik. | |
25 % verbose - display loglik at each iteration [1] | |
26 % anneal - 1 means do deterministic annealing (only for entropic priors) [0] | |
27 % anneal_rate - geometric cooling rate [0.8] | |
28 % init_temp - initial annealing temperature [10] | |
29 % final_temp - final annealing temperature [1e-3] | |
30 % | |
31 | |
32 max_iter = 100; | |
33 thresh = 1e-3; | |
34 anneal = 0; | |
35 anneal_rate = 0.8; | |
36 init_temp = 10; | |
37 final_temp = 1e-3; | |
38 verbose = 1; | |
39 | |
40 for i=1:2:length(varargin) | |
41 switch varargin{i} | |
42 case 'max_iter', max_iter = varargin{i+1}; | |
43 case 'thresh', thresh = varargin{i+1}; | |
44 case 'anneal', anneal = varargin{i+1}; | |
45 case 'anneal_rate', anneal_rate = varargin{i+1}; | |
46 case 'init_temp', init_temp = varargin{i+1}; | |
47 case 'final_temp', final_temp = varargin{i+1}; | |
48 otherwise, error(['unrecognized argument' varargin{i}]) | |
49 end | |
50 end | |
51 | |
52 % take 1 EM step at each temperature value, then when temp=0, run to convergence | |
53 % When using an entropic prior, Z = 1-T, so | |
54 % T=2 => Z=-1 (max entropy) | |
55 % T=1 => Z=0 (max likelihood) | |
56 % T=0 => Z=1 (min entropy / max structure) | |
57 num_iter = 1; | |
58 LL = []; | |
59 if anneal | |
60 temperature = init_temp; | |
61 while temperature > final_temp | |
62 [engine, loglik, logpost] = EM_step(engine, evidence, temperature); | |
63 if verbose | |
64 fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f, temp=%8.4f\n', ... | |
65 num_iter, loglik, logpost, temperature); | |
66 end | |
67 num_iter = num_iter + 1; | |
68 LL = [LL loglik]; | |
69 temperature = temperature * anneal_rate; | |
70 end | |
71 temperature = 0; | |
72 previous_loglik = loglik; | |
73 previous_logpost = logpost; | |
74 else | |
75 temperature = 0; | |
76 previous_loglik = -inf; | |
77 previous_logpost = -inf; | |
78 end | |
79 | |
80 converged = 0; | |
81 while ~converged & (num_iter <= max_iter) | |
82 [engine, loglik, logpost] = EM_step(engine, evidence, temperature); | |
83 if verbose | |
84 %fprintf('EM iteration %d, loglik = %8.4f, logpost = %8.4f\n', ... | |
85 % num_iter, loglik, logpost); | |
86 fprintf('EM iteration %d, loglik = %8.4f\n', num_iter, loglik); | |
87 end | |
88 num_iter = num_iter + 1; | |
89 [converged, decreased] = em_converged(loglik, previous_loglik, thresh); | |
90 %[converged, decreased] = em_converged(logpost, previous_logpost, thresh); | |
91 previous_loglik = loglik; | |
92 previous_logpost = logpost; | |
93 LL = [LL loglik]; | |
94 end | |
95 | |
96 bnet = bnet_from_engine(engine); | |
97 | |
98 %%%%%%%%% | |
99 | |
100 function [engine, loglik, logpost] = EM_step(engine, cases, temp) | |
101 | |
102 bnet = bnet_from_engine(engine); % engine contains the old params that are used for the E step | |
103 ss = length(bnet.intra); | |
104 CPDs = bnet.CPD; % these are the new params that get maximized | |
105 num_CPDs = length(CPDs); | |
106 | |
107 % log P(theta|D) = (log P(D|theta) + log P(theta)) - log(P(D)) | |
108 % where log P(D|theta) = sum_cases log P(case|theta) | |
109 % and log P(theta) = sum_CPDs log P(CPD) - only count once even if tied! | |
110 % logpost = log P(theta,D) (un-normalized) | |
111 % This should be negative, and increase at every step. | |
112 | |
113 adjustable = zeros(1,num_CPDs); | |
114 logprior = zeros(1, num_CPDs); | |
115 for e=1:num_CPDs | |
116 adjustable(e) = adjustable_CPD(CPDs{e}); | |
117 end | |
118 adj = find(adjustable); | |
119 | |
120 for e=adj(:)' | |
121 logprior(e) = log_prior(CPDs{e}); | |
122 CPDs{e} = reset_ess(CPDs{e}); | |
123 end | |
124 | |
125 loglik = 0; | |
126 for l=1:length(cases) | |
127 evidence = cases{l}; | |
128 if ~iscell(evidence) | |
129 error('training data must be a cell array of cell arrays') | |
130 end | |
131 [engine, ll] = enter_evidence(engine, evidence); | |
132 assert(~isnan(ll)) | |
133 loglik = loglik + ll; | |
134 T = size(evidence, 2); | |
135 | |
136 % We unroll ns etc because in update_ess, we refer to nodes by their unrolled number | |
137 % so that they extract evidence from the right place. | |
138 % (The CPD should really store its own version of ns and cnodes...) | |
139 ns = repmat(bnet.node_sizes_slice(:), [1 T]); | |
140 cnodes = unroll_set(bnet.cnodes_slice, ss, T); | |
141 | |
142 %hidden_bitv = repmat(bnet.hidden_bitv(1:ss), [1 T]); | |
143 hidden_bitv = zeros(ss, T); | |
144 hidden_bitv(isemptycell(evidence))=1; | |
145 % hidden_bitv(i) = 1 means node i is hidden. | |
146 % We pass this in, rather than using isemptycell(evidence(dom)), because | |
147 % isemptycell is very slow. | |
148 | |
149 t = 1; | |
150 for i=1:ss | |
151 e = bnet.equiv_class(i,1); | |
152 if adjustable(e) | |
153 fmarg = marginal_family(engine, i, t); | |
154 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); | |
155 end | |
156 end | |
157 | |
158 for i=1:ss | |
159 e = bnet.equiv_class(i,2); | |
160 if adjustable(e) | |
161 for t=2:T | |
162 fmarg = marginal_family(engine, i, t); | |
163 CPDs{e} = update_ess(CPDs{e}, fmarg, evidence, ns(:), cnodes(:), hidden_bitv(:)); | |
164 end | |
165 end | |
166 end | |
167 end | |
168 | |
169 logpost = loglik + sum(logprior(:)); | |
170 | |
171 for e=adj(:)' | |
172 CPDs{e} = maximize_params(CPDs{e}, temp); | |
173 end | |
174 | |
175 engine = update_engine(engine, CPDs); | |
176 | |
177 | |
178 | |
179 |