Skip to content

Commit 3017074

Browse files
committed
Add support for raw_kernel_arg extension
1 parent 579d2f8 commit 3017074

17 files changed

+705
-2
lines changed

dpctl/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from ._sycl_platform import SyclPlatform, get_platforms, lsplatform
5050
from ._sycl_queue import (
5151
LocalAccessor,
52+
RawKernelArg,
5253
SyclKernelInvalidRangeError,
5354
SyclKernelSubmitError,
5455
SyclQueue,
@@ -104,6 +105,7 @@
104105
"SyclQueueCreationError",
105106
"WorkGroupMemory",
106107
"LocalAccessor",
108+
"RawKernelArg",
107109
]
108110
__all__ += [
109111
"get_device_cached_queue",

dpctl/_backend.pxd

+15-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ cdef extern from "syclinterface/dpctl_sycl_enum_types.h":
7070
_DOUBLE 'DPCTL_FLOAT64_T',
7171
_VOID_PTR 'DPCTL_VOID_PTR',
7272
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR',
73-
_WORK_GROUP_MEMORY 'DPCTL_WORK_GROUP_MEMORY'
73+
_WORK_GROUP_MEMORY 'DPCTL_WORK_GROUP_MEMORY',
74+
_RAW_KERNEL_ARG 'DPCTL_RAW_KERNEL_ARG'
7475

7576
ctypedef enum _queue_property_type 'DPCTLQueuePropertyType':
7677
_DEFAULT_PROPERTY 'DPCTL_DEFAULT_PROPERTY'
@@ -491,3 +492,16 @@ cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
491492
DPCTLSyclWorkGroupMemoryRef Ref);
492493

493494
cdef bint DPCTLWorkGroupMemory_Available();
495+
496+
cdef struct RawKernelArgDataTy
497+
ctypedef RawKernelArgDataTy RawKernelArgData
498+
499+
cdef struct DPCTLOpaqueRawKernelArg
500+
ctypedef DPCTLOpaqueRawKernelArg *DPCTLSyclRawKernelArgRef;
501+
502+
cdef DPCTLSyclRawKernelArgRef DPCTLRawKernelArg_Create(void* bytes, size_t count);
503+
504+
cdef void DPCTLRawKernelArg_Delete(
505+
DPCTLSyclRawKernelArgRef Ref);
506+
507+
cdef bint DPCTLRawKernelArg_Available();

dpctl/_sycl_queue.pxd

+14
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
""" This file declares the SyclQueue extension type.
2121
"""
2222

23+
from cpython.buffer cimport Py_buffer
2324
from libcpp cimport bool as cpp_bool
2425

2526
from ._backend cimport (
2627
DPCTLSyclDeviceRef,
2728
DPCTLSyclQueueRef,
29+
DPCTLSyclRawKernelArgRef,
2830
DPCTLSyclWorkGroupMemoryRef,
2931
_arg_data_type,
3032
)
@@ -113,3 +115,15 @@ cdef public api class WorkGroupMemory(_WorkGroupMemory) [
113115
object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType
114116
]:
115117
pass
118+
119+
cdef public api class _RawKernelArg [
120+
object Py_RawKernelArgObject, type Py_RawKernelArgType
121+
]:
122+
cdef DPCTLSyclRawKernelArgRef _arg_ref
123+
cdef Py_buffer _buf
124+
cdef bint _is_buf
125+
126+
cdef public api class RawKernelArg(_RawKernelArg) [
127+
object PyRawKernelArgObject, type PyRawKernelArgType
128+
]:
129+
pass

dpctl/_sycl_queue.pyx

+105
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ from ._backend cimport ( # noqa: E211
5151
DPCTLQueue_SubmitNDRange,
5252
DPCTLQueue_SubmitRange,
5353
DPCTLQueue_Wait,
54+
DPCTLRawKernelArg_Available,
55+
DPCTLRawKernelArg_Create,
56+
DPCTLRawKernelArg_Delete,
5457
DPCTLSyclContextRef,
5558
DPCTLSyclDeviceSelectorRef,
5659
DPCTLSyclEventRef,
@@ -353,6 +356,15 @@ cdef class _kernel_arg_type:
353356
_arg_data_type._WORK_GROUP_MEMORY
354357
)
355358

