From a9f9d7d3c4a0fe267d304a0001abb5f476aed7df Mon Sep 17 00:00:00 2001 From: Lukasz Wawrzyniak Date: Fri, 14 Feb 2025 10:37:51 -0500 Subject: [PATCH 1/2] Add graph_compatible argument to jax_callable() --- warp/jax_experimental/ffi.py | 69 +++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/warp/jax_experimental/ffi.py b/warp/jax_experimental/ffi.py index b76b744e3..c62688879 100644 --- a/warp/jax_experimental/ffi.py +++ b/warp/jax_experimental/ffi.py @@ -290,8 +290,8 @@ def register_jax_kernel_callback(): ffi_callbacks = {} -def jax_callable(func, num_outputs=1, vmap_method="broadcast_all"): - return FfiCallable(func, num_outputs, vmap_method) +def jax_callable(func, num_outputs=1, vmap_method="broadcast_all", graph_compatible=True): + return FfiCallable(func, num_outputs, vmap_method, graph_compatible) class FfiArg: @@ -322,21 +322,21 @@ def __init__(self, name, type): raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}") -class FfiCall: - def __init__(self, callable, static_inputs): - self.callable = callable +class FfiCallDesc: + def __init__(self, static_inputs): self.static_inputs = static_inputs class FfiCallable: - call_id = 0 - call_descriptors = {} - - def __init__(self, func, num_outputs, vmap_method): + def __init__(self, func, num_outputs, vmap_method, graph_compatible): self.func = func self.name = make_full_qualified_name(func) self.num_outputs = num_outputs self.vmap_method = vmap_method + self.graph_compatible = graph_compatible + self.has_static_args = False + self.call_id = 0 + self.call_descriptors = {} # get arguments and annotations argspec = get_full_arg_spec(func) @@ -356,14 +356,17 @@ def __init__(self, func, num_outputs, vmap_method): if arg_type is not None: raise TypeError("Function must not return a value") else: - self.args.append(FfiArg(arg_name, arg_type)) + arg = FfiArg(arg_name, arg_type) + if not arg.is_array: + self.has_static_args = True + self.args.append(arg) self.input_args = self.args[: self.num_inputs] self.output_args = self.args[self.num_inputs :] # register the callback FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame)) - callback_func = FFI_CCALLFUNC(FfiCallable.ffi_callback) + callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame)) ffi_callbacks[self.name] = callback_func ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p) ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value) @@ -443,15 +446,16 @@ def __call__(self, *args, output_dims=None, vmap_method=None): module = wp.get_module(self.func.__module__) module.load(device) - # save call data to be retrieved by callback - call_id = FfiCallable.call_id - FfiCallable.call_descriptors[call_id] = FfiCall(self, static_inputs) - FfiCallable.call_id += 1 - - return call(*args, call_id=call_id) + if self.has_static_args: + # save call data to be retrieved by callback + call_id = self.call_id + self.call_descriptors[call_id] = FfiCallDesc(static_inputs) + self.call_id += 1 + return call(*args, call_id=call_id) + else: + return call(*args) - @staticmethod - def ffi_callback(call_frame): + def ffi_callback(self, call_frame): try: # TODO Try-catch around the body and return XLA_FFI_Error on error. extension = call_frame.contents.extension_start @@ -465,16 +469,17 @@ def ffi_callback(call_frame): metadata_ext.contents.metadata.contents.api_version.major_version = 0 metadata_ext.contents.metadata.contents.api_version.minor_version = 1 # Turn on CUDA graphs for this handler. - metadata_ext.contents.metadata.contents.traits = ( - XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE - ) + if self.graph_compatible: + metadata_ext.contents.metadata.contents.traits = ( + XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE + ) return None - attrs = decode_attrs(call_frame.contents.attrs) - - # retrieve call info - call_id = int(attrs["call_id"]) - call_desc = FfiCallable.call_descriptors[call_id] + if self.has_static_args: + # retrieve call info + attrs = decode_attrs(call_frame.contents.attrs) + call_id = int(attrs["call_id"]) + call_desc = self.call_descriptors[call_id] num_inputs = call_frame.contents.args.size inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer))) @@ -482,8 +487,8 @@ def ffi_callback(call_frame): num_outputs = call_frame.contents.rets.size outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer))) - assert num_inputs == call_desc.callable.num_inputs - assert num_outputs == call_desc.callable.num_outputs + assert num_inputs == self.num_inputs + assert num_outputs == self.num_outputs device = wp.device_from_jax(get_jax_device()) cuda_stream = get_stream_from_callframe(call_frame.contents) @@ -494,7 +499,7 @@ def ffi_callback(call_frame): # inputs for i in range(num_inputs): - arg = call_desc.callable.input_args[i] + arg = self.input_args[i] if arg.is_array: buffer = inputs[i].contents shape = buffer.dims[: buffer.rank - arg.dtype_ndim] @@ -507,7 +512,7 @@ def ffi_callback(call_frame): # outputs for i in range(num_outputs): - arg = call_desc.callable.output_args[i] + arg = self.output_args[i] buffer = outputs[i].contents shape = buffer.dims[: buffer.rank - arg.dtype_ndim] arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device) @@ -515,7 +520,7 @@ def ffi_callback(call_frame): # call the Python function with reconstructed arguments with wp.ScopedStream(stream, sync_enter=False): - call_desc.callable.func(*arg_list) + self.func(*arg_list) except Exception as e: print(traceback.format_exc()) From 26bdfe665f79030966c4896d0e324dc1d8c2f5e9 Mon Sep 17 00:00:00 2001 From: Lukasz Wawrzyniak Date: Fri, 14 Feb 2025 10:59:00 -0500 Subject: [PATCH 2/2] Add wp.bool to value_types --- warp/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/warp/types.py b/warp/types.py index 7a92bcf30..b4c55482e 100644 --- a/warp/types.py +++ b/warp/types.py @@ -1384,7 +1384,7 @@ def type_is_transformation(t): return getattr(t, "_wp_generic_type_hint_", None) is Transformation -value_types = (int, float, builtins.bool) + scalar_types +value_types = (int, float, builtins.bool) + scalar_and_bool_types # returns true for all value types (int, float, bool, scalars, vectors, matrices)