Skip to content

Add support for raw_kernel_arg extension #2038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ._sycl_platform import SyclPlatform, get_platforms, lsplatform
from ._sycl_queue import (
LocalAccessor,
RawKernelArg,
SyclKernelInvalidRangeError,
SyclKernelSubmitError,
SyclQueue,
Expand Down Expand Up @@ -106,6 +107,7 @@
"SyclQueueCreationError",
"WorkGroupMemory",
"LocalAccessor",
"RawKernelArg",
]
__all__ += [
"get_device_cached_queue",
Expand Down
12 changes: 12 additions & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ cdef extern from "syclinterface/dpctl_sycl_enum_types.h":
_VOID_PTR "DPCTL_VOID_PTR",
_LOCAL_ACCESSOR "DPCTL_LOCAL_ACCESSOR",
_WORK_GROUP_MEMORY "DPCTL_WORK_GROUP_MEMORY"
_RAW_KERNEL_ARG "DPCTL_RAW_KERNEL_ARG"

ctypedef enum _queue_property_type "DPCTLQueuePropertyType":
_DEFAULT_PROPERTY "DPCTL_DEFAULT_PROPERTY"
Expand Down Expand Up @@ -571,3 +572,14 @@ cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
DPCTLSyclWorkGroupMemoryRef Ref)

cdef bint DPCTLWorkGroupMemory_Available()

cdef struct DPCTLOpaqueRawKernelArg
ctypedef DPCTLOpaqueRawKernelArg *DPCTLSyclRawKernelArgRef

cdef DPCTLSyclRawKernelArgRef DPCTLRawKernelArg_Create(void* bytes,
size_t count)

cdef void DPCTLRawKernelArg_Delete(
DPCTLSyclRawKernelArgRef Ref)

cdef bint DPCTLRawKernelArg_Available()
11 changes: 11 additions & 0 deletions dpctl/_sycl_queue.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ from libcpp cimport bool as cpp_bool
from ._backend cimport (
DPCTLSyclDeviceRef,
DPCTLSyclQueueRef,
DPCTLSyclRawKernelArgRef,
DPCTLSyclWorkGroupMemoryRef,
_arg_data_type,
)
Expand Down Expand Up @@ -115,3 +116,13 @@ cdef public api class WorkGroupMemory(_WorkGroupMemory) [
object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType
]:
pass

cdef public api class _RawKernelArg [
object Py_RawKernelArgObject, type Py_RawKernelArgType
]:
cdef DPCTLSyclRawKernelArgRef _arg_ref

cdef public api class RawKernelArg(_RawKernelArg) [
object PyRawKernelArgObject, type PyRawKernelArgType
]:
pass
110 changes: 110 additions & 0 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ from ._backend cimport ( # noqa: E211
DPCTLQueue_SubmitNDRange,
DPCTLQueue_SubmitRange,
DPCTLQueue_Wait,
DPCTLRawKernelArg_Available,
DPCTLRawKernelArg_Create,
DPCTLRawKernelArg_Delete,
DPCTLSyclContextRef,
DPCTLSyclDeviceSelectorRef,
DPCTLSyclEventRef,
Expand Down Expand Up @@ -364,6 +367,15 @@ cdef class _kernel_arg_type:
_arg_data_type._WORK_GROUP_MEMORY
)

@property
def dpctl_raw_kernel_arg(self):
cdef str p_name = "dpctl_raw_kernel_arg"
return kernel_arg_type_attribute(
self._name,
p_name,
_arg_data_type._RAW_KERNEL_ARG
)


kernel_arg_type = _kernel_arg_type()

Expand Down Expand Up @@ -973,6 +985,9 @@ cdef class SyclQueue(_SyclQueue):
elif isinstance(arg, LocalAccessor):
kargs[idx] = <void*>((<LocalAccessor>arg).addressof())
kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
elif isinstance(arg, RawKernelArg):
kargs[idx] = <void*>(<size_t>arg._ref)
kargty[idx] = _arg_data_type._RAW_KERNEL_ARG
else:
ret = -1
return ret
Expand Down Expand Up @@ -1738,3 +1753,98 @@ cdef class WorkGroupMemory:
"""
def __get__(self):
return <size_t>self._mem_ref


cdef class _RawKernelArg:
def __dealloc(self):
if(self._arg_ref):
DPCTLRawKernelArg_Delete(self._arg_ref)


cdef class RawKernelArg:
"""
RawKernelArg(*args)
Python class representing the ``raw_kernel_arg`` class from the Raw Kernel
Argument oneAPI SYCL extension for passing binary data as data to kernels.

