[Feature] Add JAX integration for FlyDSL kernels#257
[Feature] Add JAX integration for FlyDSL kernels#257wenchenvincent wants to merge 9 commits intoROCm:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a first-class JAX integration layer for FlyDSL, enabling both eager execution on jax.Array inputs and jax.jit execution via XLA custom calls, while also making PyTorch an optional dependency for importing FlyDSL.
Changes:
- Introduces
python/flydsl/jax/(adapter, primitive lowering, FFI bridge, and C trampoline) to run FlyDSL kernels from JAX (eager +jax.jit). - Makes PyTorch optional by guarding torch-specific registrations in
jit_argument.py. - Exposes
JitFunction.get_last_artifact()to allow external integrations (JAX bridge) to retrieve compiled artifacts.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/test_jax_integration.py | Adds unit tests for adapter/primitive/bridge behavior (skips if JAX missing). |
| tests/test_jax_vecadd.py | Adds a JAX vector-add integration script/test (currently unguarded JAX import). |
| python/flydsl/jax/init.py | Public entrypoints for from_jax, jax_kernel, and lazy wrappers. |
| python/flydsl/jax/adapter.py | Implements JaxTensorAdaptor via DLPack for eager-mode @flyc.jit calls. |
| python/flydsl/jax/primitive.py | Implements a JAX primitive + StableHLO CustomCallOp lowering for jax.jit. |
| python/flydsl/jax/ffi_bridge.py | Compiles/registers FlyDSL kernels as XLA custom call targets + loads/builds trampoline. |
| python/flydsl/jax/_xla_bridge.c | C trampoline bridging XLA GPU custom-call ABI to FlyDSL ptr-packing ABI. |
| python/flydsl/compiler/jit_function.py | Adds get_last_artifact() and tracks last compilation result. |
| python/flydsl/compiler/jit_argument.py | Makes torch optional; keeps PyTorch tensor support when installed. |
| python/flydsl/compiler/init.py | Convenience re-export flydsl.compiler.from_jax. |
| examples/04-vectorAdd-jax.py | JAX version of vector add (eager + jax.jit). |
| examples/05-tiledCopy-jax.py | JAX version of tiled copy (eager + jax.jit). |
| examples/06-tiledMma-jax.py | JAX version of tiled MMA (eager + jax.jit). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # 0 = API_VERSION_ORIGINAL (CPU: fn(out, ins)) | ||
| # 1 = API_VERSION_STATUS_RETURNING (CPU: fn(out, ins, status)) | ||
| # 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len)) | ||
| # 4 = API_VERSION_TYPED_FFI | ||
| # We use 2 for GPU custom calls with the old untyped convention. | ||
| # backend_config carries the opaque bytes (slot index for the C trampoline). | ||
| i32_type = jax_ir.IntegerType.get_signless(32) | ||
| call = stablehlo.CustomCallOp( | ||
| result_types, | ||
| list(args), | ||
| call_target_name=target_name, | ||
| api_version=jax_ir.IntegerAttr.get(i32_type, 2), |
There was a problem hiding this comment.
stablehlo.CustomCallOp is emitted with api_version=2, but the target is registered in ffi_bridge._register_with_xla() using api_version=0. XLA/JAX expects these API versions to match; otherwise the runtime will call the bridge with a different calling convention than the C trampoline implements. Align the api_version used in both places (and update the docstring/comments accordingly).
| # 0 = API_VERSION_ORIGINAL (CPU: fn(out, ins)) | |
| # 1 = API_VERSION_STATUS_RETURNING (CPU: fn(out, ins, status)) | |
| # 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len)) | |
| # 4 = API_VERSION_TYPED_FFI | |
| # We use 2 for GPU custom calls with the old untyped convention. | |
| # backend_config carries the opaque bytes (slot index for the C trampoline). | |
| i32_type = jax_ir.IntegerType.get_signless(32) | |
| call = stablehlo.CustomCallOp( | |
| result_types, | |
| list(args), | |
| call_target_name=target_name, | |
| api_version=jax_ir.IntegerAttr.get(i32_type, 2), | |
| # 0 = API_VERSION_ORIGINAL (fn(out, ins)) | |
| # 1 = API_VERSION_STATUS_RETURNING (fn(out, ins, status)) | |
| # 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len)) | |
| # 4 = API_VERSION_TYPED_FFI | |
| # We use 0 here to match the api_version used in ffi_bridge._register_with_xla(). | |
| # backend_config carries the opaque bytes (slot index for the C trampoline). | |
| i32_type = jax_ir.IntegerType.get_signless(32) | |
| call = stablehlo.CustomCallOp( | |
| result_types, | |
| list(args), | |
| call_target_name=target_name, | |
| api_version=jax_ir.IntegerAttr.get(i32_type, 0), |
python/flydsl/jax/ffi_bridge.py
Outdated
| # api_version=0: old custom-call convention | ||
| # void fn(stream, void** buffers, const char* opaque, size_t opaque_len) | ||
| _xla_client.register_custom_call_target( | ||
| target.name, capsule, xla_platform_name, api_version=0, | ||
| ) |
There was a problem hiding this comment.
_register_with_xla() registers the custom call target with api_version=0, but the lowering in primitive.py sets api_version=2 on the StableHLO CustomCallOp. This mismatch will typically lead to the wrong runtime ABI being used for the callback. Update the registration to use the same API version as the emitted CustomCallOp (and keep the inline comments/docstring consistent).
| def _ensure_bridge_lib() -> ctypes.CDLL: | ||
| """Load ``_xla_bridge.so``, compiling from source if necessary.""" | ||
| if not _BRIDGE_SO.exists(): | ||
| if not _BRIDGE_C.exists(): | ||
| raise FileNotFoundError( | ||
| f"Cannot find XLA bridge source: {_BRIDGE_C}\n" | ||
| f"Please rebuild or reinstall flydsl." | ||
| ) | ||
| subprocess.check_call( | ||
| ["gcc", "-shared", "-fPIC", "-O2", "-o", str(_BRIDGE_SO), str(_BRIDGE_C)], | ||
| cwd=str(_THIS_DIR), | ||
| ) | ||
| lib = ctypes.CDLL(str(_BRIDGE_SO)) |
There was a problem hiding this comment.
Building _xla_bridge.so at import time via a hard-coded gcc invocation is brittle operationally (requires a compiler at runtime, write permissions to the installed package dir, and fails on non-GNU toolchains). If you do keep this fallback, add the proper thread library flags (-pthread / -lpthread) since _xla_bridge.c uses pthreads, and consider moving compilation to the package build step or caching in a writable user cache dir instead of the source tree.
| with _lock: | ||
| if target_name in _registered_targets: | ||
| return target_name | ||
|
|
There was a problem hiding this comment.
compile_and_register() does a membership check under _lock and then releases the lock before compilation/registration. Two threads lowering the same shapes can race and both compile + register the same target_name, potentially causing duplicate registrations or nondeterministic behavior. Consider holding the lock through the whole register path, or store an “in-progress” sentinel (e.g., a threading.Event/Future) so only one thread compiles while others wait.
python/flydsl/jax/_xla_bridge.c
Outdated
| pthread_mutex_unlock(&g_lock); | ||
|
|
||
| g_targets[idx].func = (flydsl_func_t)func_ptr; | ||
| g_targets[idx].n_buffers = n_buffers; | ||
| g_targets[idx].n_scalars = n_scalars; | ||
| for (int i = 0; i < n_scalars; i++) | ||
| g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0; |
There was a problem hiding this comment.
flydsl_xla_register() increments g_n_targets under the mutex, then releases the lock and only afterwards populates g_targets[idx]. xla_bridge_dispatch() can run concurrently, observe idx < g_n_targets, and read a partially initialized slot (including a NULL/garbage func pointer). Populate the slot while still holding g_lock (or use an atomic publish pattern) to make registration + dispatch thread-safe.
| pthread_mutex_unlock(&g_lock); | |
| g_targets[idx].func = (flydsl_func_t)func_ptr; | |
| g_targets[idx].n_buffers = n_buffers; | |
| g_targets[idx].n_scalars = n_scalars; | |
| for (int i = 0; i < n_scalars; i++) | |
| g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0; | |
| g_targets[idx].func = (flydsl_func_t)func_ptr; | |
| g_targets[idx].n_buffers = n_buffers; | |
| g_targets[idx].n_scalars = n_scalars; | |
| for (int i = 0; i < n_scalars; i++) | |
| g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0; | |
| pthread_mutex_unlock(&g_lock); |
python/flydsl/jax/_xla_bridge.c
Outdated
| int idx = 0; | ||
| if (opaque_len >= sizeof(int)) | ||
| memcpy(&idx, opaque, sizeof(int)); | ||
|
|
There was a problem hiding this comment.
In xla_bridge_dispatch(), if opaque_len < sizeof(int) you leave idx as 0 and will dispatch slot 0 even when the opaque payload is missing/invalid. This can silently call the wrong kernel. Consider treating short/invalid opaque as an error and returning early (and ideally validate opaque_len == sizeof(int) since you control the encoding).
| int idx = 0; | |
| if (opaque_len >= sizeof(int)) | |
| memcpy(&idx, opaque, sizeof(int)); | |
| /* Expect opaque to contain exactly one int index. Treat mismatched | |
| * or missing opaque as an error to avoid dispatching the wrong slot. | |
| */ | |
| if (opaque == NULL || opaque_len != sizeof(int)) | |
| return; | |
| int idx; | |
| memcpy(&idx, opaque, sizeof(int)); |
| for (int i = 0; i < ns; i++) { | ||
| storage[nb + i] = (void*)t->scalar_vals[i]; | ||
| packed[nb + i] = &storage[nb + i]; | ||
| } |
There was a problem hiding this comment.
Scalar packing currently does storage[nb + i] = (void*)t->scalar_vals[i]; which relies on casting integers to pointers and on the callee interpreting the in-memory representation of a void* as an integer scalar. This is non-portable/undefined behavior and can break for large/negative values or non-integer scalars. Prefer storing scalars in a dedicated int64_t scalar_storage[] array on the stack (or similar) and set packed[...] to point at the scalar bytes (matching how FlyDSL packs scalar args via ctypes).
tests/test_jax_vecadd.py
Outdated
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
|
|
||
| import flydsl.compiler as flyc | ||
| import flydsl.expr as fx | ||
| from flydsl.jax import from_jax | ||
|
|
||
|
|
There was a problem hiding this comment.
This file imports jax unconditionally at module import time. Because it’s under tests/ and named test_*.py, pytest will import it during collection, causing the whole test run to error in environments where JAX isn’t installed (even if JAX support is meant to be optional). Add a try/except ImportError + pytest.skip(..., allow_module_level=True) guard (like tests/unit/test_jax_integration.py), or move/rename it so it’s not collected by pytest when JAX is absent.
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import flydsl.compiler as flyc | |
| import flydsl.expr as fx | |
| from flydsl.jax import from_jax | |
| import numpy as np | |
| import flydsl.compiler as flyc | |
| import flydsl.expr as fx | |
| try: | |
| import jax | |
| import jax.numpy as jnp | |
| from flydsl.jax import from_jax | |
| except ImportError: | |
| import pytest | |
| pytest.skip("JAX not installed; skipping JAX vecAdd test.", allow_module_level=True) |
a9976b0 to
a58e87a
Compare
Add two levels of JAX integration so FlyDSL GPU kernels can be called
from JAX code on AMD GPUs:
Level 1 -- Eager mode (from_jax):
Wrap jax.Array objects via DLPack so they can be passed directly to
@flyc.jit functions. Works with all dtypes including float8 variants.
Level 2 -- jax.jit integration (jax_kernel):
Register compiled FlyDSL kernels as XLA custom-call targets via a C
trampoline that bridges XLA's GPU calling convention (stream, buffers,
opaque) to FlyDSL's bare-pointer convention. Supports tensor buffers,
baked scalar arguments, and compile-time constants.
New package python/flydsl/jax/:
- adapter.py: JaxTensorAdaptor with DLPack, float8, layout dynamism
- primitive.py: JAX primitive with abstract eval, eager impl, and
StableHLO CustomCallOp lowering for rocm/gpu platforms
- ffi_bridge.py: Compilation, XLA target registration, opaque encoding
- _xla_bridge.c: Thread-safe C trampoline with scalar insertion
Also makes PyTorch an optional dependency by guarding the torch import
in jit_argument.py behind try/except. The JitArgumentRegistry, compiler
decorators, and JAX integration all work without torch installed.
Adds JitFunction.get_last_artifact() public API for external integrations
to retrieve compiled kernels without accessing private attributes.
Tested on MI300X with JAX 0.8.2, all examples produce correct results.
Signed-off-by: Wen Chen <Wen.Chen@amd.com>
JAX equivalents of examples 01-03, demonstrating both eager mode (from_jax + @flyc.jit) and jax.jit mode (jax_kernel + XLA custom call): - 04-vectorAdd-jax.py: Vector addition with layout algebra - 05-tiledCopy-jax.py: Tiled copy with partitioned tensors - 06-tiledMma-jax.py: Single-tile GEMM (64x64x8) using MFMA instructions All produce correct results on MI300X with JAX 0.8.2. Signed-off-by: Wen Chen <Wen.Chen@amd.com>
30 unit tests in tests/unit/test_jax_integration.py: - Adapter: 18 tests (dtypes, shapes, fly_types, fly_ptrs, cache signature, layout dynamic, alignment, error cases, float8) - Primitive: 5 tests (eager single/multi output, constexpr/scalar forwarding, abstract eval) - C trampoline: 5 tests (dispatch, scalar insertion, slot allocation, bounds rejection) - Registration dedup: 2 tests (name hashing stability) End-to-end test in tests/test_jax_vecadd.py: - 102,400-element vectorized add using JAX arrays, max error = 0.00 All 30 unit tests pass in ~4s on JAX 0.8.2 without torch. Signed-off-by: Wen Chen <Wen.Chen@amd.com>
1. _xla_bridge.c: Populate slot data inside the mutex-protected section so concurrent dispatchers never see uninitialized data (Comment 5). Add strict opaque validation -- return early if NULL or wrong size instead of silently defaulting to slot 0 (Comment 6). Use dedicated int64_t scalar_storage[] array instead of casting integers to void* pointers, avoiding undefined behavior (Comment 7). 2. ffi_bridge.py: Hold the lock through the full compile+register path to prevent concurrent threads from duplicating work (Comment 4). Use 'cc' instead of 'gcc' for portability, add -lpthread flag (Comment 3). Clarify api_version cross-references between the XLA registration API (0=untyped) and StableHLO CustomCallOp (2=STATUS_RETURNING_UNIFIED) which refer to the same calling convention (Comments 1+2). 3. test_jax_vecadd.py: Add try/except ImportError guard with pytest.skip so the test suite doesn't break when JAX is absent (Comment 8). Signed-off-by: Wen Chen <Wen.Chen@amd.com> Signed-off-by: Wen Chen <Wen.Chen@amd.com>
Fix the jax.jit path when FLYDSL_RUNTIME_ENABLE_CACHE is enabled: - Force fresh compilation in compile_and_register by temporarily clearing JitFunction's _mem_cache and _call_state_cache, preventing stale artifacts from a prior eager call (with different adaptor options like mark_layout_dynamic) from being reused - Set _last_compiled on cache-hit paths so get_last_artifact() always returns the correct artifact - Add auto-rebuild of _xla_bridge.so when .c source is newer Address all 8 Copilot review comments: 1-2. Clarify api_version cross-references (registration=0, StableHLO=2) 3. Use 'cc' instead of 'gcc', add -lpthread 4. Hold lock through full compile+register path 5. Populate C trampoline slot inside mutex 6. Strict opaque validation in dispatch 7. Use int64_t scalar_storage[] instead of void* cast 8. Add pytest.skip guard in test_jax_vecadd.py Known limitation: the jax.jit path for kernels using make_buffer_tensor (e.g. 04-vectorAdd-jax.py) requires FLYDSL_RUNTIME_ENABLE_CACHE=0 when the same kernel is also called via the eager path with mark_layout_dynamic in the same process. This is due to cache key collisions in JitFunction where mark_layout_dynamic changes the MLIR type but not the cache signature. The 05-tiledCopy and 06-tiledMma examples work correctly with caching enabled. Signed-off-by: Wen Chen <Wen.Chen@amd.com>
mark_layout_dynamic changes the MLIR memref type (adding dynamic dimensions) but previously did not affect __cache_signature__, causing cache key collisions between dynamic and static layouts. This led to the wrong compiled artifact being reused when the same kernel was called with and without mark_layout_dynamic in the same process. Fix: Track _dynamic_leading_dim and _dynamic_divisibility on both TensorAdaptor (torch) and JaxTensorAdaptor (JAX), and include them in __cache_signature__. Adaptors without mark_layout_dynamic have None values which don't collide with any (int, int) pair. Also removes the cache-clearing workaround from ffi_bridge.py since the root cause is now fixed. All examples pass with JIT caching enabled (no FLYDSL_RUNTIME_ENABLE_CACHE=0 needed). Signed-off-by: Wen Chen <Wen.Chen@amd.com>
- Add JAX ROCm installation step (jax, jaxlib, jax-rocm7-pjrt, jax-rocm7-plugin) to test-whl.yaml CI workflow. Marked as continue-on-error so JAX installation failures don't block the rest of the test suite. - Add import guard to all three JAX examples (04, 05, 06): if JAX is not installed, print "SKIP: JAX not installed" and exit 0. - Update run_tests.sh to detect "SKIP:" output from examples and report them as SKIP instead of PASS or FAIL. Signed-off-by: Wen Chen <Wen.Chen@amd.com>
- ffi_bridge.py: remove unused `import os` - primitive.py: remove unused `Union` from typing imports Signed-off-by: Wen Chen <Wen.Chen@amd.com>
Resolve conflict in jit_argument.py: - Keep torch-optional structure (try/except import torch) - Incorporate main's cache key refactor: add raw_cache_signature(), remove shape/strides from __cache_signature__ (now handled by JitFunction) - Keep dynamic layout fields in cache signature Update JAX adapter cache signature to match (remove shape/strides). Update test to expect same cache sig for same-dtype different-shape. Note: requires rebuild (bash scripts/build.sh) for BufferLDST->BufferCopy rename in C++ bindings. Signed-off-by: Wen Chen <Wen.Chen@amd.com>
a58e87a to
3625a6d
Compare
|
@wenchenvincent is this for packing flydsl into JAX jit system and as xla backend? Currently flydsl is still in fast development mode and has not been ready enough to be part of xla. Many APIs still in dev. |
@coderfeli It is not to make FlyDSL part of xla but to provide a thin integration layer so that the users can use the FlyDSL kernel as a custom function in JAX. |
Motivation
FlyDSL currently only supports PyTorch tensors. This PR adds JAX support with two levels of integration:
from_jax): wrap JAX arrays and pass them directly to@flyc.jitfunctionsjax_kernel): call FlyDSL kernels insidejax.jitvia XLA custom callsPyTorch is also made an optional dependency — FlyDSL can now be imported and used without torch.
Technical Details
New package
python/flydsl/jax/:adapter.py—JaxTensorAdaptorwrappingjax.Arrayvia DLPack, all dtypes including float8primitive.py— JAX primitive with abstract eval, eager impl, and StableHLOCustomCallOploweringffi_bridge.py— Compiles kernels and registers them as XLA custom-call targets_xla_bridge.c— Thread-safe C trampoline bridging XLA's GPU calling convention to FlyDSL's bare-pointer convention, with baked scalar argument supportModified files:
compiler/jit_argument.py— torch import guarded behindtry/exceptcompiler/jit_function.py— Addedget_last_artifact()public API for external integrationsExamples: JAX versions of vectorAdd, tiledCopy, and tiledMma (both eager and jax.jit)
Test Plan
Test Result
All tests pass on MI300X, JAX 0.8.2. All examples produce correct results (max diff: 0.00e+00).
Submission Checklist