Skip to content

Commit f07243a

Browse files
skyeGoogle-ML-Automation
authored andcommitted
Default JAX_CPU_COLLECTIVES_IMPLEMENTATION to 'gloo'.
This enables CPU collectives by default, making multi-process CPU communication work without extra configuration. PiperOrigin-RevId: 724076284
1 parent 4b86ff2 commit f07243a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2525
* Changes
2626
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
2727
env vars. Before they could only be specified via jax.config or flags.
28+
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` now defaults to `'gloo'`, meaning
29+
multi-process CPU communication works out-of-the-box.
2830
* The `jax[tpu]` TPU extra no longer depends on the `libtpu-nightly` package.
2931
This package may safely be removed if it is present on your machine; JAX now
3032
uses `libtpu` instead.

jax/_src/xla_bridge.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262

6363
MIN_COMPUTE_CAPABILITY = 52
6464

65+
_DEFAULT_CPU_COLLECTIVES_IMPL = 'gloo'
66+
6567
# TODO(phawkins): Remove jax_xla_backend.
6668
_XLA_BACKEND = config.string_flag(
6769
'jax_xla_backend', '',
@@ -235,7 +237,9 @@ def make_cpu_client(
235237
Returns:
236238
The created CPU client.
237239
"""
238-
if collectives is None:
240+
# TODO(skyewm): use distributed.is_initialized() after
241+
# https://github.com/jax-ml/jax/pull/26172 goes in.
242+
if collectives is None and distributed.global_state.client is not None:
239243
collectives_impl = config.cpu_collectives_implementation.value
240244
if _CPU_ENABLE_GLOO_COLLECTIVES.value:
241245
collectives_impl = 'gloo'
@@ -244,6 +248,9 @@ def make_cpu_client(
244248
'"jax_cpu_collectives_implementation", "gloo")` instead.',
245249
DeprecationWarning,
246250
)
251+
if collectives_impl is None:
252+
collectives_impl = _DEFAULT_CPU_COLLECTIVES_IMPL
253+
247254
if collectives_impl == 'gloo':
248255
collectives = xla_client._xla.make_gloo_tcp_collectives(
249256
distributed_client=distributed.global_state.client,
@@ -252,8 +259,6 @@ def make_cpu_client(
252259
collectives = xla_client._xla.make_mpi_collectives()
253260
collectives.Init()
254261
atexit.register(collectives.Finalize)
255-
elif collectives_impl == 'megascale':
256-
raise ValueError('JAX_CPU_COLLECTIVES_IMPLEMENTATION must "gloo" or "mpi"')
257262
else:
258263
# Already validated by config module
259264
assert collectives_impl is None

0 commit comments

Comments
 (0)