-
|
Using reduce_window with def maxpool_step(x, kernel_size, stride):
return jax.lax.reduce_window(
x,
jnp.float32(0.0),
jax.lax.max,
window_dimensions=(1, 1, kernel_size, kernel_size),
window_strides=(1, 1, stride, stride),
padding="VALID",
)
_grad = jax.grad(lambda x: jnp.sum(maxpool_step(x, kernel_size=2, stride=2)))
_grad_jit = jit(_grad) # => ValueError: Linearization failed ...Interestingly removing jit compilation from the above make it not throw out this error even though it's the same. Below is the working version in which both def maxpool_step(x, kernel_size, stride):
return jax.lax.reduce_window(
x,
-jnp.inf, # 0.0, jnp.float32(0.0), jnp.array(0.0) or jnp.array(-jnp.inf) all cause errors
jax.lax.max,
window_dimensions=(1, 1, kernel_size, kernel_size),
window_strides=(1, 1, stride, stride),
padding="VALID",
)My questions are:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
There are a few things going on here.
The gradient computation of
|
Beta Was this translation helpful? Give feedback.
There are a few things going on here.
init_value = 0should always fail. The gradient ofjax.lax.reduce_windowis only defined forjax.lax.maxwheninit_valueis-inf. If you want to use a different init value, you need to define the gradient yourself.When you jit compile the function using jax array types for
init_value(jnp.float32(...)andjnp.array(...)), the array types are traced. This breaks the logic ofjax.lax.reduce_window.The gradient computation of
jax.lax.reduce_windowhas static branching logic that depends on the actual scalar value ofinit_value. Since values in traced arrays are are not accessible, the code fails when you call jit over top of grad.jnp.infis an …