Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FR: batched mapping for named_axes.nmap #71

Open
amifalk opened this issue Jul 26, 2024 · 5 comments
Open

FR: batched mapping for named_axes.nmap #71

amifalk opened this issue Jul 26, 2024 · 5 comments
Labels
feature-request New feature or request

Comments

@amifalk
Copy link

amifalk commented Jul 26, 2024

Sometimes nmap'ed computations don't all fit in memory at once and there are not enough devices to shard the computation over (this is limitation is particularly salient when using penzai because adding an arbitrary number of named axes is so darn convenient :) )

Jax now supports symantics for batched vmapping with jax.lax.map. This would be awesome to add to Penzai!

@danieldjohnson danieldjohnson added the feature-request New feature or request label Jul 31, 2024
@danieldjohnson
Copy link
Collaborator

Hm, interesting idea!

One question is what the API for this should be. Some ideas:

  • Add a keyword argument to nmap: Something like nmap(fn, batch_sizes={"foo":2, "bar":4}). This would be easy to implement, but a disadvantage is that you'd need to plumb through the batch sizes for each axis into each computation. For instance, if you wanted to add a new batch axis but map over it, you'd need to modify all of the calls to nmap inside the function you are calling, which would be fairly annoying.
  • Bind it to the named array: Named arrays could store a batch size for each of their axes, similar to how JAX arrays store shardings. Then any nmap call could automatically pick up and use the batch size. I think this would be pretty complex to implement, though. It also doesn't compose well with tag/untag, although I guess the rule could just be that untagging an axis resets its batching size?
  • Use a context manager: We could have a context manager that determines how batch sizes for each axis name are determined, and then all nmap calls inside the context manager would read from it. That would make it easier to specify batch sizes on a per-name level. However, it might not work well with JAX tracing, since tracers would have to be aware of the context manager somehow.

@amifalk
Copy link
Author

amifalk commented Jul 31, 2024

It occurs to me that adding batched mapping directly to nmap may not be the right thing to do here.

I'm thinking about this feature in the context of batching model evaluations on a grid of data inputs.
This is important when serving models with many requests coming in or performing a grid search over hyperparameters, just to give two examples. In this context, I only ever want to batch over the outermost code.

You might imagine that an arbitrary neural network in Penzai looks something like this under the hood:

def eval_nn(x): 
    x = nmap(custom_layer_1)(x)
    x = nmap(custom_layer_2)(x)
    x = nmap(custom_layer_3)(x)
    return x 

When we batch-nmap this over a grid, the batching will be pushed inside each nmap if we use any global configuration applied to the nmap operator (as in the binding or context manager proposal).

e.g:

def eval_nn(x): 
    x = nmap(custom_layer_1)(x)
    x = nmap(custom_layer_2)(x)
    x = nmap(custom_layer_3)(x)
    return x 

inputs = pz.nx.ones({"batch": 1_000, "x": 50, "y": 50})

with pz.nx.batch_nmap(batch_sizes={"batch": 10}):
    eval_nn(x)

will evaluate to:

def eval_nn(x): 
    x = nmap(custom_layer_1, batch=10)(x)
    x = nmap(custom_layer_2, batch=10)(x)
    x = nmap(custom_layer_3, batch=10)(x)
    return x 

I made a little micro-benchmark and piping the map ~40% slower than batch-mapping the outside on Jax 0.4.31 with an Nvidia GeForce 4090, likely due to the transfer overheads between the host and device after each scan.

#%%
import jax
import jax.numpy as jnp
import jax.random as random


def layer(arr):
    return jnp.matmul(arr, arr - jnp.mean(arr))     
    
BATCH_SIZE = 20
    
def map_each_layer(batch_of_arrs):
    batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)
    batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)
    batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)
    batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)
    batch_of_arrs = jax.lax.map(layer, batch_of_arrs, batch_size=BATCH_SIZE)

    return batch_of_arrs

def all_layers(arr):
    arr = layer(arr)
    arr = layer(arr)
    arr = layer(arr)
    arr = layer(arr)
    arr = layer(arr)    
    return arr    
    
