Open
Description
With code from the HMC docs, pyright throws an error on the typing.
import blackjax
from jax import numpy as jnp
def f(x, key):
return 0
hmc = blackjax.hmc(
f, 1.0, jnp.array(1.0), 1
)
x = hmc.init(jnp.array(1.0))
yields
- error: Argument missing for parameter "rng_key" (reportCallIssue)
1 error, 0 warnings, 0 informations
This is based on the code from https://blackjax-devs.github.io/blackjax/autoapi/blackjax/mcmc/hmc/index.html#blackjax.mcmc.hmc.init.
Metadata
Metadata
Assignees
Labels
No labels