Skip to content

Conversation

Rohan-Bierneni
Copy link
Collaborator

@Rohan-Bierneni Rohan-Bierneni commented Oct 15, 2025

Description

Maxtext unit tests are having a mismatch between jax version and libtpu version. They initially pull the 0.7.2 JAII, then downgrade only the jax and jaxlib versions to 0.7.2, but not the corresponding libtpu version. This is causing pallas libtpu error for unit tests using pathways backend.

Notice 1: The downgrade to jax 0.7.0 is because google-tunix has a requirement to be <= 0.7.1 jax. Once tunix removes the jax version pin, we will revert these changes to the normal flow.

Tests

Building a maxtext image and will verify the pip freeze: https://paste.googleplex.com/6454746957348864

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-fix-jaii-0.7.0-pin branch from 49b6ae9 to 36bd7e2 Compare October 15, 2025 23:25
python3 -m pip install 'google-tunix>=0.1.2'; \
# TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600)
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0' 'libtpu==0.0.19'; \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to skip this line when MODE=nightly? Or only do so with stable mode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we use mode=stable_stack for all images built off of JAII, we distinguish them normally via the image passed to BASEIMAGE. And the only way to distinguish nightly vs stable versions would be if it has "nightly" in the text

Copy link
Collaborator

@hengtaoguo hengtaoguo Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, do you think this code snippet will help in both cases? It forces to upgrade JAX version to the newest in the nightly version, so the libtpu could match (b/451429959).

        if [[ "$JAX_AI_IMAGE_BASEIMAGE" == *"nightly"* ]]; then \
            echo "Nightly image detected"; \
            python3 -m pip install --upgrade jax[tpu]; \
        else \
            echo "Non-nightly image"; \
            python3 -m pip install jax[tpu]==0.7.0; \
        fi

cc @gobbleturk @bvandermoon Basically we are (1) upgrading JAX to latest 0.8.0 for nightly images (2) downgrading JAX to 0.7.0 for github CI stable stack images. Both cases are after the Tunix installation so we can keep JAX version. Not sure if this is a worthwhile patch given this problem is transient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For nightly images I think we have to install via these commands:
pip uninstall -y jax jaxlib libtpu
pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works, was able to force install jax nightly only for nightly jaii images: https://paste.googleplex.com/5962851466477568

@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-fix-jaii-0.7.0-pin branch from 9262cf0 to fcbbf5a Compare October 17, 2025 17:52
python3 -m pip install 'google-tunix>=0.1.2'; \
# TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600)
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0' 'libtpu==0.0.19'; \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just install jax[tpu]?

Copy link
Collaborator Author

@Rohan-Bierneni Rohan-Bierneni Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax[tpu] would install jax 0.8.0, which hasn't been tested yet with maxtext for compatability. And there is known issues with 0.7.2 and 0.7.1, so maxtext would be on 0.7.0 until 0.8.0 compatability is confirmed

Remove extra comment

Force install jax nightly only in jaii nightly images

Remove duplicate block
@Rohan-Bierneni Rohan-Bierneni force-pushed the rbierneni-fix-jaii-0.7.0-pin branch from f87f376 to 6ecb399 Compare October 17, 2025 18:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants