Skip to content
Discussion options

You must be logged in to vote

There are a few things going on here.

  1. init_value = 0 should always fail. The gradient of jax.lax.reduce_window is only defined for jax.lax.max when init_value is -inf. If you want to use a different init value, you need to define the gradient yourself.

  2. When you jit compile the function using jax array types for init_value (jnp.float32(...) and jnp.array(...)), the array types are traced. This breaks the logic of jax.lax.reduce_window.

The gradient computation of jax.lax.reduce_window has static branching logic that depends on the actual scalar value of init_value. Since values in traced arrays are are not accessible, the code fails when you call jit over top of grad.

  1. jnp.inf is an …

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by yippiez
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants