Skip to content

Commit

Permalink
Merge branch 'lwawrzyniak/jax-ffi-update2' into 'main'
Browse files Browse the repository at this point in the history
Update jax_callable

See merge request omniverse/warp!1089
  • Loading branch information
nvlukasz committed Feb 14, 2025
2 parents 05db5a3 + 26bdfe6 commit cde8dea
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
69 changes: 37 additions & 32 deletions warp/jax_experimental/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def register_jax_kernel_callback():
ffi_callbacks = {}


def jax_callable(func, num_outputs=1, vmap_method="broadcast_all"):
return FfiCallable(func, num_outputs, vmap_method)
def jax_callable(func, num_outputs=1, vmap_method="broadcast_all", graph_compatible=True):
return FfiCallable(func, num_outputs, vmap_method, graph_compatible)


class FfiArg:
Expand Down Expand Up @@ -322,21 +322,21 @@ def __init__(self, name, type):
raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}")


class FfiCall:
def __init__(self, callable, static_inputs):
self.callable = callable
class FfiCallDesc:
def __init__(self, static_inputs):
self.static_inputs = static_inputs


class FfiCallable:
call_id = 0
call_descriptors = {}

def __init__(self, func, num_outputs, vmap_method):
def __init__(self, func, num_outputs, vmap_method, graph_compatible):
self.func = func
self.name = make_full_qualified_name(func)
self.num_outputs = num_outputs
self.vmap_method = vmap_method
self.graph_compatible = graph_compatible
self.has_static_args = False
self.call_id = 0
self.call_descriptors = {}

# get arguments and annotations
argspec = get_full_arg_spec(func)
Expand All @@ -356,14 +356,17 @@ def __init__(self, func, num_outputs, vmap_method):
if arg_type is not None:
raise TypeError("Function must not return a value")
else:
self.args.append(FfiArg(arg_name, arg_type))
arg = FfiArg(arg_name, arg_type)
if not arg.is_array:
self.has_static_args = True
self.args.append(arg)

self.input_args = self.args[: self.num_inputs]
self.output_args = self.args[self.num_inputs :]

# register the callback
FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
callback_func = FFI_CCALLFUNC(FfiCallable.ffi_callback)
callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
ffi_callbacks[self.name] = callback_func
ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
Expand Down Expand Up @@ -443,15 +446,16 @@ def __call__(self, *args, output_dims=None, vmap_method=None):
module = wp.get_module(self.func.__module__)
module.load(device)

# save call data to be retrieved by callback
call_id = FfiCallable.call_id
FfiCallable.call_descriptors[call_id] = FfiCall(self, static_inputs)
FfiCallable.call_id += 1

return call(*args, call_id=call_id)
if self.has_static_args:
# save call data to be retrieved by callback
call_id = self.call_id
self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
self.call_id += 1
return call(*args, call_id=call_id)
else:
return call(*args)

@staticmethod
def ffi_callback(call_frame):
def ffi_callback(self, call_frame):
try:
# TODO Try-catch around the body and return XLA_FFI_Error on error.
extension = call_frame.contents.extension_start
Expand All @@ -465,25 +469,26 @@ def ffi_callback(call_frame):
metadata_ext.contents.metadata.contents.api_version.major_version = 0
metadata_ext.contents.metadata.contents.api_version.minor_version = 1
# Turn on CUDA graphs for this handler.
metadata_ext.contents.metadata.contents.traits = (
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
)
if self.graph_compatible:
metadata_ext.contents.metadata.contents.traits = (
XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
)
return None

attrs = decode_attrs(call_frame.contents.attrs)

# retrieve call info
call_id = int(attrs["call_id"])
call_desc = FfiCallable.call_descriptors[call_id]
if self.has_static_args:
# retrieve call info
attrs = decode_attrs(call_frame.contents.attrs)
call_id = int(attrs["call_id"])
call_desc = self.call_descriptors[call_id]

num_inputs = call_frame.contents.args.size
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))

num_outputs = call_frame.contents.rets.size
outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))

assert num_inputs == call_desc.callable.num_inputs
assert num_outputs == call_desc.callable.num_outputs
assert num_inputs == self.num_inputs
assert num_outputs == self.num_outputs

device = wp.device_from_jax(get_jax_device())
cuda_stream = get_stream_from_callframe(call_frame.contents)
Expand All @@ -494,7 +499,7 @@ def ffi_callback(call_frame):

# inputs
for i in range(num_inputs):
arg = call_desc.callable.input_args[i]
arg = self.input_args[i]
if arg.is_array:
buffer = inputs[i].contents
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
Expand All @@ -507,15 +512,15 @@ def ffi_callback(call_frame):

# outputs
for i in range(num_outputs):
arg = call_desc.callable.output_args[i]
arg = self.output_args[i]
buffer = outputs[i].contents
shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
arg_list.append(arr)

# call the Python function with reconstructed arguments
with wp.ScopedStream(stream, sync_enter=False):
call_desc.callable.func(*arg_list)
self.func(*arg_list)

except Exception as e:
print(traceback.format_exc())
Expand Down
2 changes: 1 addition & 1 deletion warp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ def type_is_transformation(t):
return getattr(t, "_wp_generic_type_hint_", None) is Transformation


value_types = (int, float, builtins.bool) + scalar_types
value_types = (int, float, builtins.bool) + scalar_and_bool_types


# returns true for all value types (int, float, bool, scalars, vectors, matrices)
Expand Down

0 comments on commit cde8dea

Please sign in to comment.