Skip to content

Commit 1be61e5

Browse files
authored
[SYCL][Graph] Bugfix: Keep handle to last recorded queue after cleanup (#20831)
`MRecordingQueues` is cleared after recording is completed. Which prevents optimization of reusing an available queue for an executable graph object, instead of creating a placeholder queue.
1 parent 2c32865 commit 1be61e5

File tree

4 files changed

+51
-8
lines changed

4 files changed

+51
-8
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -545,15 +545,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
545545
return NodeImpl;
546546
}
547547

548-
std::shared_ptr<sycl::detail::queue_impl> graph_impl::getQueue() const {
549-
std::shared_ptr<sycl::detail::queue_impl> Return{};
550-
if (!MRecordingQueues.empty())
551-
Return = MRecordingQueues.begin()->lock();
552-
return Return;
548+
std::shared_ptr<sycl::detail::queue_impl>
549+
graph_impl::getLastRecordedQueue() const {
550+
return MLastRecordedQueue.lock();
553551
}
554552

555553
void graph_impl::addQueue(sycl::detail::queue_impl &RecordingQueue) {
556-
MRecordingQueues.insert(RecordingQueue.weak_from_this());
554+
MLastRecordedQueue = RecordingQueue.weak_from_this();
555+
MRecordingQueues.insert(MLastRecordedQueue);
557556
}
558557

559558
void graph_impl::removeQueue(sycl::detail::queue_impl &RecordingQueue) {
@@ -932,7 +931,7 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
932931
// Copy nodes from GraphImpl and merge any subgraph nodes into this graph.
933932
duplicateNodes();
934933

935-
if (auto PlaceholderQueuePtr = GraphImpl->getQueue()) {
934+
if (auto PlaceholderQueuePtr = GraphImpl->getLastRecordedQueue()) {
936935
MQueueImpl = std::move(PlaceholderQueuePtr);
937936
} else {
938937
MQueueImpl = sycl::detail::queue_impl::create(

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
176176
node_impl &add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
177177
nodes_range Deps);
178178

179-
std::shared_ptr<sycl::detail::queue_impl> getQueue() const;
179+
/// Get queue that was last recorded from.
180+
/// @ return Queue that started last recording into associated graph.
181+
std::shared_ptr<sycl::detail::queue_impl> getLastRecordedQueue() const;
180182

181183
/// Add a queue to the set of queues which are currently recording to this
182184
/// graph.
@@ -558,6 +560,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
558560
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>;
559561
/// Unique set of queues which are currently recording to this graph.
560562
RecQueuesStorage MRecordingQueues;
563+
/// Queue that has been last recorded from.
564+
std::weak_ptr<sycl::detail::queue_impl> MLastRecordedQueue;
561565
/// Map of events to their associated recorded nodes.
562566
std::unordered_map<std::shared_ptr<sycl::detail::event_impl>, node_impl *>
563567
MEventsMap;

sycl/unittests/Extensions/CommandGraph/Common.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class GraphImplTest {
4646
static int NumSyncPoints(const exec_graph_impl &Impl) {
4747
return Impl.MSyncPoints.size();
4848
}
49+
static std::shared_ptr<sycl::detail::queue_impl>
50+
GetQueueImpl(const exec_graph_impl &Impl) {
51+
return Impl.MQueueImpl;
52+
}
4953
};
5054

5155
// Common Test fixture

sycl/unittests/Extensions/CommandGraph/Regressions.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,39 @@ TEST_F(CommandGraphTest, QueueRecordBarrierMultipleGraph) {
9494
Queue.ext_oneapi_submit_barrier();
9595
GraphC.end_recording(Queue);
9696
}
97+
98+
// Test that the last recorded queue is preserved after cleanup.
99+
// This is a regression test for a bug where getLastRecordedQueue() would
100+
// return nullptr after the recording queues were cleaned up, because the
101+
// previous implementation (getQueue()) looked in the MRecordingQueues set
102+
// which gets cleared on end_recording(). The fix introduces MLastRecordedQueue
103+
// which persists even after cleanup, allowing the executable graph to retrieve
104+
// the queue that was used for recording.
105+
TEST_F(CommandGraphTest, LastRecordedQueueAfterCleanup) {
106+
// Record some work to the graph
107+
Graph.begin_recording(Queue);
108+
Queue.submit(
109+
[&](sycl::handler &cgh) { cgh.single_task<TestKernel>([]() {}); });
110+
Graph.end_recording(Queue);
111+
112+
// Get the graph implementation to check internal state
113+
auto GraphImpl = getSyclObjImpl(Graph);
114+
115+
// getLastRecordedQueue() should return the queue that was used for recording
116+
// even after end_recording() has cleared the recording queues
117+
auto LastQueue = GraphImpl->getLastRecordedQueue();
118+
EXPECT_NE(LastQueue, nullptr);
119+
EXPECT_EQ(LastQueue, getSyclObjImpl(Queue));
120+
121+
// Finalize the graph - this uses getLastRecordedQueue() internally
122+
// to set up the executable graph's queue. Before the fix, this could fail
123+
// if getLastRecordedQueue() returned nullptr.
124+
auto GraphExec = Graph.finalize();
125+
experimental::detail::exec_graph_impl &ExecGraphImpl =
126+
*getSyclObjImpl(GraphExec);
127+
128+
// The executable graph should have the queue from recording
129+
auto ExecQueueImpl = GraphImplTest::GetQueueImpl(ExecGraphImpl);
130+
EXPECT_NE(ExecQueueImpl, nullptr);
131+
EXPECT_EQ(ExecQueueImpl, getSyclObjImpl(Queue));
132+
}

0 commit comments

Comments
 (0)