Skip to content

Commit 68565f2

Browse files
Micki-Doschulz
authored andcommitted
Remove proposed sample extraction for HMC altogether
1 parent 5f5396a commit 68565f2

File tree

3 files changed

+12
-85
lines changed

3 files changed

+12
-85
lines changed

ext/ahmc_impl/ahmc_sampler_impl.jl

+5-85
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,11 @@ function BAT.mcmc_propose!!(mc_state::HMCState)
119119
z_phase = AdvancedHMC.phasepoint(hamiltonian, vec(z_current[:]), rand(rng, hamiltonian.metric, hamiltonian.kinetic))
120120
# Note: `RiemannianKinetic` requires an additional position argument, but including this causes issues. So only support the other kinetics.
121121

122-
proposal.transition, z_proposed_hmc, p_accept = _bat_transition(rng, τ, hamiltonian, z_phase)
123-
accepted = z_current[:] != proposal.transition.z.θ
124-
z_proposed[:] = accepted ? proposal.transition.z.θ : z_proposed_hmc
122+
proposal.transition = AdvancedHMC.transition(rng, τ, hamiltonian, z_phase)
123+
p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate
124+
125+
z_proposed[:] = proposal.transition.z.θ
126+
accepted = z_current[:] != z_proposed[:]
125127

126128
p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate
127129

@@ -176,85 +178,3 @@ function BAT.set_mc_state_transform!!(mc_state::HMCState, f_transform_new::Funct
176178
mc_state_new = @set mc_state_new.f_transform = f_transform_new
177179
return mc_state_new
178180
end
179-
180-
181-
# Copied from AdvancedHMC.jl, but also return proposed point
182-
function _bat_transition(
183-
rng::AbstractRNG,
184-
τ::AdvancedHMC.Trajectory{TS,I,TC},
185-
h::AdvancedHMC.Hamiltonian,
186-
z0::AdvancedHMC.PhasePoint,
187-
) where {
188-
TS<:AdvancedHMC.AbstractTrajectorySampler,
189-
I<:AdvancedHMC.AbstractIntegrator,
190-
TC<:AdvancedHMC.DynamicTerminationCriterion,
191-
}
192-
H0 = AdvancedHMC.energy(z0)
193-
tree = AdvancedHMC.BinaryTree(
194-
z0,
195-
z0,
196-
AdvancedHMC.TurnStatistic.termination_criterion, z0),
197-
zero(H0),
198-
zero(Int),
199-
zero(H0),
200-
)
201-
sampler = TS(rng, z0)
202-
termination = AdvancedHMC.Termination(false, false)
203-
zcand = z0
204-
proposed_zs = Vector[]
205-
accept_probs = Float64[]
206-
207-
j = 0
208-
while !AdvancedHMC.isterminated(termination) && j < τ.termination_criterion.max_depth
209-
v = rand(rng, [-1, 1])
210-
if v == -1
211-
tree′, sampler′, termination′ =
212-
AdvancedHMC.build_tree(rng, τ, h, tree.zleft, sampler, v, j, H0)
213-
treeleft, treeright = tree′, tree
214-
else
215-
tree′, sampler′, termination′ =
216-
AdvancedHMC.build_tree(rng, τ, h, tree.zright, sampler, v, j, H0)
217-
treeleft, treeright = tree, tree′
218-
end
219-
220-
# This acceptance prob. is specific to AdvancedHMC.MultinomialTS
221-
p_tmp = min(1, exp(sampler′.ℓw - sampler.ℓw))
222-
push!(accept_probs, p_tmp)
223-
push!(proposed_zs, sampler′.zcand.θ)
224-
225-
if !AdvancedHMC.isterminated(termination′)
226-
j = j + 1
227-
if AdvancedHMC.mh_accept(rng, sampler, sampler′)
228-
zcand = sampler′.zcand
229-
end
230-
end
231-
tree = AdvancedHMC.combine(treeleft, treeright)
232-
sampler = AdvancedHMC.combine(zcand, sampler, sampler′)
233-
termination =
234-
termination *
235-
termination′ *
236-
AdvancedHMC.isterminated.termination_criterion, h, tree, treeleft, treeright)
237-
end
238-
239-
H = AdvancedHMC.energy(zcand)
240-
tstat = AdvancedHMC.merge(
241-
(
242-
n_steps = tree.nα,
243-
is_accept = true,
244-
acceptance_rate = tree.sum_α / tree.nα,
245-
log_density = zcand.ℓπ.value,
246-
hamiltonian_energy = H,
247-
hamiltonian_energy_error = H - H0,
248-
max_hamiltonian_energy_error = tree.ΔH_max,
249-
tree_depth = j,
250-
numerical_error = termination.numerical,
251-
),
252-
AdvancedHMC.stat.integrator),
253-
)
254-
255-
accept_total = sum(accept_probs)
256-
z_proposed = iszero(accept_total) ? sum(proposed_zs) / length(proposed_zs) : sum(accept_probs .* proposed_zs) / accept_total
257-
p_accept = tstat.acceptance_rate
258-
259-
return AdvancedHMC.Transition(zcand, tstat), z_proposed, p_accept
260-
end

src/samplers/mcmc/mcmc_state.jl

+2
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ function mcmc_step!!(mcmc_state::MCMCState)
161161

162162
current = _current_sample_idx(chain_state)
163163
proposed = _proposed_sample_idx(chain_state)
164+
165+
# This does not change `sample_z` in the chain_state, that happens in the next mcmc step in `_cleanup_samples()`.
164166
_accept_reject!(chain_state, accepted, p_accept, current, proposed)
165167

166168
mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept)

src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl

+5
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ function mcmc_tune_post_step!!(
7676
mc_state::MCMCChainState,
7777
p_accept::Real,
7878
)
79+
80+
if current_sample_z(mc_state).v == proposed_sample_z(mc_state)
81+
return mc_state, tuner_state
82+
end
83+
7984
(; f_transform, sample_z) = mc_state
8085
(; target_acceptance, gamma) = tuner_state.tuning
8186
b = f_transform.b

0 commit comments

Comments
 (0)