We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a3e235f commit 83bc3a0Copy full SHA for 83bc3a0
blackjax/mcmc/integrators.py
@@ -374,9 +374,8 @@ def format_isokinetic_state_output(
374
375
def generate_isokinetic_integrator(coefficients):
376
def isokinetic_integrator(
377
- logdensity_fn: Callable, *args, **kwargs
+ logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0
378
) -> GeneralIntegrator:
379
- sqrt_diag_cov = kwargs.get("sqrt_diag_cov", 1.0)
380
position_update_fn = euclidean_position_update_fn(logdensity_fn)
381
one_step = generalized_two_stage_integrator(
382
esh_dynamics_momentum_update_one_step(sqrt_diag_cov),
0 commit comments