Skip to content

Commit 1d959ea

Browse files
committed
Adjust for AdvancedHMC.jl API changes
1 parent 590a45c commit 1d959ea

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

ext/ahmc_impl/ahmc_sampler_impl.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,14 @@ function BAT.mcmc_propose!!(mc_state::HMCState)
115115

116116
hamiltonian = proposal.hamiltonian
117117

118-
# Current location in the phase space for Hamiltonian MonteCarlo
119-
z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic))
118+
@static if isdefined(AdvancedHMC, :rand_momentum) #isdefined(AdvancedHMC.rand_momentum, Tuple{AbstractRNG, AdvancedHMC.AbstractMetric, AdvancedHMC.AbstractKinetic, AbstractVecOrMat})
119+
# For AdvnacedHMC.jl v >= 0.7
120+
momentum = rand_momentum(rng, hamiltonian.metric, hamiltonian.kinetic, z_current[:])
121+
else
122+
momentum = rand(rng, hamiltonian.metric, hamiltonian.kinetic)
123+
end
124+
125+
z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), momentum)
120126
# Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics.
121127

122128
proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase)

0 commit comments

Comments
 (0)