Skip to content

[Feature] Add JAX integration for FlyDSL kernels#257

Open
wenchenvincent wants to merge 9 commits intoROCm:mainfrom
wenchenvincent:feat/jax-integration
Open

[Feature] Add JAX integration for FlyDSL kernels#257
wenchenvincent wants to merge 9 commits intoROCm:mainfrom
wenchenvincent:feat/jax-integration

Conversation

@wenchenvincent
Copy link
Copy Markdown

Motivation

FlyDSL currently only supports PyTorch tensors. This PR adds JAX support with two levels of integration:

  • Eager mode (from_jax): wrap JAX arrays and pass them directly to @flyc.jit functions
  • jax.jit mode (jax_kernel): call FlyDSL kernels inside jax.jit via XLA custom calls

PyTorch is also made an optional dependency — FlyDSL can now be imported and used without torch.

Technical Details

New package python/flydsl/jax/:

  • adapter.pyJaxTensorAdaptor wrapping jax.Array via DLPack, all dtypes including float8
  • primitive.py — JAX primitive with abstract eval, eager impl, and StableHLO CustomCallOp lowering
  • ffi_bridge.py — Compiles kernels and registers them as XLA custom-call targets
  • _xla_bridge.c — Thread-safe C trampoline bridging XLA's GPU calling convention to FlyDSL's bare-pointer convention, with baked scalar argument support

Modified files:

  • compiler/jit_argument.py — torch import guarded behind try/except
  • compiler/jit_function.py — Added get_last_artifact() public API for external integrations

Examples: JAX versions of vectorAdd, tiledCopy, and tiledMma (both eager and jax.jit)

Test Plan

  • 30 unit tests covering adapter, primitive, C trampoline, and registration dedup
  • End-to-end integration test (102K-element vectorized add)
  • All 3 examples verified with both eager and jax.jit paths
  • Existing unit tests unaffected (34 passed, 2 pre-existing skips)

Test Result

All tests pass on MI300X, JAX 0.8.2. All examples produce correct results (max diff: 0.00e+00).

Submission Checklist

Copilot AI review requested due to automatic review settings March 21, 2026 08:34
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a first-class JAX integration layer for FlyDSL, enabling both eager execution on jax.Array inputs and jax.jit execution via XLA custom calls, while also making PyTorch an optional dependency for importing FlyDSL.

Changes:

  • Introduces python/flydsl/jax/ (adapter, primitive lowering, FFI bridge, and C trampoline) to run FlyDSL kernels from JAX (eager + jax.jit).
  • Makes PyTorch optional by guarding torch-specific registrations in jit_argument.py.
  • Exposes JitFunction.get_last_artifact() to allow external integrations (JAX bridge) to retrieve compiled artifacts.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/unit/test_jax_integration.py Adds unit tests for adapter/primitive/bridge behavior (skips if JAX missing).
