Skip to content

Commit da4fb4c

Browse files
[SYCL] Fix depends_on usage with barriers (#18139)
Signed-off-by: Tikhomirova, Kseniya <[email protected]>
1 parent e7c54c3 commit da4fb4c

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed

sycl/source/detail/scheduler/commands.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -3511,6 +3511,18 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
35113511
const AdapterPtr &Adapter = MQueue->getAdapter();
35123512
if (MEvent != nullptr)
35133513
MEvent->setHostEnqueueTime();
3514+
// User can specify explicit dependencies via depends_on call that we should
3515+
// honor here. It is very important for cross queue dependencies. We wait
3516+
// them explicitly since barrier w/o wait list waits for all commands
3517+
// submitted before and we can't add new dependencies to its wait list.
3518+
// Output event for wait operation is not requested since barrier is
3519+
// submitted immediately after and should synchronize it internally.
3520+
if (RawEvents.size()) {
3521+
auto Result = Adapter->call_nocheck<UrApiKind::urEnqueueEventsWait>(
3522+
MQueue->getHandleRef(), RawEvents.size(), &RawEvents[0], nullptr);
3523+
if (Result != UR_RESULT_SUCCESS)
3524+
return Result;
3525+
}
35143526
if (auto Result =
35153527
Adapter->call_nocheck<UrApiKind::urEnqueueEventsWaitWithBarrierExt>(
35163528
MQueue->getHandleRef(), &Properties, 0, nullptr, Event);
@@ -3545,6 +3557,12 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
35453557
const AdapterPtr &Adapter = MQueue->getAdapter();
35463558
if (MEvent != nullptr)
35473559
MEvent->setHostEnqueueTime();
3560+
// User can specify explicit dependencies via depends_on call that we should
3561+
// honor here. It is very important for cross queue dependencies. Adding
3562+
// them to the barrier wait list since barrier w/ wait list waits only for
3563+
// the events provided in wait list and we can just extend the list.
3564+
UrEvents.insert(UrEvents.end(), RawEvents.begin(), RawEvents.end());
3565+
35483566
if (auto Result =
35493567
Adapter->call_nocheck<UrApiKind::urEnqueueEventsWaitWithBarrierExt>(
35503568
MQueue->getHandleRef(), &Properties, UrEvents.size(),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
//==-------- BarrierDependencies.cpp --- Scheduler unit tests --------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "SchedulerTest.hpp"
10+
#include "SchedulerTestUtils.hpp"
11+
12+
#include <helpers/TestKernel.hpp>
13+
#include <helpers/UrMock.hpp>
14+
15+
#include <detail/event_impl.hpp>
16+
17+
#include <gtest/gtest.h>
18+
19+
#include <sycl/sycl.hpp>
20+
21+
using namespace sycl;
22+
23+
std::vector<ur_event_handle_t> EventsInWaitList;
24+
bool EventsWaitVisited = false;
25+
static ur_result_t redefinedEventWait(void *pParams) {
26+
EventsWaitVisited = true;
27+
28+
auto params = *static_cast<ur_enqueue_events_wait_params_t *>(pParams);
29+
for (size_t i = 0; i < *params.pnumEventsInWaitList; ++i)
30+
EventsInWaitList.push_back((*params.pphEventWaitList)[i]);
31+
32+
return UR_RESULT_SUCCESS;
33+
}
34+
35+
std::vector<ur_event_handle_t> BarrierEventsInWaitList;
36+
bool BarrierEventsWaitVisited = false;
37+
ur_result_t redefinedEnqueueEventsWaitWithBarrierExt(void *pParams) {
38+
BarrierEventsWaitVisited = true;
39+
40+
auto params =
41+
*static_cast<ur_enqueue_events_wait_with_barrier_ext_params_t *>(pParams);
42+
for (auto i = 0u; i < *params.pnumEventsInWaitList; i++) {
43+
BarrierEventsInWaitList.push_back((*params.pphEventWaitList)[i]);
44+
}
45+
return UR_RESULT_SUCCESS;
46+
}
47+
48+
void clearGlobals() {
49+
EventsInWaitList.clear();
50+
BarrierEventsInWaitList.clear();
51+
BarrierEventsWaitVisited = false;
52+
EventsWaitVisited = false;
53+
}
54+
55+
TEST_F(SchedulerTest, BarrierWithDependsOn) {
56+
clearGlobals();
57+
58+
sycl::unittest::UrMock<> Mock;
59+
sycl::platform Plt = sycl::platform();
60+
mock::getCallbacks().set_after_callback("urEnqueueEventsWait",
61+
&redefinedEventWait);
62+
mock::getCallbacks().set_after_callback(
63+
"urEnqueueEventsWaitWithBarrierExt",
64+
&redefinedEnqueueEventsWaitWithBarrierExt);
65+
66+
context Ctx{Plt};
67+
queue QueueA{Ctx, default_selector_v, property::queue::in_order()};
68+
queue QueueB{Ctx, default_selector_v, property::queue::in_order()};
69+
70+
auto EventA =
71+
QueueA.submit([&](sycl::handler &h) { h.ext_oneapi_barrier(); });
72+
std::shared_ptr<detail::event_impl> EventAImpl =
73+
detail::getSyclObjImpl(EventA);
74+
// it means that command is enqueued
75+
ASSERT_NE(EventAImpl->getHandle(), nullptr);
76+
77+
ASSERT_FALSE(EventsWaitVisited);
78+
ASSERT_TRUE(BarrierEventsWaitVisited);
79+
ASSERT_EQ(BarrierEventsInWaitList.size(), 0u);
80+
81+
clearGlobals();
82+
auto EventB = QueueB.submit([&](sycl::handler &h) {
83+
h.depends_on(EventA);
84+
h.ext_oneapi_barrier();
85+
});
86+
std::shared_ptr<detail::event_impl> EventBImpl =
87+
detail::getSyclObjImpl(EventB);
88+
// it means that command is enqueued
89+
ASSERT_NE(EventBImpl->getHandle(), nullptr);
90+
91+
ASSERT_TRUE(EventsWaitVisited);
92+
ASSERT_EQ(EventsInWaitList.size(), 1u);
93+
EXPECT_EQ(EventsInWaitList[0], EventAImpl->getHandle());
94+
95+
ASSERT_TRUE(BarrierEventsWaitVisited);
96+
ASSERT_EQ(BarrierEventsInWaitList.size(), 0u);
97+
98+
QueueA.wait();
99+
QueueB.wait();
100+
}
101+
102+
TEST_F(SchedulerTest, BarrierWaitListWithDependsOn) {
103+
clearGlobals();
104+
105+
sycl::unittest::UrMock<> Mock;
106+
sycl::platform Plt = sycl::platform();
107+
mock::getCallbacks().set_after_callback("urEnqueueEventsWait",
108+
&redefinedEventWait);
109+
mock::getCallbacks().set_after_callback(
110+
"urEnqueueEventsWaitWithBarrierExt",
111+
&redefinedEnqueueEventsWaitWithBarrierExt);
112+
113+
context Ctx{Plt};
114+
queue QueueA{Ctx, default_selector_v, property::queue::in_order()};
115+
queue QueueB{Ctx, default_selector_v, property::queue::in_order()};
116+
117+
auto EventA =
118+
QueueA.submit([&](sycl::handler &h) { h.ext_oneapi_barrier(); });
119+
auto EventA2 =
120+
QueueA.submit([&](sycl::handler &h) { h.ext_oneapi_barrier(); });
121+
std::shared_ptr<detail::event_impl> EventAImpl =
122+
detail::getSyclObjImpl(EventA);
123+
std::shared_ptr<detail::event_impl> EventA2Impl =
124+
detail::getSyclObjImpl(EventA2);
125+
// it means that command is enqueued
126+
ASSERT_NE(EventAImpl->getHandle(), nullptr);
127+
ASSERT_NE(EventA2Impl->getHandle(), nullptr);
128+
129+
ASSERT_FALSE(EventsWaitVisited);
130+
ASSERT_TRUE(BarrierEventsWaitVisited);
131+
ASSERT_EQ(BarrierEventsInWaitList.size(), 0u);
132+
133+
clearGlobals();
134+
auto EventB = QueueB.submit([&](sycl::handler &h) {
135+
h.depends_on(EventA);
136+
h.ext_oneapi_barrier({EventA2});
137+
});
138+
std::shared_ptr<detail::event_impl> EventBImpl =
139+
detail::getSyclObjImpl(EventB);
140+
// it means that command is enqueued
141+
ASSERT_NE(EventBImpl->getHandle(), nullptr);
142+
143+
ASSERT_FALSE(EventsWaitVisited);
144+
ASSERT_TRUE(BarrierEventsWaitVisited);
145+
ASSERT_EQ(BarrierEventsInWaitList.size(), 2u);
146+
EXPECT_EQ(BarrierEventsInWaitList[0], EventA2Impl->getHandle());
147+
EXPECT_EQ(BarrierEventsInWaitList[1], EventAImpl->getHandle());
148+
149+
QueueA.wait();
150+
QueueB.wait();
151+
}

sycl/unittests/scheduler/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ add_sycl_unittest(SchedulerTests OBJECT
2121
EnqueueWithDependsOnDeps.cpp
2222
AccessorDefaultCtor.cpp
2323
HostTaskAndBarrier.cpp
24+
BarrierDependencies.cpp
2425
)

0 commit comments

Comments
 (0)