This class is intended to be used as kernel argument when launching kernels.

This is based on a DPC++ SYCL extension and only available in newer
versions. Use ``is_available()`` to check availability in your build.

There are multiple ways to create a ``RawKernelArg``.

- If the constructor is invoked with just a single argument, this argument
is expected to expose the Python buffer interface. The raw kernel arg will
be constructed from the data in that buffer.

- If the constructor is invoked with two arguments, the first argument is
interpreted as the number of bytes in the binary argument, while the
second argument is interpreted as a pointer to the data.

Note that construction of the ``RawKernelArg`` copies the bytes, so
modifications made after construction of the ``RawKernelArg`` will not be
reflected in the kernel launch.

Args:
args:
Variadic argument, see class documentation.

Raises:
TypeError: In case of incorrect arguments given to constructurs,
unexpected types of input arguments.
"""
def __cinit__(self, *args):
cdef void* ptr = NULL
cdef size_t count
cdef int ret_code = 0
cdef Py_buffer _buffer
cdef bint _is_buf

if not DPCTLRawKernelArg_Available():
raise RuntimeError("Raw kernel arg extension not available")

if not (0 < len(args) < 3):
raise TypeError("RawKernelArg constructor takes 1 or 2 "
f"arguments, but {len(args)} were given")

if len(args) == 1:
if not _is_buffer(args[0]):
raise TypeError("RawKernelArg single argument constructor"
"expects argument to be buffer",
f"but got {type(args[0])}")

ret_code = PyObject_GetBuffer(args[0], &(_buffer),
PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
if ret_code != 0: # pragma: no cover
raise RuntimeError("Could not access buffer")

ptr = _buffer.buf
count = _buffer.len
_is_buf = True
else:
if not isinstance(args[0], numbers.Integral):
raise TypeError("RawKernelArg constructor expects first"
"argument to be `int`, but got {type(args[0])}")
if not isinstance(args[1], numbers.Integral):
raise TypeError("RawKernelArg constructor expects second"
"argument to be `int`, but got {type(args[1])}")

_is_buf = False
count = args[0]
ptr = <void*>(<unsigned long long>args[1])

self._arg_ref = DPCTLRawKernelArg_Create(ptr, count)
if(_is_buf):
PyBuffer_Release(&(_buffer))

@staticmethod
def is_available():
return DPCTLRawKernelArg_Available()

property _ref:
"""Returns the address of the C API ``DPCTLRawKernelArgRef`` pointer
as a ``size_t``.
"""
def __get__(self):
return <size_t>self._arg_ref
Binary file added dpctl/tests/input_files/raw-arg-kernel.spv
Binary file not shown.
112 changes: 112 additions & 0 deletions dpctl/tests/test_raw_kernel_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines unit test cases for the work_group_memory in a SYCL kernel"""

import ctypes
import os

import pytest

import dpctl
import dpctl.tensor


def get_spirv_abspath(fn):
curr_dir = os.path.dirname(os.path.abspath(__file__))
spirv_file = os.path.join(curr_dir, "input_files", fn)
return spirv_file


# The kernel in the SPIR-V file used in this test was generated from the
# following SYCL source code:
# #include <sycl/sycl.hpp>
#
# using namespace sycl;
#
# namespace syclexp = sycl::ext::oneapi::experimental;
# namespace syclext = sycl::ext::oneapi;
#
# using data_t = int32_t;
#
# struct Params { data_t mul; data_t add; };
#
# extern "C" SYCL_EXTERNAL
# SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
# void raw_arg_kernel(data_t* in, data_t* out, Params p){
# auto item = syclext::this_work_item::get_nd_item<1>();
# size_t global_id = item.get_global_linear_id();
# out[global_id] = (in[global_id] * p.mul) + p.add;
# }


class Params(ctypes.Structure):
_fields_ = [("mul", ctypes.c_int32), ("add", ctypes.c_int32)]


def launch_raw_arg_kernel(raw):
if not dpctl.RawKernelArg.is_available():
pytest.skip("Raw kernel arg extension not supported")

try:
q = dpctl.SyclQueue("level_zero")
except dpctl.SyclQueueCreationError:
pytest.skip("LevelZero queue could not be created")
spirv_file = get_spirv_abspath("raw-arg-kernel.spv")
with open(spirv_file, "br") as spv:
spv_bytes = spv.read()
prog = dpctl.program.create_program_from_spirv(q, spv_bytes)
kernel = prog.get_sycl_kernel("__sycl_kernel_raw_arg_kernel")
local_size = 16
global_size = local_size * 8

x = dpctl.tensor.ones(global_size, dtype="int32")
y = dpctl.tensor.zeros(global_size, dtype="int32")
x.sycl_queue.wait()
y.sycl_queue.wait()

try:
q.submit(
kernel,
[
x.usm_data,
y.usm_data,
raw,
],
[global_size],
[local_size],
)
q.wait()
except dpctl._sycl_queue.SyclKernelSubmitError:
pytest.skip(f"Kernel submission to {q.sycl_device} failed")

assert dpctl.tensor.all(y == 9)


def test_submit_raw_kernel_arg_pointer():
paramStruct = Params(4, 5)
raw = dpctl.RawKernelArg(
ctypes.sizeof(paramStruct), ctypes.addressof(paramStruct)
)
launch_raw_arg_kernel(raw)


def test_submit_raw_kernel_arg_buffer():
paramStruct = Params(4, 5)
byteArr = bytearray(paramStruct)
raw = dpctl.RawKernelArg(byteArr)
del byteArr
launch_raw_arg_kernel(raw)
1 change: 1 addition & 0 deletions dpctl/tests/test_sycl_kernel_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def test_kernel_arg_type():
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_raw_kernel_arg)


def get_spirv_abspath(fn):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ typedef enum
DPCTL_VOID_PTR,
DPCTL_LOCAL_ACCESSOR,
DPCTL_WORK_GROUP_MEMORY,
DPCTL_RAW_KERNEL_ARG,
DPCTL_UNSUPPORTED_KERNEL_ARG
} DPCTLKernelArgType;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,16 @@ void DPCTLWorkGroupMemory_Delete(__dpctl_take DPCTLSyclWorkGroupMemoryRef Ref);
DPCTL_API
bool DPCTLWorkGroupMemory_Available();

typedef struct DPCTLOpaqueSyclRawKernelArg *DPCTLSyclRawKernelArgRef;

DPCTL_API
__dpctl_give DPCTLSyclRawKernelArgRef DPCTLRawKernelArg_Create(void *bytes,
size_t count);

DPCTL_API
void DPCTLRawKernelArg_Delete(__dpctl_take DPCTLSyclRawKernelArgRef Ref);

DPCTL_API
bool DPCTLRawKernelArg_Available();

DPCTL_C_EXTERN_C_END
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<DPCTLSyclEventRef>,
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(RawWorkGroupMemory,
DPCTLSyclWorkGroupMemoryRef)

DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<unsigned char>,
DPCTLSyclRawKernelArgRef)

#endif

} // namespace dpctl::syclinterface
35 changes: 35 additions & 0 deletions libsyclinterface/source/dpctl_sycl_extension_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,38 @@ bool DPCTLWorkGroupMemory_Available()
return false;
#endif
}

using raw_kernel_arg_t = std::vector<unsigned char>;

DPCTL_API
__dpctl_give DPCTLSyclRawKernelArgRef DPCTLRawKernelArg_Create(void *bytes,
size_t count)
{
DPCTLSyclRawKernelArgRef rka = nullptr;
try {
auto RawKernelArg =
std::unique_ptr<raw_kernel_arg_t>(new raw_kernel_arg_t(count));
std::memcpy(RawKernelArg->data(), bytes, count);
rka = wrap<raw_kernel_arg_t>(RawKernelArg.get());
RawKernelArg.release();
} catch (std::exception const &e) {
error_handler(e, __FILE__, __func__, __LINE__);
}
return rka;
}

DPCTL_API
void DPCTLRawKernelArg_Delete(__dpctl_take DPCTLSyclRawKernelArgRef Ref)
{
delete unwrap<raw_kernel_arg_t>(Ref);
}

DPCTL_API
bool DPCTLRawKernelArg_Available()
{
#ifdef SYCL_EXT_ONEAPI_RAW_KERNEL_ARG
return true;
#else
return false;
#endif
}
Loading
Loading