Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy Add JAX backwards compatibility test. #379

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

copybara-service[bot]
Copy link

#sdy Add JAX backwards compatibility test.

This tests saving a module with one set of axis names, but loading it with another set of axis names.

This does also test the custom calls:

  • @Sharding
  • @xla.sdy.GlobalToLocalShape
  • @xla.sdy.LocalToGlobalShape

But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set out_shardings for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO.

This tests saving a module with one set of axis names, but loading it with another set of axis names.

This does also test the custom calls:

- `@Sharding`
- `@xla.sdy.GlobalToLocalShape`
- `@xla.sdy.LocalToGlobalShape`

But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set `out_shardings` for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO.

PiperOrigin-RevId: 725999900
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant