-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature Request: Direct ONNX Exporter for JAX #26430
Comments
ONNX became the standard distribution format in some fields and it's generally very well supported by the various inference engines. This is the number 1 thing keeping me away from switching my stuff to Flax, and I think there might be others in the same boat. |
An exporter to ONNX is possible, but is also a major undertaking with significant ongoing maintenance work. One can get an idea of what it would take by reading the jax2tf.py code, which converts JAX primitives to TF ops. (Note that we are planning to deprecate this code since it has been possible for the last 2 years to import StableHLO into TF, and that is what we recommend as the interoperation mechanism). If somebody wants to embark in this project, I would recommend studying first what it would take to convert StableHLO to ONNX, because StableHLO is a better-defined interoperation layer than anything inside the JAX (primitives, Jaxpr). |
same |
@gnecula I don't know much about
To clarify, an attempt this would forgo tensorflow by translating emitted StableHLO to ONNX and that there currently does not exist a tool for this?
I recently wrote a jaxpr->ONNX script for my use case. It sort of works, but I can only imagine the additional complexity when doing it starting with StableHLO.... |
Hi @johnnynunez, I was also recently looking for a way to export JAX models directly to ONNX and was surprised that no direct solution existed in the JAX ecosystem. So, I decided to give it a try and started working on a project for this: I'd love to hear your thoughts or any feedback on how it could be improved! 🚀 There’s also some discussion about the topic in a Flax issue: |
@enpasos Ha, looks like an exporter is in real demand! This is my (mostly LLM generated) attempt to translate jaxprs to ONNX operators: https://gist.github.com/limarta/855a88cc1c0163487a9dc369891147ab. Admittedly I cut corners here and there to get it working for my use case, so it is definitely incomplete. |
@limarta Congrats! 🎉 Your elegant approach of leveraging JAX's autogenerated expression tree (jaxpr) significantly reduces the burden on users. I definitely wouldn’t want to manually construct ONNX graphs for complex functions like the gamma sampler in your gist! What do you think about combining our approaches? The to_onnx expression tree builder in jax2onnx could integrate your jaxpr-based method for functions composed of JAX primitives. This might result in a more complete and maintainable pipeline for exporting JAX models to ONNX. Would love to hear your thoughts on this! 🚀 |
What about using JAXPR, as demonstrated by @limarta, as the low-level intermediate representation? Since JAXPR is a standard JAX functionality, it provides a solid, built-in way to represent computations. By leveraging primitive handlers—the @limarta way!—we could establish a straightforward mapping to ONNX as a baseline. Building on this, we could incorporate optimization patterns to refine the ONNX representation. Starting with relatively low initial implementation effort and it could be developed incrementally. I believe we could replace the current mechanism in jax2onnx with this approach while achieving the same results. From a user perspective, the process could be as simple as calling a conversion function with an nnx.Module instance or a JAX function, along with input shapes (and optionally an export strategy if needed). For baseline QA, we could validate correctness by comparing output values between the original JAX component and its ONNX counterpart. Test cases would cover both the primitive handlers and each optimization pattern to ensure robustness. I hope this aligns with @johnnynunez’s vision! Looking forward to your thoughts. 🚀 |
🚧 jax2onnx Update: Jaxpr-Based Redesign I've started refactoring jax2onnx, now basing the approach directly on jaxpr. This significantly simplifies the process for users and aligns closely with recent community suggestions. To effectively leverage @limarta's elegant jaxpr-to-ONNX mapping, I've introduced temporary monkey-patching to allow seamless registration of high-level callables (e.g., Flax I hope this meets the community's expectations and further develops the great ideas discussed, especially those demonstrated by @limarta. I've just pushed my current development snapshot to the repository to encourage collaboration and gather early feedback. Community feedback, ideas, and contributions are warmly welcome! |
Description:
I would like to propose the development of a direct exporter to convert JAX models to the ONNX format. Currently, the only available approach involves converting JAX models to TensorFlow using jax2tf and then converting the TensorFlow model to ONNX via tools like tf2onnx. While this workaround is functional, it introduces extra complexity and potential issues with fidelity and performance.
Motivation:
Challenges to Consider:
Potential Approaches:
Intermediate IR Translation: Explore adapting JAX’s intermediate representation into a format that is more easily convertible to ONNX’s static graph format.
Additional Context:
This feature request aims to bridge the gap between JAX’s dynamic computational paradigm and the static graph requirements of ONNX. A direct exporter would significantly enhance the usability of JAX in production environments that rely on ONNX for interoperability.
The text was updated successfully, but these errors were encountered: