Help with upgrading legacy custom calls to FFI #33405
Replies: 2 comments 8 replies
-
|
Hello! It seems like you're very close. You might try something like operands = (a, b, c, d, system_depths,
as_mhlo_constant(num_systems, np.int64),
as_mhlo_constant(stride, np.int64),)
out = build_ffi_lowering_function(kernel, result_types=out_types)(ctx, *operands)I'm not positive that'll work out of the box since For reference, you can take a look at |
Beta Was this translation helpful? Give feedback.
-
|
Good call re Let's see; one thing that jumps out to me is when you call Another possibility is |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I maintain a few projects that extend JAX via custom Cython code, including Veros and mpi4jax. These projects have seen test failures when upgrading to JAX 0.8.0 since
jax.interpreters.mlir.custom_callis now removed.I've tried upgrading to
jax.ffibut am getting stuck on a few things. Simply replacingcustom_callwithbuild_ffi_lowering_functiondoesn't seem to do the trick (see traceback below).Full traceback
For the sake of this discussion, you can assume that I want to upgrade this file. Here's what I tried so far.
Happy for any pointers!
Beta Was this translation helpful? Give feedback.
All reactions