Skip to content

Commit e9c8b1e

Browse files
Improve the implementation of __enumerable_thread_local_storage (#2049)
Implement with CRTP. Store construction arguments by values. Replace unique_ptr with optional. Co-authored-by: Dan Hoeflinger <[email protected]>
1 parent b21e7ca commit e9c8b1e

File tree

3 files changed

+69
-53
lines changed

3 files changed

+69
-53
lines changed

include/oneapi/dpl/pstl/omp/util.h

+27-17
Original file line numberDiff line numberDiff line change
@@ -152,36 +152,46 @@ __process_chunk(const __chunk_metrics& __metrics, _Iterator __base, _Index __chu
152152

153153
namespace __detail
154154
{
155-
struct __get_num_threads
155+
156+
// Workaround for VS 2017: declare an alias to the CRTP base template
157+
template <typename _ValueType, typename... _Args>
158+
struct __enumerable_thread_local_storage;
159+
160+
template <typename... _Ts>
161+
using __etls_base = __utils::__enumerable_thread_local_storage_base<__enumerable_thread_local_storage, _Ts...>;
162+
163+
template <typename _ValueType, typename... _Args>
164+
struct __enumerable_thread_local_storage : public __etls_base<_ValueType, _Args...>
156165
{
157-
std::size_t
158-
operator()() const
166+
167+
template <typename... _LocalArgs>
168+
__enumerable_thread_local_storage(_LocalArgs&&... __args)
169+
: __etls_base<_ValueType, _Args...>({std::forward<_LocalArgs>(__args)...})
170+
{
171+
}
172+
173+
static std::size_t
174+
get_num_threads()
159175
{
160176
return omp_in_parallel() ? omp_get_num_threads() : omp_get_max_threads();
161177
}
162-
};
163178

164-
struct __get_thread_num
165-
{
166-
std::size_t
167-
operator()() const
179+
static std::size_t
180+
get_thread_num()
168181
{
169182
return omp_get_thread_num();
170183
}
171184
};
172185

173186
} // namespace __detail
174187

175-
// enumerable thread local storage should only be created from make function
176-
template <typename _ValueType, typename... Args>
177-
oneapi::dpl::__utils::__detail::__enumerable_thread_local_storage<
178-
_ValueType, oneapi::dpl::__omp_backend::__detail::__get_num_threads,
179-
oneapi::dpl::__omp_backend::__detail::__get_thread_num, Args...>
180-
__make_enumerable_tls(Args&&... __args)
188+
// enumerable thread local storage should only be created with this make function
189+
template <typename _ValueType, typename... _Args>
190+
__detail::__enumerable_thread_local_storage<_ValueType, std::remove_reference_t<_Args>...>
191+
__make_enumerable_tls(_Args&&... __args)
181192
{
182-
return oneapi::dpl::__utils::__detail::__enumerable_thread_local_storage<
183-
_ValueType, oneapi::dpl::__omp_backend::__detail::__get_num_threads,
184-
oneapi::dpl::__omp_backend::__detail::__get_thread_num, Args...>(std::forward<Args>(__args)...);
193+
return __detail::__enumerable_thread_local_storage<_ValueType, std::remove_reference_t<_Args>...>(
194+
std::forward<_Args>(__args)...);
185195
}
186196

187197
} // namespace __omp_backend

include/oneapi/dpl/pstl/parallel_backend_tbb.h

+30-19
Original file line numberDiff line numberDiff line change
@@ -1308,35 +1308,46 @@ __parallel_for_each(oneapi::dpl::__internal::__tbb_backend_tag, _ExecutionPolicy
13081308

13091309
namespace __detail
13101310
{
1311-
struct __get_num_threads
1311+
1312+
// Workaround for VS 2017: declare an alias to the CRTP base template
1313+
template <typename _ValueType, typename... _Args>
1314+
struct __enumerable_thread_local_storage;
1315+
1316+
template <typename... _Ts>
1317+
using __etls_base = __utils::__enumerable_thread_local_storage_base<__enumerable_thread_local_storage, _Ts...>;
1318+
1319+
template <typename _ValueType, typename... _Args>
1320+
struct __enumerable_thread_local_storage : public __etls_base<_ValueType, _Args...>
13121321
{
1313-
std::size_t
1314-
operator()() const
1322+
1323+
template <typename... _LocalArgs>
1324+
__enumerable_thread_local_storage(_LocalArgs&&... __args)
1325+
: __etls_base<_ValueType, _Args...>({std::forward<_LocalArgs>(__args)...})
1326+
{
1327+
}
1328+
1329+
static std::size_t
1330+
get_num_threads()
13151331
{
13161332
return tbb::this_task_arena::max_concurrency();
13171333
}
1318-
};
13191334

1320-
struct __get_thread_num
1321-
{
1322-
std::size_t
1323-
operator()() const
1335+
static std::size_t
1336+
get_thread_num()
13241337
{
13251338
return tbb::this_task_arena::current_thread_index();
13261339
}
13271340
};
1328-
} //namespace __detail
1329-
1330-
// enumerable thread local storage should only be created from make function
1331-
template <typename _ValueType, typename... Args>
1332-
oneapi::dpl::__utils::__detail::__enumerable_thread_local_storage<
1333-
_ValueType, oneapi::dpl::__tbb_backend::__detail::__get_num_threads,
1334-
oneapi::dpl::__tbb_backend::__detail::__get_thread_num, Args...>
1335-
__make_enumerable_tls(Args&&... __args)
1341+
1342+
} // namespace __detail
1343+
1344+
// enumerable thread local storage should only be created with this make function
1345+
template <typename _ValueType, typename... _Args>
1346+
__detail::__enumerable_thread_local_storage<_ValueType, std::remove_reference_t<_Args>...>
1347+
__make_enumerable_tls(_Args&&... __args)
13361348
{
1337-
return oneapi::dpl::__utils::__detail::__enumerable_thread_local_storage<
1338-
_ValueType, oneapi::dpl::__tbb_backend::__detail::__get_num_threads,
1339-
oneapi::dpl::__tbb_backend::__detail::__get_thread_num, Args...>(std::forward<Args>(__args)...);
1349+
return __detail::__enumerable_thread_local_storage<_ValueType, std::remove_reference_t<_Args>...>(
1350+
std::forward<_Args>(__args)...);
13401351
}
13411352

13421353
} // namespace __tbb_backend

