Skip to content
Discussion options

You must be logged in to vote

Your operation looks like a convolution, in which case it will be much more performant to compute it via an actual convolution primitive than via a loop or scan.

Here's an example of how to compute the same values using a variant of your NumPy function and with jax.numpy.convolve:

import numpy as np
import jax
import jax.numpy as jnp

def f_np(arr1, arr2):
  """NumPy implementation using a loop."""
  assert arr1.shape == arr2.shape
  n, m, nt = arr1.shape
  result = np.zeros(arr1.shape)
  for it in range(1, nt):
    result[:,:,it] = np.sum(arr1[:,:,-it:] * arr2[:,:,:it],axis=2)
  return result

@jax.vmap
@jax.vmap
def f_jax(arr1, arr2):
  """JAX implementation using jnp.convolve."""
  assert

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@venkatou
Comment options

Answer selected by venkatou
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants