@@ -119,9 +119,11 @@ function BAT.mcmc_propose!!(mc_state::HMCState)
119119 z_phase = AdvancedHMC. phasepoint (hamiltonian, vec (z_current[:]), rand (rng, hamiltonian. metric, hamiltonian. kinetic))
120120 # Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics.
121121
122- proposal. transition, z_proposed_hmc, p_accept = _bat_transition (rng, τ, hamiltonian, z_phase)
123- accepted = z_current[:] != proposal. transition. z. θ
124- z_proposed[:] = accepted ? proposal. transition. z. θ : z_proposed_hmc
122+ proposal. transition = AdvancedHMC. transition (rng, τ, hamiltonian, z_phase)
123+ p_accept = AdvancedHMC. stat (proposal. transition). acceptance_rate
124+
125+ z_proposed[:] = proposal. transition. z. θ
126+ accepted = z_current[:] != z_proposed[:]
125127
126128 p_accept = AdvancedHMC. stat (proposal. transition). acceptance_rate
127129
@@ -176,85 +178,3 @@ function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Funct
176178 mc_state_new = @set mc_state_new. f_transform = f_transform_new
177179 return mc_state_new
178180end
179-
180-
181- # Copied from AdvancedHMC.jl, but also return proposed point
182- function _bat_transition (
183- rng:: AbstractRNG ,
184- τ:: AdvancedHMC.Trajectory{TS,I,TC} ,
185- h:: AdvancedHMC.Hamiltonian ,
186- z0:: AdvancedHMC.PhasePoint ,
187- ) where {
188- TS<: AdvancedHMC.AbstractTrajectorySampler ,
189- I<: AdvancedHMC.AbstractIntegrator ,
190- TC<: AdvancedHMC.DynamicTerminationCriterion ,
191- }
192- H0 = AdvancedHMC. energy (z0)
193- tree = AdvancedHMC. BinaryTree (
194- z0,
195- z0,
196- AdvancedHMC. TurnStatistic (τ. termination_criterion, z0),
197- zero (H0),
198- zero (Int),
199- zero (H0),
200- )
201- sampler = TS (rng, z0)
202- termination = AdvancedHMC. Termination (false , false )
203- zcand = z0
204- proposed_zs = Vector[]
205- accept_probs = Float64[]
206-
207- j = 0
208- while ! AdvancedHMC. isterminated (termination) && j < τ. termination_criterion. max_depth
209- v = rand (rng, [- 1 , 1 ])
210- if v == - 1
211- tree′, sampler′, termination′ =
212- AdvancedHMC. build_tree (rng, τ, h, tree. zleft, sampler, v, j, H0)
213- treeleft, treeright = tree′, tree
214- else
215- tree′, sampler′, termination′ =
216- AdvancedHMC. build_tree (rng, τ, h, tree. zright, sampler, v, j, H0)
217- treeleft, treeright = tree, tree′
218- end
219-
220- # This acceptance prob. is specific to AdvancedHMC.MultinomialTS
221- p_tmp = min (1 , exp (sampler′. ℓw - sampler. ℓw))
222- push! (accept_probs, p_tmp)
223- push! (proposed_zs, sampler′. zcand. θ)
224-
225- if ! AdvancedHMC. isterminated (termination′)
226- j = j + 1
227- if AdvancedHMC. mh_accept (rng, sampler, sampler′)
228- zcand = sampler′. zcand
229- end
230- end
231- tree = AdvancedHMC. combine (treeleft, treeright)
232- sampler = AdvancedHMC. combine (zcand, sampler, sampler′)
233- termination =
234- termination *
235- termination′ *
236- AdvancedHMC. isterminated (τ. termination_criterion, h, tree, treeleft, treeright)
237- end
238-
239- H = AdvancedHMC. energy (zcand)
240- tstat = AdvancedHMC. merge (
241- (
242- n_steps = tree. nα,
243- is_accept = true ,
244- acceptance_rate = tree. sum_α / tree. nα,
245- log_density = zcand. ℓπ. value,
246- hamiltonian_energy = H,
247- hamiltonian_energy_error = H - H0,
248- max_hamiltonian_energy_error = tree. ΔH_max,
249- tree_depth = j,
250- numerical_error = termination. numerical,
251- ),
252- AdvancedHMC. stat (τ. integrator),
253- )
254-
255- accept_total = sum (accept_probs)
256- z_proposed = iszero (accept_total) ? sum (proposed_zs) / length (proposed_zs) : sum (accept_probs .* proposed_zs) / accept_total
257- p_accept = tstat. acceptance_rate
258-
259- return AdvancedHMC. Transition (zcand, tstat), z_proposed, p_accept
260- end
0 commit comments