include/oneapi/dpl/pstl/parallel_backend_utils.h

+12-17
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <atomic>
2020
#include <cstddef>
2121
#include <iterator>
22-
#include <memory>
22+
#include <optional>
2323
#include <tuple>
2424
#include <utility>
2525
#include <vector>
@@ -306,16 +306,13 @@ __set_symmetric_difference_construct(_ForwardIterator1 __first1, _ForwardIterato
306306
return __cc_range(__first2, __last2, __result);
307307
}
308308

309-
namespace __detail
309+
template <template <typename, typename...> typename _Concrete, typename _ValueType, typename... _Args>
310+
struct __enumerable_thread_local_storage_base
310311
{
312+
using _Derived = _Concrete<_ValueType, _Args...>;
311313

312-
template <typename _ValueType, typename _GetNumThreads, typename _GetThreadNum, typename... _Args>
313-
struct __enumerable_thread_local_storage
314-
{
315-
316-
template <typename... _LocalArgs>
317-
__enumerable_thread_local_storage(_LocalArgs&&... __args)
318-
: __thread_specific_storage(_GetNumThreads{}()), __num_elements(0), __args(std::forward<_LocalArgs>(__args)...)
314+
__enumerable_thread_local_storage_base(std::tuple<_Args...> __tp)
315+
: __thread_specific_storage(_Derived::get_num_threads()), __num_elements(0), __args(__tp)
319316
{
320317
}
321318

@@ -359,24 +356,22 @@ struct __enumerable_thread_local_storage
359356
_ValueType&
360357
get_for_current_thread()
361358
{
362-
const std::size_t __i = _GetThreadNum{}();
363-
std::unique_ptr<_ValueType>& __thread_local_storage = __thread_specific_storage[__i];
364-
if (!__thread_local_storage)
359+
const std::size_t __i = _Derived::get_thread_num();
360+
std::optional<_ValueType>& __local = __thread_specific_storage[__i];
361+
if (!__local)
365362
{
366363
// create temporary storage on first usage to avoid extra parallel region and unnecessary instantiation
367-
__thread_local_storage =
368-
std::apply([](_Args... __arg_pack) { return std::make_unique<_ValueType>(__arg_pack...); }, __args);
364+
std::apply([&__local](_Args... __arg_pack) { __local.emplace(__arg_pack...); }, __args);
369365
__num_elements.fetch_add(1, std::memory_order_relaxed);
370366
}
371-
return *__thread_local_storage;
367+
return *__local;
372368
}
373369

374-
std::vector<std::unique_ptr<_ValueType>> __thread_specific_storage;
370+
std::vector<std::optional<_ValueType>> __thread_specific_storage;
375371
std::atomic_size_t __num_elements;
376372
const std::tuple<_Args...> __args;
377373
};
378374

379-
} // namespace __detail
380375
} // namespace __utils
381376
} // namespace dpl
382377
} // namespace oneapi

0 commit comments

Comments
 (0)