tests/test_jax_vecadd.py Adds a JAX vector-add integration script/test (currently unguarded JAX import).
python/flydsl/jax/init.py Public entrypoints for from_jax, jax_kernel, and lazy wrappers.
python/flydsl/jax/adapter.py Implements JaxTensorAdaptor via DLPack for eager-mode @flyc.jit calls.
python/flydsl/jax/primitive.py Implements a JAX primitive + StableHLO CustomCallOp lowering for jax.jit.
python/flydsl/jax/ffi_bridge.py Compiles/registers FlyDSL kernels as XLA custom call targets + loads/builds trampoline.
python/flydsl/jax/_xla_bridge.c C trampoline bridging XLA GPU custom-call ABI to FlyDSL ptr-packing ABI.
python/flydsl/compiler/jit_function.py Adds get_last_artifact() and tracks last compilation result.
python/flydsl/compiler/jit_argument.py Makes torch optional; keeps PyTorch tensor support when installed.
python/flydsl/compiler/init.py Convenience re-export flydsl.compiler.from_jax.
examples/04-vectorAdd-jax.py JAX version of vector add (eager + jax.jit).
examples/05-tiledCopy-jax.py JAX version of tiled copy (eager + jax.jit).
examples/06-tiledMma-jax.py JAX version of tiled MMA (eager + jax.jit).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +203 to +214
# 0 = API_VERSION_ORIGINAL (CPU: fn(out, ins))
# 1 = API_VERSION_STATUS_RETURNING (CPU: fn(out, ins, status))
# 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len))
# 4 = API_VERSION_TYPED_FFI
# We use 2 for GPU custom calls with the old untyped convention.
# backend_config carries the opaque bytes (slot index for the C trampoline).
i32_type = jax_ir.IntegerType.get_signless(32)
call = stablehlo.CustomCallOp(
result_types,
list(args),
call_target_name=target_name,
api_version=jax_ir.IntegerAttr.get(i32_type, 2),
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stablehlo.CustomCallOp is emitted with api_version=2, but the target is registered in ffi_bridge._register_with_xla() using api_version=0. XLA/JAX expects these API versions to match; otherwise the runtime will call the bridge with a different calling convention than the C trampoline implements. Align the api_version used in both places (and update the docstring/comments accordingly).

Suggested change
# 0 = API_VERSION_ORIGINAL (CPU: fn(out, ins))
# 1 = API_VERSION_STATUS_RETURNING (CPU: fn(out, ins, status))
# 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len))
# 4 = API_VERSION_TYPED_FFI
# We use 2 for GPU custom calls with the old untyped convention.
# backend_config carries the opaque bytes (slot index for the C trampoline).
i32_type = jax_ir.IntegerType.get_signless(32)
call = stablehlo.CustomCallOp(
result_types,
list(args),
call_target_name=target_name,
api_version=jax_ir.IntegerAttr.get(i32_type, 2),
# 0 = API_VERSION_ORIGINAL (fn(out, ins))
# 1 = API_VERSION_STATUS_RETURNING (fn(out, ins, status))
# 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len))
# 4 = API_VERSION_TYPED_FFI
# We use 0 here to match the api_version used in ffi_bridge._register_with_xla().
# backend_config carries the opaque bytes (slot index for the C trampoline).
i32_type = jax_ir.IntegerType.get_signless(32)
call = stablehlo.CustomCallOp(
result_types,
list(args),
call_target_name=target_name,
api_version=jax_ir.IntegerAttr.get(i32_type, 0),

Copilot uses AI. Check for mistakes.
Comment on lines +262 to +266
# api_version=0: old custom-call convention
# void fn(stream, void** buffers, const char* opaque, size_t opaque_len)
_xla_client.register_custom_call_target(
target.name, capsule, xla_platform_name, api_version=0,
)
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_register_with_xla() registers the custom call target with api_version=0, but the lowering in primitive.py sets api_version=2 on the StableHLO CustomCallOp. This mismatch will typically lead to the wrong runtime ABI being used for the callback. Update the registration to use the same API version as the emitted CustomCallOp (and keep the inline comments/docstring consistent).

Copilot uses AI. Check for mistakes.
Comment on lines +59 to +71
def _ensure_bridge_lib() -> ctypes.CDLL:
"""Load ``_xla_bridge.so``, compiling from source if necessary."""
if not _BRIDGE_SO.exists():
if not _BRIDGE_C.exists():
raise FileNotFoundError(
f"Cannot find XLA bridge source: {_BRIDGE_C}\n"
f"Please rebuild or reinstall flydsl."
)
subprocess.check_call(
["gcc", "-shared", "-fPIC", "-O2", "-o", str(_BRIDGE_SO), str(_BRIDGE_C)],
cwd=str(_THIS_DIR),
)
lib = ctypes.CDLL(str(_BRIDGE_SO))
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building _xla_bridge.so at import time via a hard-coded gcc invocation is brittle operationally (requires a compiler at runtime, write permissions to the installed package dir, and fails on non-GNU toolchains). If you do keep this fallback, add the proper thread library flags (-pthread / -lpthread) since _xla_bridge.c uses pthreads, and consider moving compilation to the package build step or caching in a writable user cache dir instead of the source tree.

Copilot uses AI. Check for mistakes.
Comment on lines +172 to +175
with _lock:
if target_name in _registered_targets:
return target_name

Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile_and_register() does a membership check under _lock and then releases the lock before compilation/registration. Two threads lowering the same shapes can race and both compile + register the same target_name, potentially causing duplicate registrations or nondeterministic behavior. Consider holding the lock through the whole register path, or store an “in-progress” sentinel (e.g., a threading.Event/Future) so only one thread compiles while others wait.

Copilot uses AI. Check for mistakes.
Comment on lines +51 to +57
pthread_mutex_unlock(&g_lock);

g_targets[idx].func = (flydsl_func_t)func_ptr;
g_targets[idx].n_buffers = n_buffers;
g_targets[idx].n_scalars = n_scalars;
for (int i = 0; i < n_scalars; i++)
g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0;
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flydsl_xla_register() increments g_n_targets under the mutex, then releases the lock and only afterwards populates g_targets[idx]. xla_bridge_dispatch() can run concurrently, observe idx < g_n_targets, and read a partially initialized slot (including a NULL/garbage func pointer). Populate the slot while still holding g_lock (or use an atomic publish pattern) to make registration + dispatch thread-safe.

Suggested change
pthread_mutex_unlock(&g_lock);
g_targets[idx].func = (flydsl_func_t)func_ptr;
g_targets[idx].n_buffers = n_buffers;
g_targets[idx].n_scalars = n_scalars;
for (int i = 0; i < n_scalars; i++)
g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0;
g_targets[idx].func = (flydsl_func_t)func_ptr;
g_targets[idx].n_buffers = n_buffers;
g_targets[idx].n_scalars = n_scalars;
for (int i = 0; i < n_scalars; i++)
g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0;
pthread_mutex_unlock(&g_lock);

Copilot uses AI. Check for mistakes.
Comment on lines +63 to +66
int idx = 0;
if (opaque_len >= sizeof(int))
memcpy(&idx, opaque, sizeof(int));

Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In xla_bridge_dispatch(), if opaque_len < sizeof(int) you leave idx as 0 and will dispatch slot 0 even when the opaque payload is missing/invalid. This can silently call the wrong kernel. Consider treating short/invalid opaque as an error and returning early (and ideally validate opaque_len == sizeof(int) since you control the encoding).

Suggested change
int idx = 0;
if (opaque_len >= sizeof(int))
memcpy(&idx, opaque, sizeof(int));
/* Expect opaque to contain exactly one int index. Treat mismatched
* or missing opaque as an error to avoid dispatching the wrong slot.
*/
if (opaque == NULL || opaque_len != sizeof(int))
return;
int idx;
memcpy(&idx, opaque, sizeof(int));

Copilot uses AI. Check for mistakes.
Comment on lines +84 to +87
for (int i = 0; i < ns; i++) {
storage[nb + i] = (void*)t->scalar_vals[i];
packed[nb + i] = &storage[nb + i];
}
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scalar packing currently does storage[nb + i] = (void*)t->scalar_vals[i]; which relies on casting integers to pointers and on the callee interpreting the in-memory representation of a void* as an integer scalar. This is non-portable/undefined behavior and can break for large/negative values or non-integer scalars. Prefer storing scalars in a dedicated int64_t scalar_storage[] array on the stack (or similar) and set packed[...] to point at the scalar bytes (matching how FlyDSL packs scalar args via ctypes).

Copilot uses AI. Check for mistakes.
Comment on lines +12 to +20
import jax
import jax.numpy as jnp
import numpy as np

import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.jax import from_jax


Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file imports jax unconditionally at module import time. Because it’s under tests/ and named test_*.py, pytest will import it during collection, causing the whole test run to error in environments where JAX isn’t installed (even if JAX support is meant to be optional). Add a try/except ImportError + pytest.skip(..., allow_module_level=True) guard (like tests/unit/test_jax_integration.py), or move/rename it so it’s not collected by pytest when JAX is absent.

Suggested change
import jax
import jax.numpy as jnp
import numpy as np
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.jax import from_jax
import numpy as np
import flydsl.compiler as flyc
import flydsl.expr as fx
try:
import jax
import jax.numpy as jnp
from flydsl.jax import from_jax
except ImportError:
import pytest
pytest.skip("JAX not installed; skipping JAX vecAdd test.", allow_module_level=True)

Copilot uses AI. Check for mistakes.
Add two levels of JAX integration so FlyDSL GPU kernels can be called
from JAX code on AMD GPUs:

Level 1 -- Eager mode (from_jax):
  Wrap jax.Array objects via DLPack so they can be passed directly to
  @flyc.jit functions. Works with all dtypes including float8 variants.

Level 2 -- jax.jit integration (jax_kernel):
  Register compiled FlyDSL kernels as XLA custom-call targets via a C
  trampoline that bridges XLA's GPU calling convention (stream, buffers,
  opaque) to FlyDSL's bare-pointer convention. Supports tensor buffers,
  baked scalar arguments, and compile-time constants.

New package python/flydsl/jax/:
  - adapter.py: JaxTensorAdaptor with DLPack, float8, layout dynamism
  - primitive.py: JAX primitive with abstract eval, eager impl, and
    StableHLO CustomCallOp lowering for rocm/gpu platforms
  - ffi_bridge.py: Compilation, XLA target registration, opaque encoding
  - _xla_bridge.c: Thread-safe C trampoline with scalar insertion

Also makes PyTorch an optional dependency by guarding the torch import
in jit_argument.py behind try/except. The JitArgumentRegistry, compiler
decorators, and JAX integration all work without torch installed.

Adds JitFunction.get_last_artifact() public API for external integrations
to retrieve compiled kernels without accessing private attributes.

Tested on MI300X with JAX 0.8.2, all examples produce correct results.

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
JAX equivalents of examples 01-03, demonstrating both eager mode
(from_jax + @flyc.jit) and jax.jit mode (jax_kernel + XLA custom call):

- 04-vectorAdd-jax.py: Vector addition with layout algebra
- 05-tiledCopy-jax.py: Tiled copy with partitioned tensors
- 06-tiledMma-jax.py: Single-tile GEMM (64x64x8) using MFMA instructions

All produce correct results on MI300X with JAX 0.8.2.

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
30 unit tests in tests/unit/test_jax_integration.py:
- Adapter: 18 tests (dtypes, shapes, fly_types, fly_ptrs, cache
  signature, layout dynamic, alignment, error cases, float8)
- Primitive: 5 tests (eager single/multi output, constexpr/scalar
  forwarding, abstract eval)
- C trampoline: 5 tests (dispatch, scalar insertion, slot allocation,
  bounds rejection)
- Registration dedup: 2 tests (name hashing stability)

End-to-end test in tests/test_jax_vecadd.py:
- 102,400-element vectorized add using JAX arrays, max error = 0.00

All 30 unit tests pass in ~4s on JAX 0.8.2 without torch.

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
1. _xla_bridge.c: Populate slot data inside the mutex-protected section
   so concurrent dispatchers never see uninitialized data (Comment 5).
   Add strict opaque validation -- return early if NULL or wrong size
   instead of silently defaulting to slot 0 (Comment 6). Use dedicated
   int64_t scalar_storage[] array instead of casting integers to void*
   pointers, avoiding undefined behavior (Comment 7).

2. ffi_bridge.py: Hold the lock through the full compile+register path
   to prevent concurrent threads from duplicating work (Comment 4).
   Use 'cc' instead of 'gcc' for portability, add -lpthread flag
   (Comment 3). Clarify api_version cross-references between the XLA
   registration API (0=untyped) and StableHLO CustomCallOp
   (2=STATUS_RETURNING_UNIFIED) which refer to the same calling
   convention (Comments 1+2).

3. test_jax_vecadd.py: Add try/except ImportError guard with
   pytest.skip so the test suite doesn't break when JAX is absent
   (Comment 8).

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
Signed-off-by: Wen Chen <Wen.Chen@amd.com>
Fix the jax.jit path when FLYDSL_RUNTIME_ENABLE_CACHE is enabled:
- Force fresh compilation in compile_and_register by temporarily
  clearing JitFunction's _mem_cache and _call_state_cache, preventing
  stale artifacts from a prior eager call (with different adaptor
  options like mark_layout_dynamic) from being reused
- Set _last_compiled on cache-hit paths so get_last_artifact() always
  returns the correct artifact
- Add auto-rebuild of _xla_bridge.so when .c source is newer

Address all 8 Copilot review comments:
1-2. Clarify api_version cross-references (registration=0, StableHLO=2)
3. Use 'cc' instead of 'gcc', add -lpthread
4. Hold lock through full compile+register path
5. Populate C trampoline slot inside mutex
6. Strict opaque validation in dispatch
7. Use int64_t scalar_storage[] instead of void* cast
8. Add pytest.skip guard in test_jax_vecadd.py

Known limitation: the jax.jit path for kernels using
make_buffer_tensor (e.g. 04-vectorAdd-jax.py) requires
FLYDSL_RUNTIME_ENABLE_CACHE=0 when the same kernel is also called
via the eager path with mark_layout_dynamic in the same process.
This is due to cache key collisions in JitFunction where
mark_layout_dynamic changes the MLIR type but not the cache
signature. The 05-tiledCopy and 06-tiledMma examples work correctly
with caching enabled.

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
mark_layout_dynamic changes the MLIR memref type (adding dynamic
dimensions) but previously did not affect __cache_signature__, causing
cache key collisions between dynamic and static layouts. This led to
the wrong compiled artifact being reused when the same kernel was
called with and without mark_layout_dynamic in the same process.

Fix: Track _dynamic_leading_dim and _dynamic_divisibility on both
TensorAdaptor (torch) and JaxTensorAdaptor (JAX), and include them
in __cache_signature__. Adaptors without mark_layout_dynamic have
None values which don't collide with any (int, int) pair.

Also removes the cache-clearing workaround from ffi_bridge.py since
the root cause is now fixed.

All examples pass with JIT caching enabled (no FLYDSL_RUNTIME_ENABLE_CACHE=0 needed).

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
- Add JAX ROCm installation step (jax, jaxlib, jax-rocm7-pjrt,
  jax-rocm7-plugin) to test-whl.yaml CI workflow. Marked as
  continue-on-error so JAX installation failures don't block
  the rest of the test suite.

- Add import guard to all three JAX examples (04, 05, 06): if JAX
  is not installed, print "SKIP: JAX not installed" and exit 0.

- Update run_tests.sh to detect "SKIP:" output from examples and
  report them as SKIP instead of PASS or FAIL.

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
- ffi_bridge.py: remove unused `import os`
- primitive.py: remove unused `Union` from typing imports

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
Resolve conflict in jit_argument.py:
- Keep torch-optional structure (try/except import torch)
- Incorporate main's cache key refactor: add raw_cache_signature(),
  remove shape/strides from __cache_signature__ (now handled by
  JitFunction)
- Keep dynamic layout fields in cache signature

Update JAX adapter cache signature to match (remove shape/strides).
Update test to expect same cache sig for same-dtype different-shape.

Note: requires rebuild (bash scripts/build.sh) for BufferLDST->BufferCopy
rename in C++ bindings.

Signed-off-by: Wen Chen <Wen.Chen@amd.com>
@coderfeli
Copy link
Copy Markdown
Collaborator

@wenchenvincent is this for packing flydsl into JAX jit system and as xla backend? Currently flydsl is still in fast development mode and has not been ready enough to be part of xla. Many APIs still in dev.

@wenchenvincent
Copy link
Copy Markdown
Author

@wenchenvincent is this for packing flydsl into JAX jit system and as xla backend? Currently flydsl is still in fast development mode and has not been ready enough to be part of xla. Many APIs still in dev.

@coderfeli It is not to make FlyDSL part of xla but to provide a thin integration layer so that the users can use the FlyDSL kernel as a custom function in JAX.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants