Skip to content
Discussion options

You must be logged in to vote

Hello,

Little update on the matter:
case 2 is actually solved by compiling through jit which will parallelize if you set JAX right before importation :

import os
PUcount = 16
os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={PUcount}'
import jax, jax.numpy as jnp, jax.lax as lax
from jax.tree_util import Partial

from jax.sharding import Mesh, PartitionSpec, NamedSharding, Sharding
# meshing how devices behave given tensor axis to be treated in parallel
axsn = ('rig',) #
maxDev = len(jax.devices())
mesh = Mesh( jax.devices(), 
        axis_names = axsn, 
        axis_types =(jax.sharding.AxisType.Auto,),
        )
jax.sharding.set_mesh(mesh) # without this, no implicit p…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by DiagRisker
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
1 participant