Skip to content

Commit 1ea6b9c

Browse files
committed
Address PR feedback
1 parent 3017074 commit 1ea6b9c

File tree

3 files changed

+22
-15
lines changed

3 files changed

+22
-15
lines changed

dpctl/_sycl_queue.pxd

-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
""" This file declares the SyclQueue extension type.
2121
"""
2222

23-
from cpython.buffer cimport Py_buffer
2423
from libcpp cimport bool as cpp_bool
2524

2625
from ._backend cimport (
@@ -120,8 +119,6 @@ cdef public api class _RawKernelArg [
120119
object Py_RawKernelArgObject, type Py_RawKernelArgType
121120
]:
122121
cdef DPCTLSyclRawKernelArgRef _arg_ref
123-
cdef Py_buffer _buf
124-
cdef bint _is_buf
125122

126123
cdef public api class RawKernelArg(_RawKernelArg) [
127124
object PyRawKernelArgObject, type PyRawKernelArgType

dpctl/_sycl_queue.pyx

+16-10
Original file line numberDiff line numberDiff line change
@@ -1740,8 +1740,7 @@ cdef class _RawKernelArg:
17401740
def __dealloc(self):
17411741
if(self._arg_ref):
17421742
DPCTLRawKernelArg_Delete(self._arg_ref)
1743-
if(self._is_buf):
1744-
PyBuffer_Release(&(self._buf))
1743+
17451744

17461745
cdef class RawKernelArg:
17471746
"""
@@ -1762,9 +1761,11 @@ cdef class RawKernelArg:
17621761
17631762
- If the constructor is invoked with two arguments, the first argument is
17641763
interpreted as the number of bytes in the binary argument, while the
1765-
second argument is interpreted as a pointer to the data. Note that the
1766-
raw kernel arg does not own or copy the data, so the pointed-to object
1767-
must be kept alive by the user until kernel launch.
1764+
second argument is interpreted as a pointer to the data.
1765+
1766+
Note that construction of the ``RawKernelArg`` copies the bytes, so
1767+
modifications made after construction of the ``RawKernelArg`` will not be
1768+
reflected in the kernel launch.
17681769
17691770
Args:
17701771
args:
@@ -1778,6 +1779,8 @@ cdef class RawKernelArg:
17781779
cdef void* ptr = NULL
17791780
cdef size_t count
17801781
cdef int ret_code = 0
1782+
cdef Py_buffer _buffer
1783+
cdef bint _is_buf
17811784

17821785
if not DPCTLRawKernelArg_Available():
17831786
raise RuntimeError("Raw kernel arg extension not available")
@@ -1792,13 +1795,13 @@ cdef class RawKernelArg:
17921795
"expects argument to be buffer",
17931796
f"but got {type(args[0])}")
17941797

1795-
ret_code = PyObject_GetBuffer(args[0], &(self._buf), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
1798+
ret_code = PyObject_GetBuffer(args[0], &(_buffer), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
17961799
if ret_code != 0: # pragma: no cover
17971800
raise RuntimeError("Could not access buffer")
17981801

1799-
ptr = self._buf.buf
1800-
count = self._buf.len
1801-
self._is_buf = True
1802+
ptr = _buffer.buf
1803+
count = _buffer.len
1804+
_is_buf = True
18021805
else:
18031806
if not isinstance(args[0], numbers.Integral):
18041807
raise TypeError("RawKernelArg constructor expects first"
@@ -1807,11 +1810,14 @@ cdef class RawKernelArg:
18071810
raise TypeError("RawKernelArg constructor expects second"
18081811
"argument to be `int`, but got {type(args[1])}")
18091812

1810-
self._is_buf = False
1813+
_is_buf = False
18111814
count = args[0]
18121815
ptr = <void*>(<unsigned long long>args[1])
18131816

18141817
self._arg_ref = DPCTLRawKernelArg_Create(ptr, count)
1818+
if(_is_buf):
1819+
PyBuffer_Release(&(_buffer))
1820+
18151821

18161822
"""Check whether the raw_kernel_arg extension is available"""
18171823
@staticmethod

libsyclinterface/source/dpctl_sycl_extension_interface.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,12 @@ __dpctl_give DPCTLSyclRawKernelArgRef DPCTLRawKernelArg_Create(void *bytes,
6969
{
7070
DPCTLSyclRawKernelArgRef rka = nullptr;
7171
try {
72-
auto RawKernelArg = new RawKernelArgData{bytes, count};
73-
rka = wrap<RawKernelArgData>(RawKernelArg);
72+
auto rawData =
73+
std::unique_ptr<unsigned char[]>(new unsigned char[count]);
74+
std::memcpy(rawData.get(), bytes, count);
75+
auto RawKernelArg = std::unique_ptr<RawKernelArgData>(
76+
new RawKernelArgData{rawData.release(), count});
77+
rka = wrap<RawKernelArgData>(RawKernelArg.release());
7478
} catch (std::exception const &e) {
7579
error_handler(e, __FILE__, __func__, __LINE__);
7680
}

0 commit comments

Comments
 (0)