62
62
63
63
MIN_COMPUTE_CAPABILITY = 52
64
64
65
+ _DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo'
66
+
65
67
# TODO(phawkins): Remove jax_xla_backend.
66
68
_XLA_BACKEND = config .string_flag (
67
69
'jax_xla_backend' , '' ,
@@ -235,7 +237,9 @@ def make_cpu_client(
235
237
Returns:
236
238
The created CPU client.
237
239
"""
238
- if collectives is None :
240
+ # TODO(skyewm): use distributed.is_initialized() after
241
+ # https://github.com/jax-ml/jax/pull/26172 goes in.
242
+ if collectives is None and distributed .global_state .client is not None :
239
243
collectives_impl = config .cpu_collectives_implementation .value
240
244
if _CPU_ENABLE_GLOO_COLLECTIVES .value :
241
245
collectives_impl = 'gloo'
@@ -244,6 +248,9 @@ def make_cpu_client(
244
248
'"jax_cpu_collectives_implementation", "gloo")` instead.' ,
245
249
DeprecationWarning ,
246
250
)
251
+ if collectives_impl is None :
252
+ collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL
253
+
247
254
if collectives_impl == 'gloo' :
248
255
collectives = xla_client ._xla .make_gloo_tcp_collectives (
249
256
distributed_client = distributed .global_state .client ,
@@ -252,8 +259,6 @@ def make_cpu_client(
252
259
collectives = xla_client ._xla .make_mpi_collectives ()
253
260
collectives .Init ()
254
261
atexit .register (collectives .Finalize )
255
- elif collectives_impl == 'megascale' :
256
- raise ValueError ('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"' )
257
262
else :
258
263
# Already validated by config module
259
264
assert collectives_impl is None
0 commit comments