Skip to content
Discussion options

You must be logged in to vote

Great question!

Yes, you can switch the abstract mesh inside jit. For example:

def f():
    mesh = jax.make_mesh((2, 2), ('x', 'y'), axis_types=(jax.sharding.AxisType.Explicit,) * 2)
    jax.set_mesh(mesh)
    np_inp = np.arange(16.).reshape(8, 2)
    abstract_mesh = mesh.abstract_mesh

    @jax.jit
    def f(x):
      x = jnp.sin(x)

     # We are switching the abstract mesh here!!
      with jax.sharding.use_abstract_mesh(
          abstract_mesh.update(axis_sizes=(4, 1), axis_names=('a', 'b'))):

        # Once you switch the abstract mesh, you need to reshard the inputs to the new mesh in that context.
        x = reshard(x, P(('a', 'b'), None))

        out = x * 2
        assert out…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@Prayer3th
Comment options

@yashk2810
Comment options

Answer selected by Prayer3th
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants