Skip to content

Commit 7bbaf83

Browse files
Bugfix to decay Policy for __result_and_scratch_storage (#2031)
Signed-off-by: Dan Hoeflinger <[email protected]>
1 parent 4cbf18c commit 7bbaf83

File tree

3 files changed

+43
-36
lines changed

3 files changed

+43
-36
lines changed

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h

+25-22
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,7 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend
615615

616616
// Although we do not actually need result storage in this case, we need to construct
617617
// a placeholder here to match the return type of the non-single-work-group implementation
618-
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _ValueType>;
619-
__result_and_scratch_storage_t __dummy_result_and_scratch{__exec, 0, 0};
618+
__result_and_scratch_storage<_ExecutionPolicy, _ValueType> __dummy_result_and_scratch{__exec, 0, 0};
620619

621620
if (__max_wg_size >= __targeted_wg_size)
622621
{
@@ -1093,7 +1092,7 @@ struct __write_to_id_if_else
10931092
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _UnaryOperation, typename _InitType,
10941093
typename _BinaryOperation, typename _Inclusive>
10951094
auto
1096-
__parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backend_tag, const _ExecutionPolicy& __exec,
1095+
__parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
10971096
_Range1&& __in_rng, _Range2&& __out_rng, std::size_t __n, _UnaryOperation __unary_op,
10981097
_InitType __init, _BinaryOperation __binary_op, _Inclusive)
10991098
{
@@ -1122,9 +1121,9 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen
11221121
std::size_t __single_group_upper_limit = __use_reduce_then_scan ? 2048 : 16384;
11231122
if (__group_scan_fits_in_slm<_Type>(__exec.queue(), __n, __n_uniform, __single_group_upper_limit))
11241123
{
1125-
return __parallel_transform_scan_single_group(__backend_tag, __exec, std::forward<_Range1>(__in_rng),
1126-
std::forward<_Range2>(__out_rng), __n, __unary_op, __init,
1127-
__binary_op, _Inclusive{});
1124+
return __parallel_transform_scan_single_group(
1125+
__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__in_rng),
1126+
std::forward<_Range2>(__out_rng), __n, __unary_op, __init, __binary_op, _Inclusive{});
11281127
}
11291128
}
11301129
#if _ONEDPL_COMPILE_KERNEL
@@ -1161,7 +1160,8 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen
11611160
_NoOpFunctor __get_data_op;
11621161

11631162
return __parallel_transform_scan_base(
1164-
__backend_tag, __exec, std::forward<_Range1>(__in_rng), std::forward<_Range2>(__out_rng), __init,
1163+
__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__in_rng),
1164+
std::forward<_Range2>(__out_rng), __init,
11651165
// local scan
11661166
unseq_backend::__scan<_Inclusive, _ExecutionPolicy, _BinaryOperation, _UnaryFunctor, _Assigner, _Assigner,
11671167
_NoOpFunctor, _InitType>{__binary_op, _UnaryFunctor{__unary_op}, __assign_op, __assign_op,
@@ -1283,7 +1283,7 @@ __parallel_scan_copy(oneapi::dpl::__internal::__device_backend_tag __backend_tag
12831283

12841284
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _BinaryPredicate>
12851285
auto
1286-
__parallel_unique_copy(oneapi::dpl::__internal::__device_backend_tag __backend_tag, const _ExecutionPolicy& __exec,
1286+
__parallel_unique_copy(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
12871287
_Range1&& __rng, _Range2&& __result, _BinaryPredicate __pred)
12881288
{
12891289
using _Assign = oneapi::dpl::__internal::__pstl_assign;
@@ -1316,8 +1316,9 @@ __parallel_unique_copy(oneapi::dpl::__internal::__device_backend_tag __backend_t
13161316
decltype(__n)>;
13171317
using _CopyOp = unseq_backend::__copy_by_mask<_ReduceOp, _Assign, /*inclusive*/ std::true_type, 1>;
13181318

1319-
return __parallel_scan_copy(__backend_tag, __exec, std::forward<_Range1>(__rng), std::forward<_Range2>(__result),
1320-
__n, _CreateOp{oneapi::dpl::__internal::__not_pred<_BinaryPredicate>{__pred}},
1319+
return __parallel_scan_copy(__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng),
1320+
std::forward<_Range2>(__result), __n,
1321+
_CreateOp{oneapi::dpl::__internal::__not_pred<_BinaryPredicate>{__pred}},
13211322
_CopyOp{_ReduceOp{}, _Assign{}});
13221323
}
13231324

@@ -1357,7 +1358,7 @@ __parallel_reduce_by_segment_reduce_then_scan(oneapi::dpl::__internal::__device_
13571358

13581359
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _UnaryPredicate>
13591360
auto
1360-
__parallel_partition_copy(oneapi::dpl::__internal::__device_backend_tag __backend_tag, const _ExecutionPolicy& __exec,
1361+
__parallel_partition_copy(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
13611362
_Range1&& __rng, _Range2&& __result, _UnaryPredicate __pred)
13621363
{
13631364
oneapi::dpl::__internal::__difference_t<_Range1> __n = __rng.size();
@@ -1383,14 +1384,14 @@ __parallel_partition_copy(oneapi::dpl::__internal::__device_backend_tag __backen
13831384
using _CreateOp = unseq_backend::__create_mask<_UnaryPredicate, decltype(__n)>;
13841385
using _CopyOp = unseq_backend::__partition_by_mask<_ReduceOp, /*inclusive*/ std::true_type>;
13851386

1386-
return __parallel_scan_copy(__backend_tag, __exec, std::forward<_Range1>(__rng), std::forward<_Range2>(__result),
1387-
__n, _CreateOp{__pred}, _CopyOp{_ReduceOp{}});
1387+
return __parallel_scan_copy(__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng),
1388+
std::forward<_Range2>(__result), __n, _CreateOp{__pred}, _CopyOp{_ReduceOp{}});
13881389
}
13891390

13901391
template <typename _ExecutionPolicy, typename _InRng, typename _OutRng, typename _Size, typename _Pred,
13911392
typename _Assign = oneapi::dpl::__internal::__pstl_assign>
13921393
auto
1393-
__parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag __backend_tag, const _ExecutionPolicy& __exec,
1394+
__parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
13941395
_InRng&& __in_rng, _OutRng&& __out_rng, _Size __n, _Pred __pred, _Assign __assign = _Assign{})
13951396
{
13961397
using _SingleGroupInvoker = __invoke_single_group_copy_if<_Size>;
@@ -1440,8 +1441,9 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag __backend_tag,
14401441
using _CopyOp = unseq_backend::__copy_by_mask<_ReduceOp, _Assign,
14411442
/*inclusive*/ std::true_type, 1>;
14421443

1443-
return __parallel_scan_copy(__backend_tag, __exec, std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng),
1444-
__n, _CreateOp{__pred}, _CopyOp{_ReduceOp{}, __assign});
1444+
return __parallel_scan_copy(__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng),
1445+
std::forward<_OutRng>(__out_rng), __n, _CreateOp{__pred},
1446+
_CopyOp{_ReduceOp{}, __assign});
14451447
}
14461448

14471449
#if _ONEDPL_COMPILE_KERNEL
@@ -1534,7 +1536,7 @@ __parallel_set_scan(oneapi::dpl::__internal::__device_backend_tag __backend_tag,
15341536
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare,
15351537
typename _IsOpDifference>
15361538
auto
1537-
__parallel_set_op(oneapi::dpl::__internal::__device_backend_tag __backend_tag, const _ExecutionPolicy& __exec,
1539+
__parallel_set_op(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
15381540
_Range1&& __rng1, _Range2&& __rng2, _Range3&& __result, _Compare __comp,
15391541
_IsOpDifference __is_op_difference)
15401542
{
@@ -1552,8 +1554,9 @@ __parallel_set_op(oneapi::dpl::__internal::__device_backend_tag __backend_tag, c
15521554
}
15531555
}
15541556
#endif
1555-
return __parallel_set_scan(__backend_tag, __exec, std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
1556-
std::forward<_Range3>(__result), __comp, __is_op_difference);
1557+
return __parallel_set_scan(__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1),
1558+
std::forward<_Range2>(__rng2), std::forward<_Range3>(__result), __comp,
1559+
__is_op_difference);
15571560
}
15581561

15591562
//------------------------------------------------------------------------
@@ -2467,8 +2470,8 @@ __parallel_reduce_by_segment_fallback(oneapi::dpl::__internal::__device_backend_
24672470
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Range4,
24682471
typename _BinaryPredicate, typename _BinaryOperator>
24692472
oneapi::dpl::__internal::__difference_t<_Range3>
2470-
__parallel_reduce_by_segment(oneapi::dpl::__internal::__device_backend_tag, const _ExecutionPolicy& __exec,
2471-
_Range1&& __keys, _Range2&& __values, _Range3&& __out_keys, _Range4&& __out_values,
2473+
__parallel_reduce_by_segment(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range1&& __keys,
2474+
_Range2&& __values, _Range3&& __out_keys, _Range4&& __out_values,
24722475
_BinaryPredicate __binary_pred, _BinaryOperator __binary_op)
24732476
{
24742477
// The algorithm reduces values in __values where the
@@ -2506,7 +2509,7 @@ __parallel_reduce_by_segment(oneapi::dpl::__internal::__device_backend_tag, cons
25062509
}
25072510
#endif
25082511
return __parallel_reduce_by_segment_fallback(
2509-
oneapi::dpl::__internal::__device_backend_tag{}, __exec,
2512+
oneapi::dpl::__internal::__device_backend_tag{}, std::forward<_ExecutionPolicy>(__exec),
25102513
std::forward<_Range1>(__keys), std::forward<_Range2>(__values), std::forward<_Range3>(__out_keys),
25112514
std::forward<_Range4>(__out_values), __binary_pred, __binary_op,
25122515
oneapi::dpl::unseq_backend::__has_known_identity<_BinaryOperator, __val_type>{});

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce.h

+6-8
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,12 @@ template <typename _Tp, typename _Commutative, std::uint8_t _VecSize, typename..
186186
struct __parallel_transform_reduce_device_kernel_submitter<_Tp, _Commutative, _VecSize,
187187
__internal::__optional_kernel_name<_KernelName...>>
188188
{
189-
template <typename _ExecutionPolicy, typename _Size, typename _ReduceOp, typename _TransformOp,
190-
typename _ExecutionPolicy2, typename... _Ranges>
189+
template <typename _ExecutionPolicy, typename _Size, typename _ReduceOp, typename _TransformOp, typename... _Ranges>
191190
auto
192191
operator()(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, const _Size __n,
193192
const _Size __work_group_size, const _Size __iters_per_work_item, _ReduceOp __reduce_op,
194193
_TransformOp __transform_op,
195-
const __result_and_scratch_storage<_ExecutionPolicy2, _Tp>& __scratch_container,
194+
const __result_and_scratch_storage<_ExecutionPolicy, _Tp>& __scratch_container,
196195
_Ranges&&... __rngs) const
197196
{
198197
auto __transform_pattern =
@@ -215,7 +214,7 @@ struct __parallel_transform_reduce_device_kernel_submitter<_Tp, _Commutative, _V
215214
sycl::nd_range<1>(sycl::range<1>(__n_groups * __work_group_size), sycl::range<1>(__work_group_size)),
216215
[=](sycl::nd_item<1> __item_id) {
217216
auto __temp_ptr =
218-
__result_and_scratch_storage<_ExecutionPolicy2, _Tp>::__get_usm_or_buffer_accessor_ptr(
217+
__result_and_scratch_storage<_ExecutionPolicy, _Tp>::__get_usm_or_buffer_accessor_ptr(
219218
__temp_acc);
220219
__device_reduce_kernel<_Tp>(__item_id, __n, __iters_per_work_item, __is_full, __n_groups,
221220
__transform_pattern, __reduce_pattern, __temp_local, __temp_ptr,
@@ -235,12 +234,11 @@ template <typename _Tp, typename _Commutative, std::uint8_t _VecSize, typename..
235234
struct __parallel_transform_reduce_work_group_kernel_submitter<_Tp, _Commutative, _VecSize,
236235
__internal::__optional_kernel_name<_KernelName...>>
237236
{
238-
template <typename _ExecutionPolicy, typename _Size, typename _ReduceOp, typename _InitType,
239-
typename _ExecutionPolicy2>
237+
template <typename _ExecutionPolicy, typename _Size, typename _ReduceOp, typename _InitType>
240238
auto
241239
operator()(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, sycl::event& __reduce_event,
242240
const _Size __n, const _Size __work_group_size, const _Size __iters_per_work_item, _ReduceOp __reduce_op,
243-
_InitType __init, const __result_and_scratch_storage<_ExecutionPolicy2, _Tp>& __scratch_container) const
241+
_InitType __init, const __result_and_scratch_storage<_ExecutionPolicy, _Tp>& __scratch_container) const
244242
{
245243
using _NoOpFunctor = unseq_backend::walk_n<_ExecutionPolicy, oneapi::dpl::__internal::__no_op>;
246244
auto __transform_pattern =
@@ -250,7 +248,7 @@ struct __parallel_transform_reduce_work_group_kernel_submitter<_Tp, _Commutative
250248

251249
const bool __is_full = __n == __work_group_size * __iters_per_work_item;
252250

253-
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy2, _Tp>;
251+
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _Tp>;
254252

255253
__reduce_event = __exec.queue().submit([&, __n](sycl::handler& __cgh) {
256254
__cgh.depends_on(__reduce_event);

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h

+12-6
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ struct __result_and_scratch_storage_base
530530
};
531531

532532
template <typename _ExecutionPolicy, typename _T>
533-
struct __result_and_scratch_storage : __result_and_scratch_storage_base
533+
struct __result_and_scratch_storage_impl : __result_and_scratch_storage_base
534534
{
535535
private:
536536
using __sycl_buffer_t = sycl::buffer<_T, 1>;
@@ -578,10 +578,10 @@ struct __result_and_scratch_storage : __result_and_scratch_storage_base
578578
}
579579

580580
public:
581-
__result_and_scratch_storage(const _ExecutionPolicy& __exec_, std::size_t __result_n, std::size_t __scratch_n)
581+
__result_and_scratch_storage_impl(const _ExecutionPolicy& __exec_, std::size_t __result_n, std::size_t __scratch_n)
582582
: __exec{__exec_}, __result_n{__result_n}, __scratch_n{__scratch_n},
583-
__use_USM_host{__use_USM_host_allocations(__exec.queue())}, __supports_USM_device{
584-
__use_USM_allocations(__exec.queue())}
583+
__use_USM_host{__use_USM_host_allocations(__exec.queue())},
584+
__supports_USM_device{__use_USM_allocations(__exec.queue())}
585585
{
586586
const std::size_t __total_n = __scratch_n + __result_n;
587587
// Skip in case this is a dummy container
@@ -724,6 +724,9 @@ struct __result_and_scratch_storage : __result_and_scratch_storage_base
724724
}
725725
};
726726

727+
template <typename _ExecutionPolicy, typename _T>
728+
using __result_and_scratch_storage = __result_and_scratch_storage_impl<std::decay_t<_ExecutionPolicy>, _T>;
729+
727730
// Tag __async_mode describe a pattern call mode which should be executed asynchronously
728731
struct __async_mode
729732
{
@@ -753,9 +756,12 @@ class __future : private std::tuple<_Args...>
753756
return __buf.get_host_access(sycl::read_only)[0];
754757
}
755758

756-
template <typename _ExecutionPolicy, typename _T>
759+
// Here we use __result_and_scratch_storage_impl rather than __result_and_scratch_storage because we need to
760+
// match the type with the overload and are deducing the policy type. If we used __result_and_scratch_storage,
761+
// it would cause issues in type deduction due to decay of the policy in that using statement.
762+
template <typename _DecayedExecutionPolicy, typename _T>
757763
constexpr auto
758-
__wait_and_get_value(const __result_and_scratch_storage<_ExecutionPolicy, _T>& __storage)
764+
__wait_and_get_value(const __result_and_scratch_storage_impl<_DecayedExecutionPolicy, _T>& __storage)
759765
{
760766
return __storage.__wait_and_get_value(__my_event);
761767
}

0 commit comments

Comments
 (0)