-
|
I need some help with writing the following piece of code in Jax-friendly way example in numpy One of my incorrect attempts (non-static size issues) (Array were converted to jnp) Any guidance on this is much appreciated. Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Your operation looks like a convolution, in which case it will be much more performant to compute it via an actual convolution primitive than via a loop or scan. Here's an example of how to compute the same values using a variant of your NumPy function and with import numpy as np
import jax
import jax.numpy as jnp
def f_np(arr1, arr2):
"""NumPy implementation using a loop."""
assert arr1.shape == arr2.shape
n, m, nt = arr1.shape
result = np.zeros(arr1.shape)
for it in range(1, nt):
result[:,:,it] = np.sum(arr1[:,:,-it:] * arr2[:,:,:it],axis=2)
return result
@jax.vmap
@jax.vmap
def f_jax(arr1, arr2):
"""JAX implementation using jnp.convolve."""
assert arr1.shape == arr2.shape
nt, = arr1.shape
conv_result = jnp.convolve(arr1[::-1], arr2)
return jnp.zeros(nt).at[1:].set(conv_result[:nt - 1])
nt= 50
arr1 = np.random.rand(10, 500, nt)
arr2 = np.random.rand(10, 500, nt)
numpy_result = f_np(arr1, arr2)
jax_result = f_jax(arr1, arr2)
np.testing.assert_allclose(numpy_result, jax_result, rtol=1E-6) # 1E-6 because JAX uses float32(note I had to change |
Beta Was this translation helpful? Give feedback.
Your operation looks like a convolution, in which case it will be much more performant to compute it via an actual convolution primitive than via a loop or scan.
Here's an example of how to compute the same values using a variant of your NumPy function and with
jax.numpy.convolve: