diff --git a/include/pybind11/critical_section.h b/include/pybind11/critical_section.h index e94ca765cb..2d26412047 100644 --- a/include/pybind11/critical_section.h +++ b/include/pybind11/critical_section.h @@ -13,24 +13,30 @@ PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) class scoped_critical_section { public: #ifdef Py_GIL_DISABLED - explicit scoped_critical_section(handle obj) : has2(false) { - PyCriticalSection_Begin(§ion, obj.ptr()); - } - - scoped_critical_section(handle obj1, handle obj2) : has2(true) { - PyCriticalSection2_Begin(§ion2, obj1.ptr(), obj2.ptr()); + explicit scoped_critical_section(handle obj1, handle obj2 = handle{}) { + if (obj1) { + if (obj2) { + PyCriticalSection2_Begin(§ion2, obj1.ptr(), obj2.ptr()); + rank = 2; + } else { + PyCriticalSection_Begin(§ion, obj1.ptr()); + rank = 1; + } + } else if (obj2) { + PyCriticalSection_Begin(§ion, obj2.ptr()); + rank = 1; + } } ~scoped_critical_section() { - if (has2) { - PyCriticalSection2_End(§ion2); - } else { + if (rank == 1) { PyCriticalSection_End(§ion); + } else if (rank == 2) { + PyCriticalSection2_End(§ion2); } } #else - explicit scoped_critical_section(handle) {}; - scoped_critical_section(handle, handle) {}; + explicit scoped_critical_section(handle, handle = handle{}) {}; ~scoped_critical_section() = default; #endif @@ -39,7 +45,7 @@ class scoped_critical_section { private: #ifdef Py_GIL_DISABLED - bool has2; + int rank{0}; union { PyCriticalSection section; PyCriticalSection2 section2; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0e76e68786..2cf18c3547 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -166,6 +166,7 @@ set(PYBIND11_TEST_FILES test_potentially_slicing_weak_ptr test_python_multiple_inheritance test_pytypes + test_scoped_critical_section test_sequences_and_iterators test_smart_ptr test_stl diff --git a/tests/test_scoped_critical_section.cpp b/tests/test_scoped_critical_section.cpp new file mode 100644 index 0000000000..dc9a69e039 --- /dev/null +++ b/tests/test_scoped_critical_section.cpp @@ -0,0 +1,275 @@ +#include + +#include "pybind11_tests.h" + +#include +#include +#include +#include + +#if defined(PYBIND11_CPP20) && defined(__has_include) && __has_include() +# define PYBIND11_HAS_BARRIER 1 +# include +#endif + +namespace test_scoped_critical_section_ns { + +void test_one_nullptr() { py::scoped_critical_section lock{py::handle{}}; } + +void test_two_nullptrs() { py::scoped_critical_section lock{py::handle{}, py::handle{}}; } + +void test_first_nullptr() { + py::dict d; + py::scoped_critical_section lock{py::handle{}, d}; +} + +void test_second_nullptr() { + py::dict d; + py::scoped_critical_section lock{d, py::handle{}}; +} + +// Referenced test implementation: https://github.com/PyO3/pyo3/blob/v0.25.0/src/sync.rs#L874 +class BoolWrapper { +public: + explicit BoolWrapper(bool value) : value_{value} {} + bool get() const { return value_.load(std::memory_order_acquire); } + void set(bool value) { value_.store(value, std::memory_order_release); } + +private: + std::atomic_bool value_{false}; +}; + +#if defined(PYBIND11_HAS_BARRIER) + +// Modifying the C/C++ members of a Python object from multiple threads requires a critical section +// to ensure thread safety and data integrity. +// These tests use a scoped critical section to ensure that the Python object is accessed in a +// thread-safe manner. + +void test_scoped_critical_section(const py::handle &cls) { + auto barrier = std::barrier(2); + auto bool_wrapper = cls(false); + bool output = false; + + { + // Release the GIL to allow run threads in parallel. + py::gil_scoped_release gil_release{}; + + std::thread t1([&]() { + // Use gil_scoped_acquire to ensure we have a valid Python thread state + // before entering the critical section. Otherwise, the critical section + // will cause a segmentation fault. + py::gil_scoped_acquire ensure_tstate{}; + // Enter the critical section with the same object as the second thread. + py::scoped_critical_section lock{bool_wrapper}; + // At this point, the object is locked by this thread via the scoped_critical_section. + // This barrier will ensure that the second thread waits until this thread has released + // the critical section before proceeding. + barrier.arrive_and_wait(); + // Sleep for a short time to simulate some work in the critical section. + // This sleep is necessary to test the locking mechanism properly. + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + auto *bw = bool_wrapper.cast(); + bw->set(true); + }); + + std::thread t2([&]() { + // This thread will wait until the first thread has entered the critical section due to + // the barrier. + barrier.arrive_and_wait(); + { + // Use gil_scoped_acquire to ensure we have a valid Python thread state + // before entering the critical section. Otherwise, the critical section + // will cause a segmentation fault. + py::gil_scoped_acquire ensure_tstate{}; + // Enter the critical section with the same object as the first thread. + py::scoped_critical_section lock{bool_wrapper}; + // At this point, the critical section is released by the first thread, the value + // is set to true. + auto *bw = bool_wrapper.cast(); + output = bw->get(); + } + }); + + t1.join(); + t2.join(); + } + + if (!output) { + throw std::runtime_error("Scoped critical section test failed: output is false"); + } +} + +void test_scoped_critical_section2(const py::handle &cls) { + auto barrier = std::barrier(3); + auto bool_wrapper1 = cls(false); + auto bool_wrapper2 = cls(false); + std::pair output{false, false}; + + { + // Release the GIL to allow run threads in parallel. + py::gil_scoped_release gil_release{}; + + std::thread t1([&]() { + // Use gil_scoped_acquire to ensure we have a valid Python thread state + // before entering the critical section. Otherwise, the critical section + // will cause a segmentation fault. + py::gil_scoped_acquire ensure_tstate{}; + // Enter the critical section with two different objects. + // This will ensure that the critical section is locked for both objects. + py::scoped_critical_section lock{bool_wrapper1, bool_wrapper2}; + // At this point, objects are locked by this thread via the scoped_critical_section. + // This barrier will ensure that other threads wait until this thread has released + // the critical section before proceeding. + barrier.arrive_and_wait(); + // Sleep for a short time to simulate some work in the critical section. + // This sleep is necessary to test the locking mechanism properly. + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + auto *bw1 = bool_wrapper1.cast(); + auto *bw2 = bool_wrapper2.cast(); + bw1->set(true); + bw2->set(true); + }); + + std::thread t2([&]() { + // This thread will wait until the first thread has entered the critical section due to + // the barrier. + barrier.arrive_and_wait(); + { + // Use gil_scoped_acquire to ensure we have a valid Python thread state + // before entering the critical section. Otherwise, the critical section + // will cause a segmentation fault. + py::gil_scoped_acquire ensure_tstate{}; + // Enter the critical section with the same object as the first thread. + py::scoped_critical_section lock{bool_wrapper1}; + // At this point, the critical section is released by the first thread, the value + // is set to true. + auto *bw1 = bool_wrapper1.cast(); + output.first = bw1->get(); + } + }); + + std::thread t3([&]() { + // This thread will wait until the first thread has entered the critical section due to + // the barrier. + barrier.arrive_and_wait(); + { + // Use gil_scoped_acquire to ensure we have a valid Python thread state + // before entering the critical section. Otherwise, the critical section + // will cause a segmentation fault. + py::gil_scoped_acquire ensure_tstate{}; + // Enter the critical section with the same object as the first thread. + py::scoped_critical_section lock{bool_wrapper2}; + // At this point, the critical section is released by the first thread, the value + // is set to true. + auto *bw2 = bool_wrapper2.cast(); + output.second = bw2->get(); + } + }); + + t1.join(); + t2.join(); + t3.join(); + } + + if (!output.first || !output.second) { + throw std::runtime_error( + "Scoped critical section test with two objects failed: output is false"); + } +} + +void test_scoped_critical_section2_same_object_no_deadlock(const py::handle &cls) { + auto barrier = std::barrier(2); + auto bool_wrapper = cls(false); + bool output = false; + + { + // Release the GIL to allow run threads in parallel. + py::gil_scoped_release gil_release{}; + + std::thread t1([&]() { + // Use gil_scoped_acquire to ensure we have a valid Python thread state + // before entering the critical section. Otherwise, the critical section + // will cause a segmentation fault. + py::gil_scoped_acquire ensure_tstate{}; + // Enter the critical section with the same object as the second thread. + py::scoped_critical_section lock{bool_wrapper, bool_wrapper}; // same object used here + // At this point, the object is locked by this thread via the scoped_critical_section. + // This barrier will ensure that the second thread waits until this thread has released + // the critical section before proceeding. + barrier.arrive_and_wait(); + // Sleep for a short time to simulate some work in the critical section. + // This sleep is necessary to test the locking mechanism properly. + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + auto *bw = bool_wrapper.cast(); + bw->set(true); + }); + + std::thread t2([&]() { + // This thread will wait until the first thread has entered the critical section due to + // the barrier. + barrier.arrive_and_wait(); + { + // Use gil_scoped_acquire to ensure we have a valid Python thread state + // before entering the critical section. Otherwise, the critical section + // will cause a segmentation fault. + py::gil_scoped_acquire ensure_tstate{}; + // Enter the critical section with the same object as the first thread. + py::scoped_critical_section lock{bool_wrapper}; + // At this point, the critical section is released by the first thread, the value + // is set to true. + auto *bw = bool_wrapper.cast(); + output = bw->get(); + } + }); + + t1.join(); + t2.join(); + } + + if (!output) { + throw std::runtime_error( + "Scoped critical section test with same object failed: output is false"); + } +} + +#else + +void test_scoped_critical_section(const py::handle &) {} +void test_scoped_critical_section2(const py::handle &) {} +void test_scoped_critical_section2_same_object_no_deadlock(const py::handle &) {} + +#endif + +} // namespace test_scoped_critical_section_ns + +TEST_SUBMODULE(scoped_critical_section, m) { + using namespace test_scoped_critical_section_ns; + + m.def("test_one_nullptr", test_one_nullptr); + m.def("test_two_nullptrs", test_two_nullptrs); + m.def("test_first_nullptr", test_first_nullptr); + m.def("test_second_nullptr", test_second_nullptr); + + auto BoolWrapperClass = py::class_(m, "BoolWrapper") + .def(py::init()) + .def("get", &BoolWrapper::get) + .def("set", &BoolWrapper::set); + auto BoolWrapperHandle = py::handle(BoolWrapperClass); + (void) BoolWrapperHandle.ptr(); // suppress unused variable warning + + m.attr("has_barrier") = +#ifdef PYBIND11_HAS_BARRIER + true; +#else + false; +#endif + + m.def("test_scoped_critical_section", + [BoolWrapperHandle]() -> void { test_scoped_critical_section(BoolWrapperHandle); }); + m.def("test_scoped_critical_section2", + [BoolWrapperHandle]() -> void { test_scoped_critical_section2(BoolWrapperHandle); }); + m.def("test_scoped_critical_section2_same_object_no_deadlock", [BoolWrapperHandle]() -> void { + test_scoped_critical_section2_same_object_no_deadlock(BoolWrapperHandle); + }); +} diff --git a/tests/test_scoped_critical_section.py b/tests/test_scoped_critical_section.py new file mode 100644 index 0000000000..5703e3d51c --- /dev/null +++ b/tests/test_scoped_critical_section.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import pytest + +from pybind11_tests import scoped_critical_section as m + + +def test_nullptr_combinations(): + m.test_one_nullptr() + m.test_two_nullptrs() + m.test_first_nullptr() + m.test_second_nullptr() + + +@pytest.mark.skipif(not m.has_barrier, reason="no ") +def test_scoped_critical_section() -> None: + for _ in range(64): + m.test_scoped_critical_section() + + +@pytest.mark.skipif(not m.has_barrier, reason="no ") +def test_scoped_critical_section2() -> None: + for _ in range(64): + m.test_scoped_critical_section2() + + +@pytest.mark.skipif(not m.has_barrier, reason="no ") +def test_scoped_critical_section2_same_object_no_deadlock() -> None: + for _ in range(64): + m.test_scoped_critical_section2_same_object_no_deadlock()