Skip to content
Discussion options

You must be logged in to vote

The operative difference here is that in the dictionary-based code, you're constructing 64 array buffers on the device, then concatenating them with jnp.array([parameters[k] for k in keys]). In the array-based code, you're constructing a single array on the device. This relative lack of data movement makes the array version of the code faster.

Here's a simpler demonstration of the same thing, where we avoid the dict question and just compare passing 64 individual values vs. those same 64 values in an array:

import jax

@jax.jit
def f1(x_array):
  return x_array * 2

@jax.jit
def f2(x_list):
  return jnp.array(x_list) * 2

x_list = list(range(64))
x_array = jnp.array(x_list)

_ = f1(x_array).

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@fernando-garcia-cortez
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by fernando-garcia-cortez
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants