comparison toolboxes/FullBNT-1.0.7/netlab3.3/hmc.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [samples, energies, diagn] = hmc(f, x, options, gradf, varargin)
2 %HMC Hybrid Monte Carlo sampling.
3 %
4 % Description
5 % SAMPLES = HMC(F, X, OPTIONS, GRADF) uses a hybrid Monte Carlo
6 % algorithm to sample from the distribution P ~ EXP(-F), where F is the
7 % first argument to HMC. The Markov chain starts at the point X, and
8 % the function GRADF is the gradient of the `energy' function F.
9 %
10 % HMC(F, X, OPTIONS, GRADF, P1, P2, ...) allows additional arguments to
11 % be passed to F() and GRADF().
12 %
13 % [SAMPLES, ENERGIES, DIAGN] = HMC(F, X, OPTIONS, GRADF) also returns a
14 % log of the energy values (i.e. negative log probabilities) for the
15 % samples in ENERGIES and DIAGN, a structure containing diagnostic
16 % information (position, momentum and acceptance threshold) for each
17 % step of the chain in DIAGN.POS, DIAGN.MOM and DIAGN.ACC respectively.
18 % All candidate states (including rejected ones) are stored in
19 % DIAGN.POS.
20 %
21 % [SAMPLES, ENERGIES, DIAGN] = HMC(F, X, OPTIONS, GRADF) also returns
22 % the ENERGIES (i.e. negative log probabilities) corresponding to the
23 % samples. The DIAGN structure contains three fields:
24 %
25 % POS the position vectors of the dynamic process.
26 %
27 % MOM the momentum vectors of the dynamic process.
28 %
29 % ACC the acceptance thresholds.
30 %
31 % S = HMC('STATE') returns a state structure that contains the state of
32 % the two random number generators RAND and RANDN and the momentum of
33 % the dynamic process. These are contained in fields randstate,
34 % randnstate and mom respectively. The momentum state is only used for
35 % a persistent momentum update.
36 %
37 % HMC('STATE', S) resets the state to S. If S is an integer, then it
38 % is passed to RAND and RANDN and the momentum variable is randomised.
39 % If S is a structure returned by HMC('STATE') then it resets the
40 % generator to exactly the same state.
41 %
42 % The optional parameters in the OPTIONS vector have the following
43 % interpretations.
44 %
45 % OPTIONS(1) is set to 1 to display the energy values and rejection
46 % threshold at each step of the Markov chain. If the value is 2, then
47 % the position vectors at each step are also displayed.
48 %
49 % OPTIONS(5) is set to 1 if momentum persistence is used; default 0,
50 % for complete replacement of momentum variables.
51 %
52 % OPTIONS(7) defines the trajectory length (i.e. the number of leap-
53 % frog steps at each iteration). Minimum value 1.
54 %
55 % OPTIONS(9) is set to 1 to check the user defined gradient function.
56 %
57 % OPTIONS(14) is the number of samples retained from the Markov chain;
58 % default 100.
59 %
60 % OPTIONS(15) is the number of samples omitted from the start of the
61 % chain; default 0.
62 %
63 % OPTIONS(17) defines the momentum used when a persistent update of
64 % (leap-frog) momentum is used. This is bounded to the interval [0,
65 % 1).
66 %
67 % OPTIONS(18) is the step size used in leap-frogs; default 1/trajectory
68 % length.
69 %
70 % See also
71 % METROP
72 %
73
74 % Copyright (c) Ian T Nabney (1996-2001)
75
76 % Global variable to store state of momentum variables: set by set_state
77 % Used to initialise variable if set
78 global HMC_MOM
79 if nargin <= 2
80 if ~strcmp(f, 'state')
81 error('Unknown argument to hmc');
82 end
83 switch nargin
84 case 1
85 samples = get_state(f);
86 return;
87 case 2
88 set_state(f, x);
89 return;
90 end
91 end
92
93 display = options(1);
94 if (round(options(5) == 1))
95 persistence = 1;
96 % Set alpha to lie in [0, 1)
97 alpha = max(0, options(17));
98 alpha = min(1, alpha);
99 salpha = sqrt(1-alpha*alpha);
100 else
101 persistence = 0;
102 end
103 L = max(1, options(7)); % At least one step in leap-frogging
104 if options(14) > 0
105 nsamples = options(14);
106 else
107 nsamples = 100; % Default
108 end
109 if options(15) >= 0
110 nomit = options(15);
111 else
112 nomit = 0;
113 end
114 if options(18) > 0
115 step_size = options(18); % Step size.
116 else
117 step_size = 1/L; % Default
118 end
119 x = x(:)'; % Force x to be a row vector
120 nparams = length(x);
121
122 % Set up strings for evaluating potential function and its gradient.
123 f = fcnchk(f, length(varargin));
124 gradf = fcnchk(gradf, length(varargin));
125
126 % Check the gradient evaluation.
127 if (options(9))
128 % Check gradients
129 feval('gradchek', x, f, gradf, varargin{:});
130 end
131
132 samples = zeros(nsamples, nparams); % Matrix of returned samples.
133 if nargout >= 2
134 en_save = 1;
135 energies = zeros(nsamples, 1);
136 else
137 en_save = 0;
138 end
139 if nargout >= 3
140 diagnostics = 1;
141 diagn_pos = zeros(nsamples, nparams);
142 diagn_mom = zeros(nsamples, nparams);
143 diagn_acc = zeros(nsamples, 1);
144 else
145 diagnostics = 0;
146 end
147
148 n = - nomit + 1;
149 Eold = feval(f, x, varargin{:}); % Evaluate starting energy.
150 nreject = 0;
151 if (~persistence | isempty(HMC_MOM))
152 p = randn(1, nparams); % Initialise momenta at random
153 else
154 p = HMC_MOM; % Initialise momenta from stored state
155 end
156 lambda = 1;
157
158 % Main loop.
159 while n <= nsamples
160
161 xold = x; % Store starting position.
162 pold = p; % Store starting momenta
163 Hold = Eold + 0.5*(p*p'); % Recalculate Hamiltonian as momenta have changed
164
165 if ~persistence
166 % Choose a direction at random
167 if (rand < 0.5)
168 lambda = -1;
169 else
170 lambda = 1;
171 end
172 end
173 % Perturb step length.
174 epsilon = lambda*step_size*(1.0 + 0.1*randn(1));
175
176 % First half-step of leapfrog.
177 p = p - 0.5*epsilon*feval(gradf, x, varargin{:});
178 x = x + epsilon*p;
179
180 % Full leapfrog steps.
181 for m = 1 : L - 1
182 p = p - epsilon*feval(gradf, x, varargin{:});
183 x = x + epsilon*p;
184 end
185
186 % Final half-step of leapfrog.
187 p = p - 0.5*epsilon*feval(gradf, x, varargin{:});
188
189 % Now apply Metropolis algorithm.
190 Enew = feval(f, x, varargin{:}); % Evaluate new energy.
191 p = -p; % Negate momentum
192 Hnew = Enew + 0.5*p*p'; % Evaluate new Hamiltonian.
193 a = exp(Hold - Hnew); % Acceptance threshold.
194 if (diagnostics & n > 0)
195 diagn_pos(n,:) = x;
196 diagn_mom(n,:) = p;
197 diagn_acc(n,:) = a;
198 end
199 if (display > 1)
200 fprintf(1, 'New position is\n');
201 disp(x);
202 end
203
204 if a > rand(1) % Accept the new state.
205 Eold = Enew; % Update energy
206 if (display > 0)
207 fprintf(1, 'Finished step %4d Threshold: %g\n', n, a);
208 end
209 else % Reject the new state.
210 if n > 0
211 nreject = nreject + 1;
212 end
213 x = xold; % Reset position
214 p = pold; % Reset momenta
215 if (display > 0)
216 fprintf(1, ' Sample rejected %4d. Threshold: %g\n', n, a);
217 end
218 end
219 if n > 0
220 samples(n,:) = x; % Store sample.
221 if en_save
222 energies(n) = Eold; % Store energy.
223 end
224 end
225
226 % Set momenta for next iteration
227 if persistence
228 p = -p;
229 % Adjust momenta by a small random amount.
230 p = alpha.*p + salpha.*randn(1, nparams);
231 else
232 p = randn(1, nparams); % Replace all momenta.
233 end
234
235 n = n + 1;
236 end
237
238 if (display > 0)
239 fprintf(1, '\nFraction of samples rejected: %g\n', ...
240 nreject/(nsamples));
241 end
242 if diagnostics
243 diagn.pos = diagn_pos;
244 diagn.mom = diagn_mom;
245 diagn.acc = diagn_acc;
246 end
247 % Store final momentum value in global so that it can be retrieved later
248 HMC_MOM = p;
249 return
250
251 % Return complete state of sampler (including momentum)
252 function state = get_state(f)
253
254 global HMC_MOM
255 state.randstate = rand('state');
256 state.randnstate = randn('state');
257 state.mom = HMC_MOM;
258 return
259
260 % Set complete state of sampler (including momentum) or just set randn
261 % and rand with integer argument.
262 function set_state(f, x)
263
264 global HMC_MOM
265 if isnumeric(x)
266 rand('state', x);
267 randn('state', x);
268 HMC_MOM = [];
269 else
270 if ~isstruct(x)
271 error('Second argument to hmc must be number or state structure');
272 end
273 if (~isfield(x, 'randstate') | ~isfield(x, 'randnstate') ...
274 | ~isfield(x, 'mom'))
275 error('Second argument to hmc must contain correct fields')
276 end
277 rand('state', x.randstate);
278 randn('state', x.randnstate);
279 HMC_MOM = x.mom;
280 end
281 return