359+
@property
360+
def dpctl_raw_kernel_arg(self):
361+
cdef str p_name = "dpctl_raw_kernel_arg"
362+
return kernel_arg_type_attribute(
363+
self._name,
364+
p_name,
365+
_arg_data_type._RAW_KERNEL_ARG
366+
)
367+
356368

357369
kernel_arg_type = _kernel_arg_type()
358370

@@ -958,6 +970,9 @@ cdef class SyclQueue(_SyclQueue):
958970
elif isinstance(arg, LocalAccessor):
959971
kargs[idx] = <void*>((<LocalAccessor>arg).addressof())
960972
kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
973+
elif isinstance(arg, RawKernelArg):
974+
kargs[idx] = <void*>(<size_t>arg._ref)
975+
kargty[idx] = _arg_data_type._RAW_KERNEL_ARG
961976
else:
962977
ret = -1
963978
return ret
@@ -1719,3 +1734,93 @@ cdef class WorkGroupMemory:
17191734
"""
17201735
def __get__(self):
17211736
return <size_t>self._mem_ref
1737+
1738+
1739+
cdef class _RawKernelArg:
1740+
def __dealloc(self):
1741+
if(self._arg_ref):
1742+
DPCTLRawKernelArg_Delete(self._arg_ref)
1743+
if(self._is_buf):
1744+
PyBuffer_Release(&(self._buf))
1745+
1746+
cdef class RawKernelArg:
1747+
"""
1748+
RawKernelArg(*args)
1749+
Python class representing the ``raw_kernel_arg`` class from the Raw Kernel
1750+
Argument oneAPI SYCL extension for passing binary data as data to kernels.
1751+
1752+
This class is intended to be used as kernel argument when launching kernels.
1753+
1754+
This is based on a DPC++ SYCL extension and only available in newer
1755+
versions. Use ``is_available()`` to check availability in your build.
1756+
1757+
There are multiple ways to create a ``RawKernelArg``.
1758+
1759+
- If the constructor is invoked with just a single argument, this argument
1760+
is expected to expose the Python buffer interface. The raw kernel arg will
1761+
be constructed from the data in that buffer.
1762+
1763+
- If the constructor is invoked with two arguments, the first argument is
1764+
interpreted as the number of bytes in the binary argument, while the
1765+
second argument is interpreted as a pointer to the data. Note that the
1766+
raw kernel arg does not own or copy the data, so the pointed-to object
1767+
must be kept alive by the user until kernel launch.
1768+
1769+
Args:
1770+
args:
1771+
Variadic argument, see class documentation.
1772+
1773+
Raises:
1774+
TypeError: In case of incorrect arguments given to constructurs,
1775+
unexpected types of input arguments.
1776+
"""
1777+
def __cinit__(self, *args):
1778+
cdef void* ptr = NULL
1779+
cdef size_t count
1780+
cdef int ret_code = 0
1781+
1782+
if not DPCTLRawKernelArg_Available():
1783+
raise RuntimeError("Raw kernel arg extension not available")
1784+
1785+
if not (0 < len(args) < 3):
1786+
raise TypeError("RawKernelArg constructor takes 1 or 2 "
1787+
f"arguments, but {len(args)} were given")
1788+
1789+
if len(args) == 1:
1790+
if not _is_buffer(args[0]):
1791+
raise TypeError("RawKernelArg single argument constructor"
1792+
"expects argument to be buffer",
1793+
f"but got {type(args[0])}")
1794+
1795+
ret_code = PyObject_GetBuffer(args[0], &(self._buf), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
1796+
if ret_code != 0: # pragma: no cover
1797+
raise RuntimeError("Could not access buffer")
1798+
1799+
ptr = self._buf.buf
1800+
count = self._buf.len
1801+
self._is_buf = True
1802+
else:
1803+
if not isinstance(args[0], numbers.Integral):
1804+
raise TypeError("RawKernelArg constructor expects first"
1805+
"argument to be `int`, but got {type(args[0])}")
1806+
if not isinstance(args[1], numbers.Integral):
1807+
raise TypeError("RawKernelArg constructor expects second"
1808+
"argument to be `int`, but got {type(args[1])}")
1809+
1810+
self._is_buf = False
1811+
count = args[0]
1812+
ptr = <void*>(<unsigned long long>args[1])
1813+
1814+
self._arg_ref = DPCTLRawKernelArg_Create(ptr, count)
1815+
1816+
"""Check whether the raw_kernel_arg extension is available"""
1817+
@staticmethod
1818+
def is_available():
1819+
return DPCTLRawKernelArg_Available();
1820+
1821+
property _ref:
1822+
"""Returns the address of the C API ``DPCTLRawKernelArgRef`` pointer
1823+
as a ``size_t``.
1824+
"""
1825+
def __get__(self):
1826+
return <size_t>self._arg_ref

dpctl/sycl.pxd

+11
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ cdef extern from "sycl/sycl.hpp" namespace "sycl":
4545
cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
4646
cdef struct RawWorkGroupMemoryTy
4747
ctypedef RawWorkGroupMemoryTy RawWorkGroupMemory
48+
cdef struct RawKernelArgDataTy
49+
ctypedef RawKernelArgDataTy RawKernelArgData
4850

4951
cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
5052
namespace "dpctl::syclinterface":
@@ -80,3 +82,12 @@ cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
8082
cdef RawWorkGroupMemory * unwrap_work_group_memory \
8183
"dpctl::syclinterface::unwrap<RawWorkGroupMemory>" (
8284
dpctl_backend.DPCTLSyclWorkGroupMemoryRef)
85+
86+
# raw kernel arg extension
87+
cdef dpctl_backend.DPCTLSyclRawKernelArgRef wrap_raw_kernel_arg \
88+
"dpctl::syclinterface::wrap<RawKernelArgData>" \
89+
(const RawKernelArgData *)
90+
91+
cdef RawKernelArgData * unwrap_raw_kernel_arg \
92+
"dpctl::syclinterface::unwrap<RawKernelArgData>" (
93+
dpctl_backend.DPCTLSyclRawKernelArgRef)
1.6 KB
Binary file not shown.

dpctl/tests/test_raw_kernel_arg.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Defines unit test cases for the work_group_memory in a SYCL kernel"""
18+
19+
import ctypes
20+
import os
21+
22+
import pytest
23+
24+
import dpctl
25+
import dpctl.tensor
26+
27+
28+
def get_spirv_abspath(fn):
29+
curr_dir = os.path.dirname(os.path.abspath(__file__))
30+
spirv_file = os.path.join(curr_dir, "input_files", fn)
31+
return spirv_file
32+
33+
34+
# The kernel in the SPIR-V file used in this test was generated from the
35+
# following SYCL source code:
36+
# #include <sycl/sycl.hpp>
37+
#
38+
# using namespace sycl;
39+
#
40+
# namespace syclexp = sycl::ext::oneapi::experimental;
41+
# namespace syclext = sycl::ext::oneapi;
42+
#
43+
# using data_t = int32_t;
44+
#
45+
# struct Params { data_t mul; data_t add; };
46+
#
47+
# extern "C" SYCL_EXTERNAL
48+
# SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
49+
# void raw_arg_kernel(data_t* in, data_t* out, Params p){
50+
# auto item = syclext::this_work_item::get_nd_item<1>();
51+
# size_t global_id = item.get_global_linear_id();
52+
# out[global_id] = (in[global_id] * p.mul) + p.add;
53+
# }
54+
55+
56+
class Params(ctypes.Structure):
57+
_fields_ = [("mul", ctypes.c_int32), ("add", ctypes.c_int32)]
58+
59+
60+
def launch_raw_arg_kernel(raw):
61+
if not dpctl.RawKernelArg.is_available():
62+
pytest.skip("Raw kernel arg extension not supported")
63+
64+
try:
65+
q = dpctl.SyclQueue("level_zero")
66+
except dpctl.SyclQueueCreationError:
67+
pytest.skip("LevelZero queue could not be created")
68+
spirv_file = get_spirv_abspath("raw-arg-kernel.spv")
69+
with open(spirv_file, "br") as spv:
70+
spv_bytes = spv.read()
71+
prog = dpctl.program.create_program_from_spirv(q, spv_bytes)
72+
kernel = prog.get_sycl_kernel("__sycl_kernel_raw_arg_kernel")
73+
local_size = 16
74+
global_size = local_size * 8
75+
76+
x = dpctl.tensor.ones(global_size, dtype="int32")
77+
y = dpctl.tensor.zeros(global_size, dtype="int32")
78+
x.sycl_queue.wait()
79+
y.sycl_queue.wait()
80+
81+
try:
82+
q.submit(
83+
kernel,
84+
[
85+
x.usm_data,
86+
y.usm_data,
87+
raw,
88+
],
89+
[global_size],
90+
[local_size],
91+
)
92+
q.wait()
93+
except dpctl._sycl_queue.SyclKernelSubmitError:
94+
pytest.skip(f"Kernel submission to {q.sycl_device} failed")
95+
96+
assert dpctl.tensor.all(y == 9)
97+
98+
99+
def test_submit_raw_kernel_arg_pointer():
100+
paramStruct = Params(4, 5)
101+
raw = dpctl.RawKernelArg(
102+
ctypes.sizeof(paramStruct), ctypes.addressof(paramStruct)
103+
)
104+
launch_raw_arg_kernel(raw)
105+
106+
107+
def test_submit_raw_kernel_arg_buffer():
108+
paramStruct = Params(4, 5)
109+
byteArr = bytearray(paramStruct)
110+
raw = dpctl.RawKernelArg(byteArr)
111+
del byteArr
112+
launch_raw_arg_kernel(raw)

