Skip to content

Commit c3a8c90

Browse files
include/oneapi/dpl/internal/scan_by_segment_impl.h - __sycl_scan_by_segment_submitter_factory for __sycl_scan_by_segment_submitter (__sycl_scan_by_segment_impl)
1 parent a142f1d commit c3a8c90

File tree

1 file changed

+45
-10
lines changed

1 file changed

+45
-10
lines changed

include/oneapi/dpl/internal/scan_by_segment_impl.h

+45-10
Original file line numberDiff line numberDiff line change
@@ -96,27 +96,61 @@ class __seg_scan_wg_kernel;
9696
template <bool __is_inclusive, typename... Name>
9797
class __seg_scan_prefix_kernel;
9898

99-
template <bool __is_inclusive>
100-
struct __sycl_scan_by_segment_impl
99+
template <bool __is_inclusive, typename _ExecutionPolicy>
100+
struct __sycl_scan_by_segment_submitter;
101+
102+
struct __sycl_scan_by_segment_submitter_factory
103+
{
104+
template <bool __is_inclusive, typename _ExecutionPolicy>
105+
static auto
106+
create(_ExecutionPolicy&& __exec)
107+
{
108+
//using _ExecutionPolicyCtor = std::remove_cv_t<std::remove_reference_t<std::decay_t<_ExecutionPolicy>>>;
109+
using _ExecutionPolicyCtor = std::decay_t<_ExecutionPolicy>;
110+
static_assert(std::is_same_v<_ExecutionPolicyCtor, std::remove_cv_t<std::remove_reference_t<std::decay_t<_ExecutionPolicy>>>>);
111+
112+
return __sycl_scan_by_segment_submitter<__is_inclusive, _ExecutionPolicyCtor>{std::forward<_ExecutionPolicy>(__exec)};
113+
}
114+
};
115+
116+
template <bool __is_inclusive, typename _ExecutionPolicy>
117+
struct __sycl_scan_by_segment_submitter
101118
{
119+
// We should instantiate this submitter only for cleared _ExecutionPolicy type
120+
static_assert(std::is_same_v<_ExecutionPolicy, std::decay_t<_ExecutionPolicy>>);
121+
122+
friend __sycl_scan_by_segment_submitter_factory;
123+
124+
protected:
125+
126+
_ExecutionPolicy __exec;
127+
128+
template <typename _ExecutionPolicyCtor>
129+
__sycl_scan_by_segment_submitter(_ExecutionPolicyCtor&& __exec)
130+
{
131+
__exec = std::forward<_ExecutionPolicyCtor>(__exec);
132+
}
133+
134+
public:
135+
102136
template <typename... _Name>
103137
using _SegScanWgPhase = __seg_scan_wg_kernel<__is_inclusive, _Name...>;
104138

105139
template <typename... _Name>
106140
using _SegScanPrefixPhase = __seg_scan_prefix_kernel<__is_inclusive, _Name...>;
107141

108-
template <typename _BackendTag, typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3,
109-
typename _BinaryPredicate, typename _BinaryOperator, typename _T>
142+
template <typename _BackendTag, typename _Range1, typename _Range2, typename _Range3, typename _BinaryPredicate,
143+
typename _BinaryOperator, typename _T>
110144
void
111-
operator()(_BackendTag, _ExecutionPolicy&& __exec, _Range1&& __keys, _Range2&& __values, _Range3&& __out_values,
145+
operator()(_BackendTag, _Range1&& __keys, _Range2&& __values, _Range3&& __out_values,
112146
_BinaryPredicate __binary_pred, _BinaryOperator __binary_op, _T __init, _T __identity)
113147
{
114148
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
115149

116-
using _SegScanWgKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< // KSATODO: __kernel_name_generator with _ExecutionPolicy - __sycl_scan_by_segment_impl
150+
using _SegScanWgKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< // KSATODO: __kernel_name_generator with _ExecutionPolicy - __sycl_scan_by_segment_impl, __sycl_scan_by_segment_submitter
117151
_SegScanWgPhase, _CustomName, _ExecutionPolicy, _Range1, _Range2, _Range3, _BinaryPredicate,
118152
_BinaryOperator>;
119-
using _SegScanPrefixKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< // KSATODO: __kernel_name_generator with _ExecutionPolicy - __sycl_scan_by_segment_impl
153+
using _SegScanPrefixKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< // KSATODO: __kernel_name_generator with _ExecutionPolicy - __sycl_scan_by_segment_impl, __sycl_scan_by_segment_submitter
120154
_SegScanPrefixPhase, _CustomName, _ExecutionPolicy, _Range1, _Range2, _Range3, _BinaryPredicate,
121155
_BinaryOperator>;
122156

@@ -394,9 +428,10 @@ __scan_by_segment_impl_common(__internal::__hetero_tag<_BackendTag>, Policy&& po
394428

395429
constexpr iter_value_t identity = unseq_backend::__known_identity<Operator, iter_value_t>;
396430

397-
__sycl_scan_by_segment_impl<Inclusive::value>()(_BackendTag{}, ::std::forward<Policy>(policy), key_buf.all_view(),
398-
value_buf.all_view(), value_output_buf.all_view(), binary_pred,
399-
binary_op, init, identity);
431+
__sycl_scan_by_segment_submitter_factory::create<Inclusive::value, Policy>(policy)(
432+
_BackendTag{}, key_buf.all_view(), value_buf.all_view(), value_output_buf.all_view(), binary_pred, binary_op,
433+
init, identity);
434+
400435
return result + n;
401436
}
402437

0 commit comments

Comments
 (0)