Python Dictionary lookups in the context of jax. Deeper reasons as to why array-based code is faster than dictionary-based code. #33402
-
|
Classically speaking (in the absence of JAX), it is known that dictionaries can be slower than simply feeding a vector (array) with parameters. Still, when function inputs (parameters) have names (e.g., codes for physics or other natural sciences), it is often handy to have dictionaries as your input. That way, code becomes more readable and easier to keep track of, for example: Recently I changed my entire code to be purely (JAX numpy) array-based, without any use of dictionaries, and this drastically improved the run time of my code. Is there a reason (in the context of JAX, so for gradients, tracing, JIT, etc.) for my code to run much, much faster with just arrays and no dictionaries? What can be said about dictionary lookups in the context of JAX? I decided to test whether the dictionary lookup was the biggest slowdown, so I wrote a very artificial looking function to stress-test this: Observe that the functions differ in that the dictionary version first unpacks the dictionary as an array (in other words, the dictionary lookup occurs one per call). I found that after running them 1000 times (both were jitted and burned in before benchmark), the running times were:
Fantastic! But now I'm puzzled about why this is the case. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
|
Can you share your benchmarking code? (in particular, I'm not sure what |
Beta Was this translation helpful? Give feedback.
-
|
The operative difference here is that in the dictionary-based code, you're constructing 64 array buffers on the device, then concatenating them with Here's a simpler demonstration of the same thing, where we avoid the dict question and just compare passing 64 individual values vs. those same 64 values in an array: import jax
@jax.jit
def f1(x_array):
return x_array * 2
@jax.jit
def f2(x_list):
return jnp.array(x_list) * 2
x_list = list(range(64))
x_array = jnp.array(x_list)
_ = f1(x_array).block_until_ready()
%timeit f1(x_array).block_until_ready()
# 12.6 µs ± 3.08 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
_ = f2(*x_list).block_until_ready()
%timeit f2(x_list).block_until_ready()
# 284 µs ± 18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) |
Beta Was this translation helpful? Give feedback.
The operative difference here is that in the dictionary-based code, you're constructing 64 array buffers on the device, then concatenating them with
jnp.array([parameters[k] for k in keys]). In the array-based code, you're constructing a single array on the device. This relative lack of data movement makes the array version of the code faster.Here's a simpler demonstration of the same thing, where we avoid the dict question and just compare passing 64 individual values vs. those same 64 values in an array: