Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/hmc.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [samples, energies, diagn] = hmc(f, x, options, gradf, varargin) | |
2 %HMC Hybrid Monte Carlo sampling. | |
3 % | |
4 % Description | |
5 % SAMPLES = HMC(F, X, OPTIONS, GRADF) uses a hybrid Monte Carlo | |
6 % algorithm to sample from the distribution P ~ EXP(-F), where F is the | |
7 % first argument to HMC. The Markov chain starts at the point X, and | |
8 % the function GRADF is the gradient of the `energy' function F. | |
9 % | |
10 % HMC(F, X, OPTIONS, GRADF, P1, P2, ...) allows additional arguments to | |
11 % be passed to F() and GRADF(). | |
12 % | |
13 % [SAMPLES, ENERGIES, DIAGN] = HMC(F, X, OPTIONS, GRADF) also returns a | |
14 % log of the energy values (i.e. negative log probabilities) for the | |
15 % samples in ENERGIES and DIAGN, a structure containing diagnostic | |
16 % information (position, momentum and acceptance threshold) for each | |
17 % step of the chain in DIAGN.POS, DIAGN.MOM and DIAGN.ACC respectively. | |
18 % All candidate states (including rejected ones) are stored in | |
19 % DIAGN.POS. | |
20 % | |
21 % [SAMPLES, ENERGIES, DIAGN] = HMC(F, X, OPTIONS, GRADF) also returns | |
22 % the ENERGIES (i.e. negative log probabilities) corresponding to the | |
23 % samples. The DIAGN structure contains three fields: | |
24 % | |
25 % POS the position vectors of the dynamic process. | |
26 % | |
27 % MOM the momentum vectors of the dynamic process. | |
28 % | |
29 % ACC the acceptance thresholds. | |
30 % | |
31 % S = HMC('STATE') returns a state structure that contains the state of | |
32 % the two random number generators RAND and RANDN and the momentum of | |
33 % the dynamic process. These are contained in fields randstate, | |
34 % randnstate and mom respectively. The momentum state is only used for | |
35 % a persistent momentum update. | |
36 % | |
37 % HMC('STATE', S) resets the state to S. If S is an integer, then it | |
38 % is passed to RAND and RANDN and the momentum variable is randomised. | |
39 % If S is a structure returned by HMC('STATE') then it resets the | |
40 % generator to exactly the same state. | |
41 % | |
42 % The optional parameters in the OPTIONS vector have the following | |
43 % interpretations. | |
44 % | |
45 % OPTIONS(1) is set to 1 to display the energy values and rejection | |
46 % threshold at each step of the Markov chain. If the value is 2, then | |
47 % the position vectors at each step are also displayed. | |
48 % | |
49 % OPTIONS(5) is set to 1 if momentum persistence is used; default 0, | |
50 % for complete replacement of momentum variables. | |
51 % | |
52 % OPTIONS(7) defines the trajectory length (i.e. the number of leap- | |
53 % frog steps at each iteration). Minimum value 1. | |
54 % | |
55 % OPTIONS(9) is set to 1 to check the user defined gradient function. | |
56 % | |
57 % OPTIONS(14) is the number of samples retained from the Markov chain; | |
58 % default 100. | |
59 % | |
60 % OPTIONS(15) is the number of samples omitted from the start of the | |
61 % chain; default 0. | |
62 % | |
63 % OPTIONS(17) defines the momentum used when a persistent update of | |
64 % (leap-frog) momentum is used. This is bounded to the interval [0, | |
65 % 1). | |
66 % | |
67 % OPTIONS(18) is the step size used in leap-frogs; default 1/trajectory | |
68 % length. | |
69 % | |
70 % See also | |
71 % METROP | |
72 % | |
73 | |
74 % Copyright (c) Ian T Nabney (1996-2001) | |
75 | |
76 % Global variable to store state of momentum variables: set by set_state | |
77 % Used to initialise variable if set | |
78 global HMC_MOM | |
79 if nargin <= 2 | |
80 if ~strcmp(f, 'state') | |
81 error('Unknown argument to hmc'); | |
82 end | |
83 switch nargin | |
84 case 1 | |
85 samples = get_state(f); | |
86 return; | |
87 case 2 | |
88 set_state(f, x); | |
89 return; | |
90 end | |
91 end | |
92 | |
93 display = options(1); | |
94 if (round(options(5) == 1)) | |
95 persistence = 1; | |
96 % Set alpha to lie in [0, 1) | |
97 alpha = max(0, options(17)); | |
98 alpha = min(1, alpha); | |
99 salpha = sqrt(1-alpha*alpha); | |
100 else | |
101 persistence = 0; | |
102 end | |
103 L = max(1, options(7)); % At least one step in leap-frogging | |
104 if options(14) > 0 | |
105 nsamples = options(14); | |
106 else | |
107 nsamples = 100; % Default | |
108 end | |
109 if options(15) >= 0 | |
110 nomit = options(15); | |
111 else | |
112 nomit = 0; | |
113 end | |
114 if options(18) > 0 | |
115 step_size = options(18); % Step size. | |
116 else | |
117 step_size = 1/L; % Default | |
118 end | |
119 x = x(:)'; % Force x to be a row vector | |
120 nparams = length(x); | |
121 | |
122 % Set up strings for evaluating potential function and its gradient. | |
123 f = fcnchk(f, length(varargin)); | |
124 gradf = fcnchk(gradf, length(varargin)); | |
125 | |
126 % Check the gradient evaluation. | |
127 if (options(9)) | |
128 % Check gradients | |
129 feval('gradchek', x, f, gradf, varargin{:}); | |
130 end | |
131 | |
132 samples = zeros(nsamples, nparams); % Matrix of returned samples. | |
133 if nargout >= 2 | |
134 en_save = 1; | |
135 energies = zeros(nsamples, 1); | |
136 else | |
137 en_save = 0; | |
138 end | |
139 if nargout >= 3 | |
140 diagnostics = 1; | |
141 diagn_pos = zeros(nsamples, nparams); | |
142 diagn_mom = zeros(nsamples, nparams); | |
143 diagn_acc = zeros(nsamples, 1); | |
144 else | |
145 diagnostics = 0; | |
146 end | |
147 | |
148 n = - nomit + 1; | |
149 Eold = feval(f, x, varargin{:}); % Evaluate starting energy. | |
150 nreject = 0; | |
151 if (~persistence | isempty(HMC_MOM)) | |
152 p = randn(1, nparams); % Initialise momenta at random | |
153 else | |
154 p = HMC_MOM; % Initialise momenta from stored state | |
155 end | |
156 lambda = 1; | |
157 | |
158 % Main loop. | |
159 while n <= nsamples | |
160 | |
161 xold = x; % Store starting position. | |
162 pold = p; % Store starting momenta | |
163 Hold = Eold + 0.5*(p*p'); % Recalculate Hamiltonian as momenta have changed | |
164 | |
165 if ~persistence | |
166 % Choose a direction at random | |
167 if (rand < 0.5) | |
168 lambda = -1; | |
169 else | |
170 lambda = 1; | |
171 end | |
172 end | |
173 % Perturb step length. | |
174 epsilon = lambda*step_size*(1.0 + 0.1*randn(1)); | |
175 | |
176 % First half-step of leapfrog. | |
177 p = p - 0.5*epsilon*feval(gradf, x, varargin{:}); | |
178 x = x + epsilon*p; | |
179 | |
180 % Full leapfrog steps. | |
181 for m = 1 : L - 1 | |
182 p = p - epsilon*feval(gradf, x, varargin{:}); | |
183 x = x + epsilon*p; | |
184 end | |
185 | |
186 % Final half-step of leapfrog. | |
187 p = p - 0.5*epsilon*feval(gradf, x, varargin{:}); | |
188 | |
189 % Now apply Metropolis algorithm. | |
190 Enew = feval(f, x, varargin{:}); % Evaluate new energy. | |
191 p = -p; % Negate momentum | |
192 Hnew = Enew + 0.5*p*p'; % Evaluate new Hamiltonian. | |
193 a = exp(Hold - Hnew); % Acceptance threshold. | |
194 if (diagnostics & n > 0) | |
195 diagn_pos(n,:) = x; | |
196 diagn_mom(n,:) = p; | |
197 diagn_acc(n,:) = a; | |
198 end | |
199 if (display > 1) | |
200 fprintf(1, 'New position is\n'); | |
201 disp(x); | |
202 end | |
203 | |
204 if a > rand(1) % Accept the new state. | |
205 Eold = Enew; % Update energy | |
206 if (display > 0) | |
207 fprintf(1, 'Finished step %4d Threshold: %g\n', n, a); | |
208 end | |
209 else % Reject the new state. | |
210 if n > 0 | |
211 nreject = nreject + 1; | |
212 end | |
213 x = xold; % Reset position | |
214 p = pold; % Reset momenta | |
215 if (display > 0) | |
216 fprintf(1, ' Sample rejected %4d. Threshold: %g\n', n, a); | |
217 end | |
218 end | |
219 if n > 0 | |
220 samples(n,:) = x; % Store sample. | |
221 if en_save | |
222 energies(n) = Eold; % Store energy. | |
223 end | |
224 end | |
225 | |
226 % Set momenta for next iteration | |
227 if persistence | |
228 p = -p; | |
229 % Adjust momenta by a small random amount. | |
230 p = alpha.*p + salpha.*randn(1, nparams); | |
231 else | |
232 p = randn(1, nparams); % Replace all momenta. | |
233 end | |
234 | |
235 n = n + 1; | |
236 end | |
237 | |
238 if (display > 0) | |
239 fprintf(1, '\nFraction of samples rejected: %g\n', ... | |
240 nreject/(nsamples)); | |
241 end | |
242 if diagnostics | |
243 diagn.pos = diagn_pos; | |
244 diagn.mom = diagn_mom; | |
245 diagn.acc = diagn_acc; | |
246 end | |
247 % Store final momentum value in global so that it can be retrieved later | |
248 HMC_MOM = p; | |
249 return | |
250 | |
251 % Return complete state of sampler (including momentum) | |
252 function state = get_state(f) | |
253 | |
254 global HMC_MOM | |
255 state.randstate = rand('state'); | |
256 state.randnstate = randn('state'); | |
257 state.mom = HMC_MOM; | |
258 return | |
259 | |
260 % Set complete state of sampler (including momentum) or just set randn | |
261 % and rand with integer argument. | |
262 function set_state(f, x) | |
263 | |
264 global HMC_MOM | |
265 if isnumeric(x) | |
266 rand('state', x); | |
267 randn('state', x); | |
268 HMC_MOM = []; | |
269 else | |
270 if ~isstruct(x) | |
271 error('Second argument to hmc must be number or state structure'); | |
272 end | |
273 if (~isfield(x, 'randstate') | ~isfield(x, 'randnstate') ... | |
274 | ~isfield(x, 'mom')) | |
275 error('Second argument to hmc must contain correct fields') | |
276 end | |
277 rand('state', x.randstate); | |
278 randn('state', x.randnstate); | |
279 HMC_MOM = x.mom; | |
280 end | |
281 return |