Skip to content

Commit e588f8f

Browse files
[NFCI][SYCL] Less shared_ptr for platform_impl
`GlobalHandler::MPlatformCache` keeps (shared) ownership of `platform_impl` objects, so none of them could be destroyed until SYCL RT library shutdown/unload process. As such, using raw pointers/reference to `platform_impl` throughout the SYCL RT is totally fine and avoids extra costs of `std::shared_ptr` I'm relatively sure `sycl::platform` could avoid using `std::shared_ptr<detail::platform_impl> impl` as well, but that would be a breaking change so not being implemented at this moment.
1 parent c04e5db commit e588f8f

22 files changed

+126
-113
lines changed

sycl/include/sycl/platform.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@ inline namespace _V1 {
3737
// Forward declaration
3838
class device;
3939
class context;
40+
class platform;
4041

4142
template <backend BackendName, class SyclObjectT>
4243
auto get_native(const SyclObjectT &Obj)
4344
-> backend_return_t<BackendName, SyclObjectT>;
4445
namespace detail {
4546
class platform_impl;
47+
template <class T>
48+
std::enable_if_t<std::is_same_v<T, platform>, platform>
49+
createSyclObjFromImpl(platform_impl &);
4650

4751
/// Allows to enable/disable "Default Context" extension
4852
///
@@ -231,6 +235,9 @@ class __SYCL_EXPORT platform : public detail::OwnerLessBase<platform> {
231235
template <class Obj>
232236
friend const decltype(Obj::impl) &
233237
detail::getSyclObjImpl(const Obj &SyclObject);
238+
template <class T>
239+
friend std::enable_if_t<std::is_same_v<T, platform>, platform>
240+
detail::createSyclObjFromImpl(detail::platform_impl &);
234241

235242
template <backend BackendName, class SyclObjectT>
236243
friend auto get_native(const SyclObjectT &Obj)

sycl/source/backend.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ __SYCL_EXPORT device make_device(ur_native_handle_t NativeHandle,
8989
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);
9090

9191
// Construct the SYCL device from UR device.
92-
auto Platform = platform_impl::getPlatformFromUrDevice(UrDevice, Adapter);
9392
return detail::createSyclObjFromImpl<device>(
94-
Platform->getOrMakeDeviceImpl(UrDevice, Platform));
93+
platform_impl::getPlatformFromUrDevice(UrDevice, Adapter)
94+
.getOrMakeDeviceImpl(UrDevice));
9595
}
9696

9797
__SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
@@ -288,10 +288,9 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
288288
std::transform(
289289
ProgramDevices.begin(), ProgramDevices.end(), std::back_inserter(Devices),
290290
[&Adapter](const auto &Dev) {
291-
auto Platform =
292-
detail::platform_impl::getPlatformFromUrDevice(Dev, Adapter);
293-
auto DeviceImpl = Platform->getOrMakeDeviceImpl(Dev, Platform);
294-
return createSyclObjFromImpl<device>(DeviceImpl);
291+
return createSyclObjFromImpl<device>(
292+
detail::platform_impl::getPlatformFromUrDevice(Dev, Adapter)
293+
.getOrMakeDeviceImpl(Dev));
295294
});
296295

297296
// Unlike SYCL, other backends, like OpenCL or Level Zero, may not support

sycl/source/backend/level_zero.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@ using namespace sycl::detail;
2020
__SYCL_EXPORT device make_device(const platform &Platform,
2121
ur_native_handle_t NativeHandle) {
2222
const auto &Adapter = ur::getAdapter<backend::ext_oneapi_level_zero>();
23-
const auto &PlatformImpl = getSyclObjImpl(Platform);
2423
// Create UR device first.
2524
ur_device_handle_t UrDevice;
2625
Adapter->call<UrApiKind::urDeviceCreateWithNativeHandle>(
2726
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);
2827

2928
return detail::createSyclObjFromImpl<device>(
30-
PlatformImpl->getOrMakeDeviceImpl(UrDevice, PlatformImpl));
29+
getSyclObjImpl(Platform)->getOrMakeDeviceImpl(UrDevice));
3130
}
3231

