Skip to content

Commit c534114

Browse files
committed
Implement Dynamic Local Accessors
1 parent 9d085d7 commit c534114

File tree

9 files changed

+502
-2
lines changed

9 files changed

+502
-2
lines changed

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

+50
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ class node_impl;
100100
class graph_impl;
101101
class exec_graph_impl;
102102
class dynamic_parameter_impl;
103+
//template <typename DataT, int Dimensions>
104+
//class dynamic_local_accessor_impl;
103105
class dynamic_command_group_impl;
104106
} // namespace detail
105107

@@ -484,6 +486,11 @@ class command_graph<graph_state::executable>
484486
namespace detail {
485487
class __SYCL_EXPORT dynamic_parameter_base {
486488
public:
489+
490+
dynamic_parameter_base(
491+
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
492+
Graph);
493+
487494
dynamic_parameter_base(
488495
sycl::ext::oneapi::experimental::command_graph<graph_state::modifiable>
489496
Graph,
@@ -498,6 +505,13 @@ class __SYCL_EXPORT dynamic_parameter_base {
498505
void updateValue(const raw_kernel_arg *NewRawValue, size_t Size);
499506

500507
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc);
508+
509+
sycl::detail::LocalAccessorImplPtr getLocalAccessor(handler* Handler);
510+
511+
void registerLocalAccessor(sycl::detail::LocalAccessorBaseHost* LocalAccBaseHost, handler* Handler);
512+
513+
void updateLocalAccessor(range<3> NewAllocationSize);
514+
501515
std::shared_ptr<dynamic_parameter_impl> impl;
502516

503517
template <class Obj>
@@ -535,6 +549,42 @@ class dynamic_parameter : public detail::dynamic_parameter_base {
535549
}
536550
};
537551

552+
template <typename DataT, int Dimensions = 1>
553+
class dynamic_local_accessor : public detail::dynamic_parameter_base {
554+
public:
555+
template <int Dims = Dimensions, typename = std::enable_if_t<(Dims > 0)>>
556+
dynamic_local_accessor(command_graph<graph_state::modifiable> Graph,
557+
range<Dimensions> AllocationSize,
558+
const property_list &PropList = {})
559+
: detail::dynamic_parameter_base(Graph), AllocationSize(AllocationSize) {
560+
(void)PropList;
561+
}
562+
563+
void update(range<Dimensions> NewAllocationSize) {
564+
detail::dynamic_parameter_base::updateLocalAccessor(
565+
::sycl::detail::convertToArrayOfN<3, 1>(NewAllocationSize));
566+
};
567+
568+
local_accessor<DataT, Dimensions> get(handler &CGH) {
569+
#ifndef __SYCL_DEVICE_ONLY__
570+
::sycl::detail::LocalAccessorImplPtr BaseLocalAcc = getLocalAccessor(&CGH);
571+
if (BaseLocalAcc) {
572+
return sycl::detail::createSyclObjFromImpl<local_accessor<DataT, Dimensions>>(BaseLocalAcc);
573+
} else {
574+
local_accessor<DataT, Dimensions> LocalAccessor(AllocationSize, CGH);
575+
registerLocalAccessor(
576+
static_cast<sycl::detail::LocalAccessorBaseHost *>(&LocalAccessor), &CGH);
577+
return LocalAccessor;
578+
}
579+
#else
580+
return local_accessor<DataT, Dimensions>();
581+
#endif
582+
};
583+
584+
private:
585+
range<Dimensions> AllocationSize;
586+
};
587+
538588
/// Additional CTAD deduction guides.
539589
template <typename ValueT>
540590
dynamic_parameter(experimental::command_graph<graph_state::modifiable> Graph,

sycl/include/sycl/handler.hpp

+23
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,22 @@ class __SYCL_EXPORT handler {
647647
registerDynamicParameter(DynamicParam, ArgIndex);
648648
}
649649

650+
// setArgHelper for graph dynamic_local_accessors.
651+
template <typename DataT, int Dims>
652+
void
653+
setArgHelper(int ArgIndex,
654+
ext::oneapi::experimental::dynamic_local_accessor<DataT, Dims>
655+
&DynamicLocalAccessor) {
656+
#ifndef __SYCL_DEVICE_ONLY__
657+
auto LocalAccessor = DynamicLocalAccessor.get(*this);
658+
setArgHelper(ArgIndex, LocalAccessor);
659+
registerDynamicParameter(DynamicLocalAccessor, ArgIndex);
660+
#else
661+
(void)ArgIndex;
662+
(void)DynamicLocalAccessor;
663+
#endif
664+
}
665+
650666
// setArgHelper for the raw_kernel_arg extension type.
651667
void setArgHelper(int ArgIndex,
652668
sycl::ext::oneapi::experimental::raw_kernel_arg &&Arg) {
@@ -1838,6 +1854,13 @@ class __SYCL_EXPORT handler {
18381854
setArgHelper(argIndex, dynamicParam);
18391855
}
18401856

1857+
template <typename DataT, int Dims>
1858+
void set_arg(int argIndex,
1859+
ext::oneapi::experimental::dynamic_local_accessor<DataT, Dims>
1860+
&DynamicLocalAccessor) {
1861+
setArgHelper(argIndex, DynamicLocalAccessor);
1862+
}
1863+
18411864
// set_arg for the raw_kernel_arg extension type.
18421865
void set_arg(int argIndex, ext::oneapi::experimental::raw_kernel_arg &&Arg) {
18431866
setArgHelper(argIndex, std::move(Arg));

sycl/source/detail/graph_impl.cpp

+89
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,11 @@ dynamic_parameter_base::dynamic_parameter_base(
18291829
: impl(std::make_shared<dynamic_parameter_impl>(
18301830
sycl::detail::getSyclObjImpl(Graph), ParamSize, Data)) {}
18311831

1832+
dynamic_parameter_base::dynamic_parameter_base(
1833+
command_graph<graph_state::modifiable> Graph)
1834+
: impl(std::make_shared<dynamic_parameter_impl>(
1835+
sycl::detail::getSyclObjImpl(Graph))) {}
1836+
18321837
void dynamic_parameter_base::updateValue(const void *NewValue, size_t Size) {
18331838
impl->updateValue(NewValue, Size);
18341839
}
@@ -1843,6 +1848,20 @@ void dynamic_parameter_base::updateAccessor(
18431848
impl->updateAccessor(Acc);
18441849
}
18451850

1851+
sycl::detail::LocalAccessorImplPtr
1852+
dynamic_parameter_base::getLocalAccessor(handler *Handler) {
1853+
return impl->getLocalAccessor(Handler);
1854+
}
1855+
1856+
void dynamic_parameter_base::registerLocalAccessor(
1857+
sycl::detail::LocalAccessorBaseHost *LocalAccBaseHost, handler *Handler) {
1858+
impl->registerLocalAccessor(LocalAccBaseHost, Handler);
1859+
}
1860+
1861+
void dynamic_parameter_base::updateLocalAccessor(range<3> NewAllocationSize) {
1862+
impl->updateLocalAccessor(NewAllocationSize);
1863+
}
1864+
18461865
void dynamic_parameter_impl::updateValue(const raw_kernel_arg *NewRawValue,
18471866
size_t Size) {
18481867
// Number of bytes is taken from member of raw_kernel_arg object rather
@@ -1898,6 +1917,53 @@ void dynamic_parameter_impl::updateAccessor(
18981917
sizeof(sycl::detail::AccessorBaseHost));
18991918
}
19001919

1920+
sycl::detail::LocalAccessorImplPtr
1921+
dynamic_parameter_impl::getLocalAccessor(handler *Handler) {
1922+
auto HandlerImpl = sycl::detail::getSyclObjImpl(*Handler);
1923+
auto FindLocalAcc = MHandlerToLocalAccMap.find(HandlerImpl);
1924+
1925+
if (FindLocalAcc != MHandlerToLocalAccMap.end()) {
1926+
auto LocalAccImpl = FindLocalAcc->second;
1927+
return LocalAccImpl;
1928+
}
1929+
return nullptr;
1930+
}
1931+
1932+
void dynamic_parameter_impl::registerLocalAccessor(
1933+
sycl::detail::LocalAccessorBaseHost *LocalAccBaseHost, handler *Handler) {
1934+
1935+
auto HandlerImpl = sycl::detail::getSyclObjImpl(*Handler);
1936+
auto LocalAccImpl = sycl::detail::getSyclObjImpl(*LocalAccBaseHost);
1937+
1938+
MHandlerToLocalAccMap.insert({HandlerImpl, LocalAccImpl});
1939+
}
1940+
1941+
void dynamic_parameter_impl::updateLocalAccessor(range<3> NewAllocationSize) {
1942+
1943+
for (auto &[NodeWeak, ArgIndex] : MNodes) {
1944+
auto NodeShared = NodeWeak.lock();
1945+
if (NodeShared) {
1946+
// We can use the first local accessor in the map since the dimensions
1947+
// and element type should be identical.
1948+
auto LocalAccessor = MHandlerToLocalAccMap.begin()->second;
1949+
dynamic_parameter_impl::updateCGLocalAccessor(
1950+
NodeShared->MCommandGroup, ArgIndex, NewAllocationSize,
1951+
LocalAccessor->MDims, LocalAccessor->MElemSize);
1952+
}
1953+
}
1954+
1955+
for (auto &DynCGInfo : MDynCGs) {
1956+
auto DynCG = DynCGInfo.DynCG.lock();
1957+
if (DynCG) {
1958+
auto &CG = DynCG->MKernels[DynCGInfo.CGIndex];
1959+
auto LocalAccessor = MHandlerToLocalAccMap.begin()->second;
1960+
dynamic_parameter_impl::updateCGLocalAccessor(
1961+
CG, DynCGInfo.ArgIndex, NewAllocationSize,
1962+
LocalAccessor->MDims, LocalAccessor->MElemSize);
1963+
}
1964+
}
1965+
}
1966+
19011967
void dynamic_parameter_impl::updateCGArgValue(
19021968
std::shared_ptr<sycl::detail::CG> CG, int ArgIndex, const void *NewValue,
19031969
size_t Size) {
@@ -1963,6 +2029,28 @@ void dynamic_parameter_impl::updateCGAccessor(
19632029
}
19642030
}
19652031

2032+
void dynamic_parameter_impl::updateCGLocalAccessor(
2033+
std::shared_ptr<sycl::detail::CG> CG, int ArgIndex,
2034+
range<3> NewAllocationSize,
2035+
int Dims, int ElemSize) {
2036+
auto &Args = static_cast<sycl::detail::CGExecKernel *>(CG.get())->MArgs;
2037+
2038+
for (auto &Arg : Args) {
2039+
if (Arg.MIndex != ArgIndex) {
2040+
continue;
2041+
}
2042+
assert(Arg.MType == sycl::detail::kernel_param_kind_t::kind_std_layout);
2043+
2044+
int SizeInBytes = ElemSize;
2045+
for (int I = 0; I < Dims; ++I)
2046+
SizeInBytes *= NewAllocationSize[I];
2047+
SizeInBytes = std::max(SizeInBytes, 1);
2048+
2049+
Arg.MSize = SizeInBytes;
2050+
break;
2051+
}
2052+
}
2053+
19662054
dynamic_command_group_impl::dynamic_command_group_impl(
19672055
const command_graph<graph_state::modifiable> &Graph)
19682056
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0) {}
@@ -2084,6 +2172,7 @@ size_t dynamic_command_group::get_active_index() const {
20842172
void dynamic_command_group::set_active_index(size_t Index) {
20852173
return impl->setActiveIndex(Index);
20862174
}
2175+
20872176
} // namespace experimental
20882177
} // namespace oneapi
20892178
} // namespace ext

sycl/source/detail/graph_impl.hpp

+40
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,10 @@ class exec_graph_impl {
14311431

14321432
class dynamic_parameter_impl {
14331433
public:
1434+
/// Used for parameters that don't have data such as local_accessors.
1435+
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl)
1436+
: MGraph(GraphImpl) {}
1437+
14341438
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
14351439
size_t ParamSize, const void *Data)
14361440
: MGraph(GraphImpl), MValueStorage(ParamSize) {
@@ -1496,6 +1500,26 @@ class dynamic_parameter_impl {
14961500
/// @param Acc The new accessor value
14971501
void updateAccessor(const sycl::detail::AccessorBaseHost *Acc);
14981502

1503+
/// Updates the value of all local accessors in registered nodes and dynamic
1504+
/// CGs.
1505+
/// @param NewAllocationSize The new size for the update local accessors.
1506+
void updateLocalAccessor(range<3> NewAllocationSize);
1507+
1508+
/// Gets the implementation for the local accessor that is associated with
1509+
/// a specific handler.
1510+
/// @param The handler that the local accessor is associated with.
1511+
/// @return returns the impl object for the local accessor that is associated
1512+
/// with this handler. Or nullptr if no local accessor has been registered
1513+
/// for this handler.
1514+
sycl::detail::LocalAccessorImplPtr getLocalAccessor(handler *Handler);
1515+
1516+
/// Associates a local accessor with this dynamic local accessor for a
1517+
/// specific handler.
1518+
/// @param LocalAccBase the local accessor that needs to be registered.
1519+
/// @param Handler the handler that the LocalAccessor is associated with.
1520+
void registerLocalAccessor(sycl::detail::LocalAccessorBaseHost *LocalAccBase,
1521+
handler *Handler);
1522+
14991523
/// Static helper function for updating command-group value arguments.
15001524
/// @param CG The command-group to update the argument information for.
15011525
/// @param ArgIndex The argument index to update.
@@ -1512,13 +1536,29 @@ class dynamic_parameter_impl {
15121536
int ArgIndex,
15131537
const sycl::detail::AccessorBaseHost *Acc);
15141538

1539+
/// Static helper function for updating command-group local accessor
1540+
/// arguments.
1541+
/// @param CG The command-group to update the argument information for.
1542+
/// @param ArgIndex The argument index to update.
1543+
/// @param NewAllocationSize The new allocation size for the local accessor
1544+
/// argument.
1545+
/// @param Dims The dimensions of the local accessor argument.
1546+
/// @param ElemSize The size of each element in the local accessor.
1547+
static void updateCGLocalAccessor(std::shared_ptr<sycl::detail::CG> CG,
1548+
int ArgIndex, range<3> NewAllocationSize,
1549+
int Dims, int ElemSize);
1550+
15151551
// Weak ptrs to node_impls which will be updated
15161552
std::vector<std::pair<std::weak_ptr<node_impl>, int>> MNodes;
15171553
// Dynamic command-groups which will be updated
15181554
std::vector<DynamicCGInfo> MDynCGs;
15191555

15201556
std::shared_ptr<graph_impl> MGraph;
15211557
std::vector<std::byte> MValueStorage;
1558+
1559+
std::unordered_map<std::shared_ptr<sycl::detail::handler_impl>,
1560+
sycl::detail::LocalAccessorImplPtr>
1561+
MHandlerToLocalAccMap;
15221562
};
15231563

15241564
class dynamic_command_group_impl

sycl/test-e2e/CMakeLists.txt

-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,5 @@ add_custom_target(check-sycl-e2e
9191
USES_TERMINAL
9292
)
9393

94-
add_executable("local_memory_test" Graph/Explicit/compile_time_local_memory.cpp)
95-
9694
add_subdirectory(External)
9795
add_subdirectory(ExtraTests)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
// Extra run to check for leaks in Level Zero using UR_L0_LEAKS_DEBUG
4+
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=0 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
5+
// Extra run to check for immediate-command-list in Level Zero
6+
// RUN: %if level_zero %{env SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 %{l0_leak_check} %{run} %t.out 2>&1 | FileCheck %s --implicit-check-not=LEAK %}
7+
8+
// Tests updating local accessor parameters.
9+
#include "../graph_common.hpp"
10+
11+
int main() {
12+
using T = int;
13+
14+
const size_t LocalMemSize = 128;
15+
16+
queue Queue{};
17+
18+
std::vector<T> HostDataBeforeUpdate(Size);
19+
std::vector<T> HostDataAfterUpdate(Size);
20+
std::iota(HostDataBeforeUpdate.begin(), HostDataBeforeUpdate.end(), 10);
21+
22+
T *PtrA = malloc_device<T>(Size, Queue);
23+
Queue.copy(HostDataBeforeUpdate.data(), PtrA, Size);
24+
Queue.wait_and_throw();
25+
26+
exp_ext::command_graph Graph{Queue.get_context(), Queue.get_device()};
27+
28+
exp_ext::dynamic_local_accessor<T, 1> DynLocalAccessor{Graph, LocalMemSize};
29+
30+
auto Node = Graph.add([&](handler &CGH) {
31+
CGH.set_arg(0, DynLocalAccessor);
32+
auto LocalMem = DynLocalAccessor.get(CGH);
33+
34+
CGH.parallel_for(nd_range({Size}, {LocalMemSize}), [=](nd_item<1> Item) {
35+
LocalMem[Item.get_local_linear_id()] = Item.get_local_linear_id();
36+
PtrA[Item.get_global_linear_id()] = LocalMem[Item.get_local_linear_id()];
37+
});
38+
});
39+
40+
auto GraphExec = Graph.finalize(exp_ext::property::graph::updatable{});
41+
42+
// Submit the graph before the update and save the results.
43+
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(GraphExec); });
44+
Queue.wait_and_throw();
45+
Queue.copy(PtrA, HostDataBeforeUpdate.data(), Size);
46+
Queue.wait_and_throw();
47+
48+
DynLocalAccessor.update(LocalMemSize * 2);
49+
Node.update_nd_range(nd_range({Size}, {LocalMemSize * 2}));
50+
GraphExec.update(Node);
51+
52+
// Submit the graph after the update and save the results.
53+
Queue.submit([&](handler &CGH) { CGH.ext_oneapi_graph(GraphExec); });
54+
Queue.wait_and_throw();
55+
Queue.copy(PtrA, HostDataAfterUpdate.data(), Size);
56+
Queue.wait_and_throw();
57+
58+
for (size_t i = 0; i < Size; i++) {
59+
T Ref = i % LocalMemSize;
60+
assert(check_value(i, Ref, HostDataBeforeUpdate[i], "PtrA Before Update"));
61+
}
62+
63+
for (size_t i = 0; i < Size; i++) {
64+
T Ref = i % (LocalMemSize * 2);
65+
assert(check_value(i, Ref, HostDataAfterUpdate[i], "PtrA After Update"));
66+
}
67+
68+
free(PtrA, Queue);
69+
70+
return 0;
71+
}

0 commit comments

Comments
 (0)