Parallel setup / Auto Sharding mode for general tensor operation? #32494
-
|
Hello, This is an attempt to get a simple working guide on parallel operations: import multiprocessing, os
PUcount = multiprocessing.cpu_count() #not trustworthy on servors
os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={PUcount}' # Use N CPU core as devices
import jax, jax.numpy as jnp, jax.lax as lax ; from jax.tree_util import Partial
Pmode = 1
if Pmode:
from jax.sharding import Mesh, PartitionSpec, NamedSharding, Sharding
# meshing how devices behave given tensor axis to be treated in parallel
axsn = ('cathair',)
mesh = Mesh( jax.devices(),
axis_names = axsn,
axis_types =(jax.sharding.AxisType.Auto,),
)
# Partition
spec = PartitionSpec(*axsn) # jax.P is the alias of jax.sharding.PartitionSpec
# Sharding object
sharding = jax.sharding.NamedSharding(mesh, spec)
tileme= Partial(jax.device_put , src = sharding, may_alias = True) # dispatch for concurrent compute
else:
tileme= jax.jit(lambda arg: arg)
Y = jnp.ones( (32,64,7) )
print(tileme(Y).sharding) # -> SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device) This is unexpected..
#case 1 - using psum and force multi device compute
O = lax.psum( tileme(Y), tuple(range(1,2)))
print(O.sharding) # same suprise : SingleDeviceSharding..
#case 2 - having a dynamic tensor contraction function, how to map the mesh in an automatic way?
precision = jax.lax.Precision('high')
@jax.jit
def f(x: jax.Array): #, **args):
A = jnp.ones((x.shape[-1],3), dtype = jnp.float32)
return lax.dot( tileme(x), tileme(A), dimension_numbers = (((x.ndim-1,),(0,)),((),())),
precision = precision) )
O = f(tileme(Y)) ; print(O.shape, tileme(O).sharding) # same SingleDeviceSharding outcome
#case 3 - sharding a pytree like structure
# say that I want to init a pytree whose size is known
mytree= [jnp.ones((4, 3, 2)), jnp.ones((7, 5, 1))]
ade = jax.jit( lambda x: x) #can be replaced by a wide set of functions
result = jax.tree_util.tree_map(ade, mytree)how to parallelize case 3? given : #11394 Thanks in advance! NB : setting : |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
|
Update: Why JAX does not automatically let some devices idle if needed? (quite rigid out of ML training scope) I tried coding a jax.device_put macro for automating device count with tensor dimension: import jax, jax.numpy as jnp, jax.lax as lax, numpy as np
from jax.tree_util import Partial
from jax.sharding import Mesh
from typing import Sequence
uint = lambda el, b = 1: el // b + b * bool(el % b)
lmax = lambda *l : int(.5*(sum(l) + abs( l[0]-l[1])) )
def tileme( info : Sequence[int] | int = (8,), devices = np.array(jax.devices()) ):
""" automatic dispatch with convenient parameters
devices = None or list of devices
"""
axsn = tuple( str(el) for el in range(len(info)) )#('cathair',)
# if not devices:
# devices = jax.devices()
# maxDev = (len(jax.devices()),)
# else:
devices = np.asarray(devices)
maxDev = np.array(devices).shape
mesh = Mesh( devices[tuple( slice( lmax( el, el//uint(el/al) ) ) for el,al in zip(info,maxDev) )],
axis_names = axsn,
axis_types =(jax.sharding.AxisType.Auto,),
)
spec = jax.P(*axsn)
sharding = jax.sharding.NamedSharding(mesh, spec)
return Partial(jax.device_put , device = sharding, may_alias = True)
precision = jax.lax.Precision('high')
@jax.jit
def f(x: jax.Array):
A = jnp.ones((x.shape[-1],3), dtype = jnp.float32)
return jax.lax.dot( tileme(info=x.shape[-1:])(x), tileme(info=x.shape[-1:])(A), dimension_numbers = (((x.ndim-1,),(0,)),((),())),
precision = precision) This works on z = jnp.ones((4,3,2,7))
f(z)
It seems I must explicitly state every bit of the tensor shape as axis info of the jax.device_put call, however it should not be needed (since jax.Array same as numpy.array, tensor indexing is virtually calling a very lengthy contiguous vector with regular stepping) Same exercise for case 1: Y = jnp.ones( (32,64,7) )
O = lax.psum( tileme(info = Y.shape[1:2])(Y), tuple(range(1,2)))
O.sharding # NamedSharding(mesh=Mesh('0': 16, axis_types=(Auto,)), spec=PartitionSpec('0',), memory_kind=device) Now the sharding works , case 1 solved ! Just a gentle nudge for @jakevdp, @emilyfertig , @superbobry, @hawkinsp, @cgarciae on the topic. |
Beta Was this translation helpful? Give feedback.
-
|
Hello, Little update on the matter: import os
PUcount = 16
os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={PUcount}'
import jax, jax.numpy as jnp, jax.lax as lax
from jax.tree_util import Partial
from jax.sharding import Mesh, PartitionSpec, NamedSharding, Sharding
# meshing how devices behave given tensor axis to be treated in parallel
axsn = ('rig',) #
maxDev = len(jax.devices())
mesh = Mesh( jax.devices(),
axis_names = axsn,
axis_types =(jax.sharding.AxisType.Auto,),
)
jax.sharding.set_mesh(mesh) # without this, no implicit parallel dispatch
precision = jax.lax.Precision('high')
@jax.jit
def f(x: jax.Array): #, **args):
A = jnp.ones((x.shape[-1],3), dtype = jnp.float32)
return lax.dot( x, A,
dimension_numbers = (((x.ndim-1,),(0,)),((),())), precision = precision ) #, **args)
z2 = jnp.ones((4,321,21,764)) ; y2 = f(z2) ; print(y2.sharding) # NamedSharding(mesh=Mesh('rig': 16, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=device)
z = jnp.ones((4,3,2,7)); y = f(z) ; print(y.sharding) #case 3 is a similar problem to #32993 (comment) argument about ragged arrays. This is still in development. |
Beta Was this translation helpful? Give feedback.

Hello,
Little update on the matter:
case 2 is actually solved by compiling through jit which will parallelize if you set JAX right before importation :