@@ -119,9 +119,11 @@ function BAT.mcmc_propose!!(mc_state::HMCState)
119
119
z_phase = AdvancedHMC. phasepoint (hamiltonian, vec (z_current[:]), rand (rng, hamiltonian. metric, hamiltonian. kinetic))
120
120
# Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics.
121
121
122
- proposal. transition, z_proposed_hmc, p_accept = _bat_transition (rng, τ, hamiltonian, z_phase)
123
- accepted = z_current[:] != proposal. transition. z. θ
124
- z_proposed[:] = accepted ? proposal. transition. z. θ : z_proposed_hmc
122
+ proposal. transition = AdvancedHMC. transition (rng, τ, hamiltonian, z_phase)
123
+ p_accept = AdvancedHMC. stat (proposal. transition). acceptance_rate
124
+
125
+ z_proposed[:] = proposal. transition. z. θ
126
+ accepted = z_current[:] != z_proposed[:]
125
127
126
128
p_accept = AdvancedHMC. stat (proposal. transition). acceptance_rate
127
129
@@ -176,85 +178,3 @@ function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Funct
176
178
mc_state_new = @set mc_state_new. f_transform = f_transform_new
177
179
return mc_state_new
178
180
end
179
-
180
-
181
- # Copied from AdvancedHMC.jl, but also return proposed point
182
- function _bat_transition (
183
- rng:: AbstractRNG ,
184
- τ:: AdvancedHMC.Trajectory{TS,I,TC} ,
185
- h:: AdvancedHMC.Hamiltonian ,
186
- z0:: AdvancedHMC.PhasePoint ,
187
- ) where {
188
- TS<: AdvancedHMC.AbstractTrajectorySampler ,
189
- I<: AdvancedHMC.AbstractIntegrator ,
190
- TC<: AdvancedHMC.DynamicTerminationCriterion ,
191
- }
192
- H0 = AdvancedHMC. energy (z0)
193
- tree = AdvancedHMC. BinaryTree (
194
- z0,
195
- z0,
196
- AdvancedHMC. TurnStatistic (τ. termination_criterion, z0),
197
- zero (H0),
198
- zero (Int),
199
- zero (H0),
200
- )
201
- sampler = TS (rng, z0)
202
- termination = AdvancedHMC. Termination (false , false )
203
- zcand = z0
204
- proposed_zs = Vector[]
205
- accept_probs = Float64[]
206
-
207
- j = 0
208
- while ! AdvancedHMC. isterminated (termination) && j < τ. termination_criterion. max_depth
209
- v = rand (rng, [- 1 , 1 ])
210
- if v == - 1
211
- tree′, sampler′, termination′ =
212
- AdvancedHMC. build_tree (rng, τ, h, tree. zleft, sampler, v, j, H0)
213
- treeleft, treeright = tree′, tree
214
- else
215
- tree′, sampler′, termination′ =
216
- AdvancedHMC. build_tree (rng, τ, h, tree. zright, sampler, v, j, H0)
217
- treeleft, treeright = tree, tree′
218
- end
219
-
220
- # This acceptance prob. is specific to AdvancedHMC.MultinomialTS
221
- p_tmp = min (1 , exp (sampler′. ℓw - sampler. ℓw))
222
- push! (accept_probs, p_tmp)
223
- push! (proposed_zs, sampler′. zcand. θ)
224
-
225
- if ! AdvancedHMC. isterminated (termination′)
226
- j = j + 1
227
- if AdvancedHMC. mh_accept (rng, sampler, sampler′)
228
- zcand = sampler′. zcand
229
- end
230
- end
231
- tree = AdvancedHMC. combine (treeleft, treeright)
232
- sampler = AdvancedHMC. combine (zcand, sampler, sampler′)
233
- termination =
234
- termination *
235
- termination′ *
236
- AdvancedHMC. isterminated (τ. termination_criterion, h, tree, treeleft, treeright)
237
- end
238
-
239
- H = AdvancedHMC. energy (zcand)
240
- tstat = AdvancedHMC. merge (
241
- (
242
- n_steps = tree. nα,
243
- is_accept = true ,
244
- acceptance_rate = tree. sum_α / tree. nα,
245
- log_density = zcand. ℓπ. value,
246
- hamiltonian_energy = H,
247
- hamiltonian_energy_error = H - H0,
248
- max_hamiltonian_energy_error = tree. ΔH_max,
249
- tree_depth = j,
250
- numerical_error = termination. numerical,
251
- ),
252
- AdvancedHMC. stat (τ. integrator),
253
- )
254
-
255
- accept_total = sum (accept_probs)
256
- z_proposed = iszero (accept_total) ? sum (proposed_zs) / length (proposed_zs) : sum (accept_probs .* proposed_zs) / accept_total
257
- p_accept = tstat. acceptance_rate
258
-
259
- return AdvancedHMC. Transition (zcand, tstat), z_proposed, p_accept
260
- end
0 commit comments