Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Direct ONNX Exporter for JAX #26430

Open
johnnynunez opened this issue Feb 10, 2025 · 9 comments
Open

Feature Request: Direct ONNX Exporter for JAX #26430

johnnynunez opened this issue Feb 10, 2025 · 9 comments
Labels
enhancement New feature or request

Comments

@johnnynunez
Copy link

johnnynunez commented Feb 10, 2025

Description:
I would like to propose the development of a direct exporter to convert JAX models to the ONNX format. Currently, the only available approach involves converting JAX models to TensorFlow using jax2tf and then converting the TensorFlow model to ONNX via tools like tf2onnx. While this workaround is functional, it introduces extra complexity and potential issues with fidelity and performance.

Motivation:

  • Simplified Workflow: A direct exporter would eliminate the need for an intermediate conversion step, making it easier to deploy JAX models in ONNX-supported environments.
  • Increased Compatibility: Direct conversion could improve the integration of JAX with other frameworks that rely on ONNX for model interoperability.
  • Reduced Overhead: Removing the extra conversion layer could mitigate potential issues related to model accuracy and performance that sometimes occur during multi-step conversions.

Challenges to Consider:

  • Intermediate Representations: JAX uses XLA and a dynamic representation (jaxpr) which is optimized for JIT compilation and research rather than the static graph structure expected by ONNX.
  • Operation Mapping: Not all JAX operations have a straightforward counterpart in ONNX. Developing a robust mapping (or extending the ONNX operator set) to cover all necessary operations, including control flow and custom gradients, would be a significant undertaking.
  • Community and Engineering Resources: Building and maintaining such an exporter requires substantial investment and collaboration. It may be beneficial to initiate community discussions to assess interest and determine feasible approaches.

Potential Approaches:

  • Mapping JAX Primitives: Investigate the possibility of mapping common JAX primitives directly to ONNX operators.
    Intermediate IR Translation: Explore adapting JAX’s intermediate representation into a format that is more easily convertible to ONNX’s static graph format.
  • Community Collaboration: Engage with both the JAX and ONNX communities to identify priorities, share knowledge, and possibly establish a working group for this project.
    Additional Context:

This feature request aims to bridge the gap between JAX’s dynamic computational paradigm and the static graph requirements of ONNX. A direct exporter would significantly enhance the usability of JAX in production environments that rely on ONNX for interoperability.

@johnnynunez johnnynunez added the enhancement New feature or request label Feb 10, 2025
@Artoriuz
Copy link

ONNX became the standard distribution format in some fields and it's generally very well supported by the various inference engines.

This is the number 1 thing keeping me away from switching my stuff to Flax, and I think there might be others in the same boat.

@gnecula
Copy link
Collaborator

gnecula commented Feb 17, 2025

An exporter to ONNX is possible, but is also a major undertaking with significant ongoing maintenance work. One can get an idea of what it would take by reading the jax2tf.py code, which converts JAX primitives to TF ops. (Note that we are planning to deprecate this code since it has been possible for the last 2 years to import StableHLO into TF, and that is what we recommend as the interoperation mechanism).

If somebody wants to embark in this project, I would recommend studying first what it would take to convert StableHLO to ONNX, because StableHLO is a better-defined interoperation layer than anything inside the JAX (primitives, Jaxpr).

@johnnynunez
Copy link
Author

ONNX became the standard distribution format in some fields and it's generally very well supported by the various inference engines.

This is the number 1 thing keeping me away from switching my stuff to Flax, and I think there might be others in the same boat.

same

@limarta
Copy link

limarta commented Feb 27, 2025

@gnecula I don't know much about jax2tf.py, but was the idea that we first translated jaxprs to tf ops and then complete the translation with something like tensorflow-onnx?

If somebody wants to embark in this project, I would recommend studying first what it would take to convert StableHLO to ONNX, because StableHLO is a better-defined interoperation layer than anything inside the JAX (primitives, Jaxpr).

To clarify, an attempt this would forgo tensorflow by translating emitted StableHLO to ONNX and that there currently does not exist a tool for this?

An exporter to ONNX is possible, but is also a major undertaking with significant ongoing maintenance work.

I recently wrote a jaxpr->ONNX script for my use case. It sort of works, but I can only imagine the additional complexity when doing it starting with StableHLO....

@enpasos
Copy link

enpasos commented Feb 28, 2025

Hi @johnnynunez,

I was also recently looking for a way to export JAX models directly to ONNX and was surprised that no direct solution existed in the JAX ecosystem.

So, I decided to give it a try and started working on a project for this:
🔗 https://github.com/enpasos/jax2onnx

I'd love to hear your thoughts or any feedback on how it could be improved! 🚀

There’s also some discussion about the topic in a Flax issue:
🔗 google/flax#4430

@limarta
Copy link

limarta commented Feb 28, 2025

@enpasos Ha, looks like an exporter is in real demand! This is my (mostly LLM generated) attempt to translate jaxprs to ONNX operators: https://gist.github.com/limarta/855a88cc1c0163487a9dc369891147ab. Admittedly I cut corners here and there to get it working for my use case, so it is definitely incomplete.

@enpasos
Copy link

enpasos commented Mar 1, 2025

@limarta Congrats! 🎉 Your elegant approach of leveraging JAX's autogenerated expression tree (jaxpr) significantly reduces the burden on users. I definitely wouldn’t want to manually construct ONNX graphs for complex functions like the gamma sampler in your gist!

What do you think about combining our approaches? The to_onnx expression tree builder in jax2onnx could integrate your jaxpr-based method for functions composed of JAX primitives. This might result in a more complete and maintainable pipeline for exporting JAX models to ONNX.

Would love to hear your thoughts on this! 🚀

@enpasos
Copy link

enpasos commented Mar 4, 2025

What about using JAXPR, as demonstrated by @limarta, as the low-level intermediate representation? Since JAXPR is a standard JAX functionality, it provides a solid, built-in way to represent computations. By leveraging primitive handlers—the @limarta way!—we could establish a straightforward mapping to ONNX as a baseline.

Building on this, we could incorporate optimization patterns to refine the ONNX representation. Starting with relatively low initial implementation effort and it could be developed incrementally. I believe we could replace the current mechanism in jax2onnx with this approach while achieving the same results.

From a user perspective, the process could be as simple as calling a conversion function with an nnx.Module instance or a JAX function, along with input shapes (and optionally an export strategy if needed).

For baseline QA, we could validate correctness by comparing output values between the original JAX component and its ONNX counterpart. Test cases would cover both the primitive handlers and each optimization pattern to ensure robustness.

I hope this aligns with @johnnynunez’s vision! Looking forward to your thoughts. 🚀

@enpasos
Copy link

enpasos commented Mar 7, 2025

🚧 jax2onnx Update: Jaxpr-Based Redesign

I've started refactoring jax2onnx, now basing the approach directly on jaxpr. This significantly simplifies the process for users and aligns closely with recent community suggestions.

To effectively leverage @limarta's elegant jaxpr-to-ONNX mapping, I've introduced temporary monkey-patching to allow seamless registration of high-level callables (e.g., Flax nnx components) as custom JAX primitives. Additionally, I've retained the plugin architecture from the original jax2onnx, ensuring continued flexibility and extensibility.

I hope this meets the community's expectations and further develops the great ideas discussed, especially those demonstrated by @limarta. I've just pushed my current development snapshot to the repository to encourage collaboration and gather early feedback.

Community feedback, ideas, and contributions are warmly welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants