@@ -96,27 +96,61 @@ class __seg_scan_wg_kernel;
96
96
template <bool __is_inclusive, typename ... Name>
97
97
class __seg_scan_prefix_kernel ;
98
98
99
- template <bool __is_inclusive>
100
- struct __sycl_scan_by_segment_impl
99
+ template <bool __is_inclusive, typename _ExecutionPolicy>
100
+ struct __sycl_scan_by_segment_submitter ;
101
+
102
+ struct __sycl_scan_by_segment_submitter_factory
103
+ {
104
+ template <bool __is_inclusive, typename _ExecutionPolicy>
105
+ static auto
106
+ create (_ExecutionPolicy&& __exec)
107
+ {
108
+ // using _ExecutionPolicyCtor = std::remove_cv_t<std::remove_reference_t<std::decay_t<_ExecutionPolicy>>>;
109
+ using _ExecutionPolicyCtor = std::decay_t <_ExecutionPolicy>;
110
+ static_assert (std::is_same_v<_ExecutionPolicyCtor, std::remove_cv_t <std::remove_reference_t <std::decay_t <_ExecutionPolicy>>>>);
111
+
112
+ return __sycl_scan_by_segment_submitter<__is_inclusive, _ExecutionPolicyCtor>{std::forward<_ExecutionPolicy>(__exec)};
113
+ }
114
+ };
115
+
116
+ template <bool __is_inclusive, typename _ExecutionPolicy>
117
+ struct __sycl_scan_by_segment_submitter
101
118
{
119
+ // We should instantiate this submitter only for cleared _ExecutionPolicy type
120
+ static_assert (std::is_same_v<_ExecutionPolicy, std::decay_t <_ExecutionPolicy>>);
121
+
122
+ friend __sycl_scan_by_segment_submitter_factory;
123
+
124
+ protected:
125
+
126
+ _ExecutionPolicy __exec;
127
+
128
+ template <typename _ExecutionPolicyCtor>
129
+ __sycl_scan_by_segment_submitter (_ExecutionPolicyCtor&& __exec)
130
+ {
131
+ __exec = std::forward<_ExecutionPolicyCtor>(__exec);
132
+ }
133
+
134
+ public:
135
+
102
136
template <typename ... _Name>
103
137
using _SegScanWgPhase = __seg_scan_wg_kernel<__is_inclusive, _Name...>;
104
138
105
139
template <typename ... _Name>
106
140
using _SegScanPrefixPhase = __seg_scan_prefix_kernel<__is_inclusive, _Name...>;
107
141
108
- template <typename _BackendTag, typename _ExecutionPolicy , typename _Range1 , typename _Range2 , typename _Range3 ,
109
- typename _BinaryPredicate, typename _BinaryOperator, typename _T>
142
+ template <typename _BackendTag, typename _Range1 , typename _Range2 , typename _Range3 , typename _BinaryPredicate ,
143
+ typename _BinaryOperator, typename _T>
110
144
void
111
- operator ()(_BackendTag, _ExecutionPolicy&& __exec, _Range1&& __keys, _Range2&& __values, _Range3&& __out_values,
145
+ operator ()(_BackendTag, _Range1&& __keys, _Range2&& __values, _Range3&& __out_values,
112
146
_BinaryPredicate __binary_pred, _BinaryOperator __binary_op, _T __init, _T __identity)
113
147
{
114
148
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
115
149
116
- using _SegScanWgKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< // KSATODO: __kernel_name_generator with _ExecutionPolicy - __sycl_scan_by_segment_impl
150
+ using _SegScanWgKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< // KSATODO: __kernel_name_generator with _ExecutionPolicy - __sycl_scan_by_segment_impl, __sycl_scan_by_segment_submitter
117
151
_SegScanWgPhase, _CustomName, _ExecutionPolicy, _Range1, _Range2, _Range3, _BinaryPredicate,
118
152
_BinaryOperator>;
119
- using _SegScanPrefixKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< // KSATODO: __kernel_name_generator with _ExecutionPolicy - __sycl_scan_by_segment_impl
153
+ using _SegScanPrefixKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< // KSATODO: __kernel_name_generator with _ExecutionPolicy - __sycl_scan_by_segment_impl, __sycl_scan_by_segment_submitter
120
154
_SegScanPrefixPhase, _CustomName, _ExecutionPolicy, _Range1, _Range2, _Range3, _BinaryPredicate,
121
155
_BinaryOperator>;
122
156
@@ -394,9 +428,11 @@ __scan_by_segment_impl_common(__internal::__hetero_tag<_BackendTag>, Policy&& po
394
428
395
429
constexpr iter_value_t identity = unseq_backend::__known_identity<Operator, iter_value_t >;
396
430
397
- __sycl_scan_by_segment_impl<Inclusive::value>()(_BackendTag{}, ::std::forward<Policy>(policy), key_buf.all_view (),
398
- value_buf.all_view (), value_output_buf.all_view (), binary_pred,
399
- binary_op, init, identity);
431
+ // TODO spezialisation of Policy type in template params is not required here
432
+ __sycl_scan_by_segment_submitter_factory::create<Inclusive::value/* , Policy*/ >(policy)(
433
+ _BackendTag{}, key_buf.all_view (), value_buf.all_view (), value_output_buf.all_view (), binary_pred, binary_op,
434
+ init, identity);
435
+
400
436
return result + n;
401
437
}
402
438
0 commit comments