def map_all_layers(batch_of_arrs):
    return jax.lax.map(all_layers, batch_of_arrs, batch_size=BATCH_SIZE)


map_each_layer_jit = jax.jit(map_each_layer)
map_all_layers_jit = jax.jit(map_all_layers)

batch_of_arrs = random.normal(random.PRNGKey(0), (500, 100, 100))
#%%
%time map_each_layer_jit(batch_of_arrs).block_until_ready() 
%time map_all_layers_jit(batch_of_arrs).block_until_ready() 
#%%
%timeit map_each_layer_jit(batch_of_arrs).block_until_ready() 
%timeit map_all_layers_jit(batch_of_arrs).block_until_ready() 
# %%
143 ms ± 72.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) # internal piping of scan
105 ms ± 68.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) # one map on the outside

Instead, maybe we should have some function pz.nx.batch (this could also be a decorator) that batch-nmaps only over the specified axes and does not pipe in the mapping to the internal functions.

Like lax.map, this would just be syntactic sugar over lax.scan. Here's some rough pseudo-code:

def batch(fn, xs, batch_sizes=None):
    if batch_sizes is None:
        batch_sizes = {}
    batch_axes, axis_sizes = tuple(batch_sizes.keys()), tuple(batch_sizes.items())
    xs = xs.untag(*batch_axes).reshape((-1,) + axis_sizes)

    return pz.nx.stack(pz.nx.scan(fn, axes=batch_axes), axes=batch_axes)

# usage: 
inputs = pz.nx.ones({"batch": 1_000, "x": 50, "y": 50})
pz.nx.batch(eval_nn, batch_sizes={"batch": 10})(inputs)

What do you think?

@amifalk
Copy link
Author

amifalk commented Aug 4, 2024

I wrote up a prototype of this functionality described above ^ here: https://gist.github.com/amifalk/e21059da7f0c0ecb3db8240604413998

I realized there's no benefit to allowing different batch sizes for different axes given that they're all evaluated independently. In the worst case, the remainder named array won't have the same shape as the batch array, so it will not be possible to broadcast them back together.

@danieldjohnson
Copy link
Collaborator

I agree that having a single scan at the outside seems better than having a number of smaller scans in the inside for this use case. Thanks for running the benchmark, and for taking a stab at the implementation!

I wonder if the best solution here would be to aim to match the semantics of jax.lax.map with a named-axis version, similar to the relationship between jax.lax.scan and pz.nx.scan. If so, this would suggest that:

  • the mapped-over axis should be a single axis (to keep it simple, similar to the APIs for jax.lax.map and pz.nx.scan)
  • however, unlike jax.lax.map, the mapped-over axis should be a named axis (like pz.nx.scan)
  • the batch_size option should be optional (similar to jax.lax.map), and default to 1
  • the axis being mapped over should be removed from the inputs, and added back to the outputs, instead of still appearing inside with a smaller size and requiring f to keep it as-is
  • ideally, it should be OK to have positional axes, and they should just be passed through transparently

One question is what the name of this should be:

  • pz.nx.map would fit with pz.nx.scan, but might get confused with pz.nx.nmap?
  • pz.nx.serial_map is more explicit but more verbose.

What do you think?

(This makes me think of a related question: would it be useful to provide pz.nx.vmap which is like jax.vmap but maps over a single named axis in parallel, keeping all the other named axes? This would pull the batching out instead of pushing it inward to the inner functions. I'm not sure if there's much of a reason to have this, though. Perhaps avoiding name conflicts?)

@amifalk
Copy link
Author

amifalk commented Aug 8, 2024

Agreed w.r.t. all points emulating jax.lax.map. My gut says the name pz.nx.map is fine given that it will have a different function signature then nmap, but I don't have a strong feeling between that and serial_map.

On pz.nx.vmap, I would guess the XLA wouldn't change and it strikes me as a bit of an anti-pattern that would make functions more brittle and might confuse new users (once you nmap, you never go back!).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature-request New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants