-
Notifications
You must be signed in to change notification settings - Fork 417
Update patch to ensure maxtext images are downgraded to 0.7.0 #2506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
49b6ae9
to
36bd7e2
Compare
maxtext_jax_ai_image.Dockerfile
Outdated
python3 -m pip install 'google-tunix>=0.1.2'; \ | ||
# TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600) | ||
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \ | ||
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0' 'libtpu==0.0.19'; \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to skip this line when MODE=nightly
? Or only do so with stable mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we use mode=stable_stack for all images built off of JAII, we distinguish them normally via the image passed to BASEIMAGE. And the only way to distinguish nightly vs stable versions would be if it has "nightly" in the text
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, do you think this code snippet will help in both cases? It forces to upgrade JAX version to the newest in the nightly version, so the libtpu could match (b/451429959).
if [[ "$JAX_AI_IMAGE_BASEIMAGE" == *"nightly"* ]]; then \
echo "Nightly image detected"; \
python3 -m pip install --upgrade jax[tpu]; \
else \
echo "Non-nightly image"; \
python3 -m pip install jax[tpu]==0.7.0; \
fi
cc @gobbleturk @bvandermoon Basically we are (1) upgrading JAX to latest 0.8.0 for nightly images (2) downgrading JAX to 0.7.0 for github CI stable stack images. Both cases are after the Tunix installation so we can keep JAX version. Not sure if this is a worthwhile patch given this problem is transient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For nightly images I think we have to install via these commands:
pip uninstall -y jax jaxlib libtpu
pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works, was able to force install jax nightly only for nightly jaii images: https://paste.googleplex.com/5962851466477568
9262cf0
to
fcbbf5a
Compare
maxtext_jax_ai_image.Dockerfile
Outdated
python3 -m pip install 'google-tunix>=0.1.2'; \ | ||
# TODO: Once tunix stopped pinning jax 0.7.1, we should remove our 0.7.0 version pin (b/450286600) | ||
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0'; \ | ||
python3 -m pip install 'jax==0.7.0' 'jaxlib==0.7.0' 'libtpu==0.0.19'; \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just install jax[tpu]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax[tpu] would install jax 0.8.0, which hasn't been tested yet with maxtext for compatability. And there is known issues with 0.7.2 and 0.7.1, so maxtext would be on 0.7.0 until 0.8.0 compatability is confirmed
Remove extra comment Force install jax nightly only in jaii nightly images Remove duplicate block
f87f376
to
6ecb399
Compare
Description
Maxtext unit tests are having a mismatch between jax version and libtpu version. They initially pull the 0.7.2 JAII, then downgrade only the jax and jaxlib versions to 0.7.2, but not the corresponding libtpu version. This is causing pallas libtpu error for unit tests using pathways backend.
Notice 1: The downgrade to jax 0.7.0 is because google-tunix has a requirement to be <= 0.7.1 jax. Once tunix removes the jax version pin, we will revert these changes to the normal flow.
Tests
Building a maxtext image and will verify the pip freeze: https://paste.googleplex.com/6454746957348864
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-review
label.