dpctl/tests/test_sycl_kernel_submit.py

+1
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def test_kernel_arg_type():
280280
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
281281
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)
282282
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory)
283+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_raw_kernel_arg)
283284

284285

285286
def get_spirv_abspath(fn):

libsyclinterface/include/syclinterface/dpctl_sycl_enum_types.h

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ typedef enum
101101
DPCTL_VOID_PTR,
102102
DPCTL_LOCAL_ACCESSOR,
103103
DPCTL_WORK_GROUP_MEMORY,
104+
DPCTL_RAW_KERNEL_ARG,
104105
DPCTL_UNSUPPORTED_KERNEL_ARG
105106
} DPCTLKernelArgType;
106107

libsyclinterface/include/syclinterface/dpctl_sycl_extension_interface.h

+18
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,22 @@ void DPCTLWorkGroupMemory_Delete(__dpctl_take DPCTLSyclWorkGroupMemoryRef Ref);
5353
DPCTL_API
5454
bool DPCTLWorkGroupMemory_Available();
5555

56+
typedef struct RawKernelArgDataTy
57+
{
58+
void *bytes;
59+
size_t count;
60+
} RawKernelArgData;
61+
62+
typedef struct DPCTLOpaqueSyclRawKernelArg *DPCTLSyclRawKernelArgRef;
63+
64+
DPCTL_API
65+
__dpctl_give DPCTLSyclRawKernelArgRef DPCTLRawKernelArg_Create(void *bytes,
66+
size_t count);
67+
68+
DPCTL_API
69+
void DPCTLRawKernelArg_Delete(__dpctl_take DPCTLSyclRawKernelArgRef Ref);
70+
71+
DPCTL_API
72+
bool DPCTLRawKernelArg_Available();
73+
5674
DPCTL_C_EXTERN_C_END

libsyclinterface/include/syclinterface/dpctl_sycl_type_casters.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ DEFINE_SIMPLE_CONVERSION_FUNCTIONS(std::vector<DPCTLSyclEventRef>,
8484
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(RawWorkGroupMemory,
8585
DPCTLSyclWorkGroupMemoryRef)
8686

87+
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(RawKernelArgData, DPCTLSyclRawKernelArgRef)
88+
8789
#endif
8890

8991
} // namespace dpctl::syclinterface

0 commit comments

Comments
 (0)