3332
} // namespace ext::oneapi::level_zero::detail

sycl/source/detail/allowlist.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,9 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
375375

376376
// Get platform's backend and put it to DeviceDesc
377377
DeviceDescT DeviceDesc;
378-
auto PlatformImpl = platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
379-
backend Backend = PlatformImpl->getBackend();
378+
platform_impl &PlatformImpl =
379+
platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
380+
backend Backend = PlatformImpl.getBackend();
380381

381382
for (const auto &SyclBe : getSyclBeMap()) {
382383
if (SyclBe.second == Backend) {
@@ -395,7 +396,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
395396

396397
int InsertIDx = 0;
397398
for (ur_device_handle_t Device : UrDevices) {
398-
auto DeviceImpl = PlatformImpl->getOrMakeDeviceImpl(Device, PlatformImpl);
399+
auto DeviceImpl = PlatformImpl.getOrMakeDeviceImpl(Device);
399400
// get DeviceType value and put it to DeviceDesc
400401
ur_device_type_t UrDevType = UR_DEVICE_TYPE_ALL;
401402
Adapter->call<UrApiKind::urDeviceGetInfo>(

sycl/source/detail/buffer_impl.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,11 @@ buffer_impl::getNativeVector(backend BackendName) const {
7979
// doesn't have context and platform
8080
if (!Ctx)
8181
continue;
82-
const PlatformImplPtr &Platform = Ctx->getPlatformImpl();
83-
assert(Platform && "Platform must be present for device context");
84-
if (Platform->getBackend() != BackendName)
82+
const platform_impl &Platform = Ctx->getPlatformImpl();
83+
if (Platform.getBackend() != BackendName)
8584
continue;
8685

87-
auto Adapter = Platform->getAdapter();
86+
auto Adapter = Platform.getAdapter();
8887

8988
ur_native_handle_t Handle = 0;
9089
// When doing buffer interop we don't know what device the memory should be
@@ -94,7 +93,7 @@ buffer_impl::getNativeVector(backend BackendName) const {
9493
&Handle);
9594
Handles.push_back(Handle);
9695

97-
if (Platform->getBackend() == backend::opencl) {
96+
if (Platform.getBackend() == backend::opencl) {
9897
__SYCL_OCL_CALL(clRetainMemObject, ur::cast<cl_mem>(Handle));
9998
}
10099
}

sycl/source/detail/context_impl.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ context_impl::context_impl(const device &Device, async_handler AsyncHandler,
3131
const property_list &PropList)
3232
: MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(1, Device),
3333
MContext(nullptr),
34-
MPlatform(detail::getSyclObjImpl(Device.get_platform())),
34+
MPlatform(detail::getSyclObjImpl(Device.get_platform()).get()),
3535
MPropList(PropList), MSupportBufferLocationByDevices(NotChecked) {
3636
verifyProps(PropList);
3737
MKernelProgramCache.setContextPtr(this);
@@ -41,10 +41,10 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
4141
async_handler AsyncHandler,
4242
const property_list &PropList)
4343
: MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(Devices),
44-
MContext(nullptr), MPlatform(), MPropList(PropList),
45-
MSupportBufferLocationByDevices(NotChecked) {
44+
MContext(nullptr),
45+
MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform()).get()),
46+
MPropList(PropList), MSupportBufferLocationByDevices(NotChecked) {
4647
verifyProps(PropList);
47-
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
4848
std::vector<ur_device_handle_t> DeviceIds;
4949
for (const auto &D : MDevices) {
5050
if (D.has(aspect::ext_oneapi_is_composite)) {
@@ -77,7 +77,7 @@ context_impl::context_impl(ur_context_handle_t UrContext,
7777
MDevices(DeviceList), MContext(UrContext), MPlatform(),
7878
MSupportBufferLocationByDevices(NotChecked) {
7979
if (!MDevices.empty()) {
80-
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform());
80+
MPlatform = detail::getSyclObjImpl(MDevices[0].get_platform()).get();
8181
} else {
8282
std::vector<ur_device_handle_t> DeviceIds;
8383
uint32_t DevicesNum = 0;
@@ -96,13 +96,13 @@ context_impl::context_impl(ur_context_handle_t UrContext,
9696
make_error_code(errc::invalid),
9797
"No devices in the provided device list and native context.");
9898

99-
std::shared_ptr<detail::platform_impl> Platform =
99+
platform_impl &Platform =
100100
platform_impl::getPlatformFromUrDevice(DeviceIds[0], Adapter);
101101
for (ur_device_handle_t Dev : DeviceIds) {
102-
MDevices.emplace_back(createSyclObjFromImpl<device>(
103-
Platform->getOrMakeDeviceImpl(Dev, Platform)));
102+
MDevices.emplace_back(
103+
createSyclObjFromImpl<device>(Platform.getOrMakeDeviceImpl(Dev)));
104104
}
105-
MPlatform = Platform;
105+
MPlatform = &Platform;
106106
}
107107
// TODO catch an exception and put it to list of asynchronous exceptions
108108
// getAdapter() will be the same as the Adapter passed. This should be taken
@@ -158,7 +158,7 @@ uint32_t context_impl::get_info<info::context::reference_count>() const {
158158
this->getAdapter());
159159
}
160160
template <> platform context_impl::get_info<info::context::platform>() const {
161-
return createSyclObjFromImpl<platform>(MPlatform);
161+
return createSyclObjFromImpl<platform>(*MPlatform);
162162
}
163163
template <>
164164
std::vector<sycl::device>

sycl/source/detail/context_impl.hpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ inline namespace _V1 {
2929
// Forward declaration
3030
class device;
3131
namespace detail {
32-
using PlatformImplPtr = std::shared_ptr<detail::platform_impl>;
3332
class context_impl {
3433
public:
3534
/// Constructs a context_impl using a single SYCL devices.
@@ -89,8 +88,10 @@ class context_impl {
8988
/// \return the Adapter associated with the platform of this context.
9089
const AdapterPtr &getAdapter() const { return MPlatform->getAdapter(); }
9190

91+
// TODO: Think more about `const`
9292
/// \return the PlatformImpl associated with this context.
93-
const PlatformImplPtr &getPlatformImpl() const { return MPlatform; }
93+
const platform_impl &getPlatformImpl() const { return *MPlatform; }
94+
platform_impl &getPlatformImpl() { return *MPlatform; }
9495

9596
/// Queries this context for information.
9697
///
@@ -257,7 +258,8 @@ class context_impl {
257258
async_handler MAsyncHandler;
258259
std::vector<device> MDevices;
259260
ur_context_handle_t MContext;
260-
PlatformImplPtr MPlatform;
261+
// TODO: Make it a reference instead, but that needs a bit more refactoring:
262+
platform_impl *MPlatform = nullptr;
261263
property_list MPropList;
262264
CachedLibProgramsT MCachedLibPrograms;
263265
std::mutex MCachedLibProgramsMutex;

sycl/source/detail/device_impl.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@ namespace detail {
2121

2222
/// Constructs a SYCL device instance using the provided
2323
/// UR device instance.
24-
device_impl::device_impl(ur_device_handle_t Device, PlatformImplPtr Platform)
24+
device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform)
2525
: MDevice(Device), MPlatform(Platform),
2626
MDeviceHostBaseTime(std::make_pair(0, 0)) {
27-
const AdapterPtr &Adapter = Platform->getAdapter();
27+
const AdapterPtr &Adapter = Platform.getAdapter();
2828

2929
// TODO catch an exception and put it to list of asynchronous exceptions
3030
Adapter->call<UrApiKind::urDeviceGetInfo>(
3131
MDevice, UR_DEVICE_INFO_TYPE, sizeof(ur_device_type_t), &MType, nullptr);
3232

3333
// No need to set MRootDevice when MAlwaysRootDevice is true
34-
if (!Platform->MAlwaysRootDevice) {
34+
if (!Platform.MAlwaysRootDevice) {
3535
// TODO catch an exception and put it to list of asynchronous exceptions
3636
Adapter->call<UrApiKind::urDeviceGetInfo>(
3737
MDevice, UR_DEVICE_INFO_PARENT_DEVICE, sizeof(ur_device_handle_t),
@@ -177,7 +177,7 @@ std::vector<device> device_impl::create_sub_devices(
177177
std::for_each(SubDevices.begin(), SubDevices.end(),
178178
[&res, this](const ur_device_handle_t &a_ur_device) {
179179
device sycl_device = detail::createSyclObjFromImpl<device>(
180-
MPlatform->getOrMakeDeviceImpl(a_ur_device, MPlatform));
180+
MPlatform.getOrMakeDeviceImpl(a_ur_device));
181181
res.push_back(sycl_device);
182182
});
183183
return res;

sycl/source/detail/device_impl.hpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,13 @@ namespace detail {
3030

3131
// Forward declaration
3232
class platform_impl;
33-
using PlatformImplPtr = std::shared_ptr<platform_impl>;
3433

3534
// TODO: Make code thread-safe
3635
class device_impl {
3736
public:
3837
/// Constructs a SYCL device instance using the provided
3938
/// UR device instance.
40-
explicit device_impl(ur_device_handle_t Device, PlatformImplPtr Platform);
39+
explicit device_impl(ur_device_handle_t Device, platform_impl &Platform);
4140

4241
~device_impl();
4342

@@ -94,7 +93,7 @@ class device_impl {
9493
platform get_platform() const;
9594

9695
/// \return the associated adapter with this device.
97-
const AdapterPtr &getAdapter() const { return MPlatform->getAdapter(); }
96+
const AdapterPtr &getAdapter() const { return MPlatform.getAdapter(); }
9897

9998
/// Check SYCL extension support by device
10099
///
@@ -276,11 +275,11 @@ class device_impl {
276275
bool isGetDeviceAndHostTimerSupported();
277276

278277
/// Get the backend of this device
279-
backend getBackend() const { return MPlatform->getBackend(); }
278+
backend getBackend() const { return MPlatform.getBackend(); }
280279

280+
// TODO: const-correctness
281281
/// @brief Get the platform impl serving this device
282-
/// @return PlatformImplPtr
283-
const PlatformImplPtr &getPlatformImpl() const { return MPlatform; }
282+
platform_impl &getPlatformImpl() const { return MPlatform; }
284283

285284
/// Get device info string
286285
std::string get_device_info_string(ur_device_info_t InfoCode) const;
@@ -292,7 +291,7 @@ class device_impl {
292291
ur_device_handle_t MDevice = 0;
293292
ur_device_type_t MType;
294293
ur_device_handle_t MRootDevice = nullptr;
295-
PlatformImplPtr MPlatform;
294+
platform_impl &MPlatform;
296295
bool MUseNativeAssert = false;
297296
mutable std::string MDeviceName;
298297
mutable std::once_flag MDeviceNameFlag;

sycl/source/detail/device_info.hpp

+5-8
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
namespace sycl {
3636
inline namespace _V1 {
3737
namespace detail {
38-
3938
inline std::vector<memory_order>
4039
readMemoryOrderBitfield(ur_memory_order_capability_flags_t bits) {
4140
std::vector<memory_order> result;
@@ -1171,9 +1170,8 @@ template <> struct get_device_info_impl<device, info::device::parent_device> {
11711170
throw exception(make_error_code(errc::invalid),
11721171
"No parent for device because it is not a subdevice");
11731172

1174-
const auto &Platform = Dev.getPlatformImpl();
11751173
return createSyclObjFromImpl<device>(
1176-
Platform->getOrMakeDeviceImpl(result, Platform));
1174+
Dev.getPlatformImpl().getOrMakeDeviceImpl(result));
11771175
}
11781176
};
11791177

@@ -1337,10 +1335,10 @@ struct get_device_info_impl<
13371335
ext::oneapi::experimental::info::device::component_devices>::value,
13381336
ResultSize, Devs.data(), nullptr);
13391337
std::vector<sycl::device> Result;
1340-
const auto &Platform = Dev.getPlatformImpl();
1338+
platform_impl &Platform = Dev.getPlatformImpl();
13411339
for (const auto &d : Devs)
1342-
Result.push_back(createSyclObjFromImpl<device>(
1343-
Platform->getOrMakeDeviceImpl(d, Platform)));
1340+
Result.push_back(
1341+
createSyclObjFromImpl<device>(Platform.getOrMakeDeviceImpl(d)));
13441342

13451343
return Result;
13461344
}
@@ -1363,9 +1361,8 @@ struct get_device_info_impl<
13631361
sizeof(Result), &Result, nullptr);
13641362

13651363
if (Result) {
1366-
const auto &Platform = Dev.getPlatformImpl();
13671364
return createSyclObjFromImpl<device>(
1368-
Platform->getOrMakeDeviceImpl(Result, Platform));
1365+
Dev.getPlatformImpl().getOrMakeDeviceImpl(Result));
13691366
}
13701367
throw sycl::exception(make_error_code(errc::invalid),
13711368
"A component with aspect::ext_oneapi_is_component "

sycl/source/detail/global_handler.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ ProgramManager &GlobalHandler::getProgramManager() {
186186
return PM;
187187
}
188188

189-
std::unordered_map<PlatformImplPtr, ContextImplPtr> &
189+
std::unordered_map<platform_impl *, ContextImplPtr> &
190190
GlobalHandler::getPlatformToDefaultContextCache() {
191191
// The optimization with static reference is not done because
192192
// there are public methods of the GlobalHandler
@@ -207,8 +207,8 @@ Sync &GlobalHandler::getSync() {
207207
return sync;
208208
}
209209

210-
std::vector<PlatformImplPtr> &GlobalHandler::getPlatformCache() {
211-
static std::vector<PlatformImplPtr> &PlatformCache =
210+
std::vector<std::shared_ptr<platform_impl>> &GlobalHandler::getPlatformCache() {
211+
static std::vector<std::shared_ptr<platform_impl>> &PlatformCache =
212212
getOrCreate(MPlatformCache);
213213
return PlatformCache;
214214
}

sycl/source/detail/global_handler.hpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class ods_target_list;
2727
class XPTIRegistry;
2828
class ThreadPool;
2929

30-
using PlatformImplPtr = std::shared_ptr<platform_impl>;
3130
using ContextImplPtr = std::shared_ptr<context_impl>;
3231
using AdapterPtr = std::shared_ptr<Adapter>;
3332

@@ -60,9 +59,9 @@ class GlobalHandler {
6059
bool isSchedulerAlive() const;
6160
ProgramManager &getProgramManager();
6261
Sync &getSync();
63-
std::vector<PlatformImplPtr> &getPlatformCache();
62+
std::vector<std::shared_ptr<platform_impl>> &getPlatformCache();
6463

65-
std::unordered_map<PlatformImplPtr, ContextImplPtr> &
64+
std::unordered_map<platform_impl *, ContextImplPtr> &
6665
getPlatformToDefaultContextCache();
6766

6867
std::mutex &getPlatformToDefaultContextCacheMutex();
@@ -118,8 +117,8 @@ class GlobalHandler {
118117
InstWithLock<Scheduler> MScheduler;
119118
InstWithLock<ProgramManager> MProgramManager;
120119
InstWithLock<Sync> MSync;
121-
InstWithLock<std::vector<PlatformImplPtr>> MPlatformCache;
122-
InstWithLock<std::unordered_map<PlatformImplPtr, ContextImplPtr>>
120+
InstWithLock<std::vector<std::shared_ptr<platform_impl>>> MPlatformCache;
121+
InstWithLock<std::unordered_map<platform_impl *, ContextImplPtr>>
123122
MPlatformToDefaultContextCache;
124123
InstWithLock<std::mutex> MPlatformToDefaultContextCacheMutex;
125124
InstWithLock<std::mutex> MPlatformMapMutex;

sycl/source/detail/kernel_impl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void kernel_impl::checkIfValidForNumArgsInfoQuery() const {
126126
}
127127

128128
void kernel_impl::enableUSMIndirectAccess() const {
129-
if (!MContext->getPlatformImpl()->supports_usm())
129+
if (!MContext->getPlatformImpl().supports_usm())
130130
return;
131131

132132
// Some UR Adapters (like OpenCL) require this call to enable USM

0 commit comments

Comments
 (0)