-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FR: batched mapping for named_axes.nmap
#71
Comments
Hm, interesting idea! One question is what the API for this should be. Some ideas:
|
It occurs to me that adding batched mapping directly to I'm thinking about this feature in the context of batching model evaluations on a grid of data inputs. You might imagine that an arbitrary neural network in Penzai looks something like this under the hood: def eval_nn(x):
x = nmap(custom_layer_1)(x)
x = nmap(custom_layer_2)(x)
x = nmap(custom_layer_3)(x)
return x When we batch-nmap this over a grid, the batching will be pushed inside each e.g: def eval_nn(x):
x = nmap(custom_layer_1)(x)
x = nmap(custom_layer_2)(x)
x = nmap(custom_layer_3)(x)
return x
inputs = pz.nx.ones({"batch": 1_000, "x": 50, "y": 50})
with pz.nx.batch_nmap(batch_sizes={"batch": 10}):
eval_nn(x) will evaluate to: def eval_nn(x):
x = nmap(custom_layer_1, batch=10)(x)
x = nmap(custom_layer_2, batch=10)(x)
x = nmap(custom_layer_3, batch=10)(x)
return x I made a little micro-benchmark and piping the map ~40% slower than batch-mapping the outside on Jax 0.4.31 with an Nvidia GeForce 4090, likely due to the transfer overheads between the host and device after each scan. #%%
import jax
import jax.numpy as jnp
import jax.random as random
def layer(arr):
return jnp.matmul(arr, arr - jnp.mean(arr))
BATCH_SIZE = 20
def map_each_layer(batch_of_arrs):
batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)
batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)
batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)
batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)
batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)
return batch_of_arrs
def all_layers(arr):
arr = layer(arr)
arr = layer(arr)
arr = layer(arr)
arr = layer(arr)
arr = layer(arr)
return arr
def map_all_layers(batch_of_arrs):
return jax.lax.map(all_layers, batch_of_arrs, batch_size=BATCH_SIZE)
map_each_layer_jit = jax.jit(map_each_layer)
map_all_layers_jit = jax.jit(map_all_layers)
batch_of_arrs = random.normal(random.PRNGKey(0), (500, 100, 100))
#%%
%time map_each_layer_jit(batch_of_arrs).block_until_ready()
%time map_all_layers_jit(batch_of_arrs).block_until_ready()
#%%
%timeit map_each_layer_jit(batch_of_arrs).block_until_ready()
%timeit map_all_layers_jit(batch_of_arrs).block_until_ready()
# %%
Instead, maybe we should have some function Like def batch(fn, xs, batch_sizes=None):
if batch_sizes is None:
batch_sizes = {}
batch_axes, axis_sizes = tuple(batch_sizes.keys()), tuple(batch_sizes.items())
xs = xs.untag(*batch_axes).reshape((-1,) + axis_sizes)
return pz.nx.stack(pz.nx.scan(fn, axes=batch_axes), axes=batch_axes)
# usage:
inputs = pz.nx.ones({"batch": 1_000, "x": 50, "y": 50})
pz.nx.batch(eval_nn, batch_sizes={"batch": 10})(inputs) What do you think? |
I wrote up a prototype of this functionality described above ^ here: https://gist.github.com/amifalk/e21059da7f0c0ecb3db8240604413998 I realized there's no benefit to allowing different batch sizes for different axes given that they're all evaluated independently. In the worst case, the remainder named array won't have the same shape as the batch array, so it will not be possible to broadcast them back together. |
I agree that having a single I wonder if the best solution here would be to aim to match the semantics of
One question is what the name of this should be:
What do you think? (This makes me think of a related question: would it be useful to provide |
Agreed w.r.t. all points emulating On |
Sometimes
nmap
'ed computations don't all fit in memory at once and there are not enough devices to shard the computation over (this is limitation is particularly salient when using penzai because adding an arbitrary number of named axes is so darn convenient :) )Jax now supports symantics for batched vmapping with
jax.lax.map
. This would be awesome to add to Penzai!The text was updated successfully, but these errors were encountered: