Open
Description
Description:
I would like to propose the development of a direct exporter to convert JAX models to the ONNX format. Currently, the only available approach involves converting JAX models to TensorFlow using jax2tf and then converting the TensorFlow model to ONNX via tools like tf2onnx. While this workaround is functional, it introduces extra complexity and potential issues with fidelity and performance.
Motivation:
- Simplified Workflow: A direct exporter would eliminate the need for an intermediate conversion step, making it easier to deploy JAX models in ONNX-supported environments.
- Increased Compatibility: Direct conversion could improve the integration of JAX with other frameworks that rely on ONNX for model interoperability.
- Reduced Overhead: Removing the extra conversion layer could mitigate potential issues related to model accuracy and performance that sometimes occur during multi-step conversions.
Challenges to Consider:
- Intermediate Representations: JAX uses XLA and a dynamic representation (jaxpr) which is optimized for JIT compilation and research rather than the static graph structure expected by ONNX.
- Operation Mapping: Not all JAX operations have a straightforward counterpart in ONNX. Developing a robust mapping (or extending the ONNX operator set) to cover all necessary operations, including control flow and custom gradients, would be a significant undertaking.
- Community and Engineering Resources: Building and maintaining such an exporter requires substantial investment and collaboration. It may be beneficial to initiate community discussions to assess interest and determine feasible approaches.
Potential Approaches:
- Mapping JAX Primitives: Investigate the possibility of mapping common JAX primitives directly to ONNX operators.
Intermediate IR Translation: Explore adapting JAX’s intermediate representation into a format that is more easily convertible to ONNX’s static graph format. - Community Collaboration: Engage with both the JAX and ONNX communities to identify priorities, share knowledge, and possibly establish a working group for this project.
Additional Context:
This feature request aims to bridge the gap between JAX’s dynamic computational paradigm and the static graph requirements of ONNX. A direct exporter would significantly enhance the usability of JAX in production environments that rely on ONNX for interoperability.