Skip to content

Commit 5f5396a

Browse files
Micki-Doschulz
authored andcommitted
Change proposed sample in hmc to weighted mean, Fix weight assignment error in mcmc_stepgit add -A ()
1 parent 18d1115 commit 5f5396a

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

Diff for: ext/ahmc_impl/ahmc_sampler_impl.jl

+12-4
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,12 @@ function BAT.mcmc_propose!!(mc_state::HMCState)
125125

126126
p_accept = AdvancedHMC.stat(proposal.transition).acceptance_rate
127127

128-
x_proposed[:] = f_transform(z_proposed)
128+
x_proposed[:], ladj = with_logabsdet_jacobian(f_transform, z_proposed)
129129
logd_x_proposed = logdensityof(target, x_proposed)
130130
samples.logd[proposed_x_idx] = logd_x_proposed
131131

132+
sample_z.logd[proposed_z_idx] = logd_x_proposed + ladj
133+
132134
return mc_state, accepted, p_accept
133135
end
134136

@@ -200,6 +202,7 @@ function _bat_transition(
200202
termination = AdvancedHMC.Termination(false, false)
201203
zcand = z0
202204
proposed_zs = Vector[]
205+
accept_probs = Float64[]
203206

204207
j = 0
205208
while !AdvancedHMC.isterminated(termination) && j < τ.termination_criterion.max_depth
@@ -213,14 +216,18 @@ function _bat_transition(
213216
AdvancedHMC.build_tree(rng, τ, h, tree.zright, sampler, v, j, H0)
214217
treeleft, treeright = tree, tree′
215218
end
219+
220+
# This acceptance prob. is specific to AdvancedHMC.MultinomialTS
221+
p_tmp = min(1, exp(sampler′.ℓw - sampler.ℓw))
222+
push!(accept_probs, p_tmp)
223+
push!(proposed_zs, sampler′.zcand.θ)
224+
216225
if !AdvancedHMC.isterminated(termination′)
217226
j = j + 1
218227
if AdvancedHMC.mh_accept(rng, sampler, sampler′)
219228
zcand = sampler′.zcand
220229
end
221230
end
222-
push!(proposed_zs, sampler′.zcand.θ)
223-
224231
tree = AdvancedHMC.combine(treeleft, treeright)
225232
sampler = AdvancedHMC.combine(zcand, sampler, sampler′)
226233
termination =
@@ -245,7 +252,8 @@ function _bat_transition(
245252
AdvancedHMC.stat.integrator),
246253
)
247254

248-
z_proposed = proposed_zs[end]
255+
accept_total = sum(accept_probs)
256+
z_proposed = iszero(accept_total) ? sum(proposed_zs) / length(proposed_zs) : sum(accept_probs .* proposed_zs) / accept_total
249257
p_accept = tstat.acceptance_rate
250258

251259
return AdvancedHMC.Transition(zcand, tstat), z_proposed, p_accept

Diff for: src/samplers/mcmc/mcmc_state.jl

+3-5
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,13 @@ function mcmc_step!!(mcmc_state::MCMCState)
159159

160160
chain_state, accepted, p_accept = mcmc_propose!!(chain_state)
161161

162-
mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept)
163-
164-
chain_state = mcmc_state_new.chain_state
165-
166162
current = _current_sample_idx(chain_state)
167163
proposed = _proposed_sample_idx(chain_state)
168-
169164
_accept_reject!(chain_state, accepted, p_accept, current, proposed)
170165

166+
mcmc_state_new = mcmc_tune_post_step!!(mcmc_state, p_accept)
167+
168+
chain_state = mcmc_state_new.chain_state
171169
mcmc_state_final = @set mcmc_state_new.chain_state = chain_state
172170

173171
return mcmc_state_final

Diff for: test/samplers/mcmc/test_hmc.jl

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import AdvancedHMC
2626
@testset "MCMC iteration" begin
2727
v_init = bat_initval(target, InitFromTarget(), context).result
2828
# Note: No @inferred, since MCMCChainState is not type stable (yet) with HamiltonianMC
29-
# TODO: MD, reactivate
3029
@test BAT.MCMCChainState(samplingalg, target, 1, unshaped(v_init, varshape(target)), deepcopy(context)) isa BAT.HMCState
3130
mcmc_state = BAT.MCMCState(samplingalg, target, 1, unshaped(v_init, varshape(target)), deepcopy(context))
3231
nsteps = 10^4

0 commit comments

Comments
 (0)