Skip to content

Commit bf990be

Browse files
@@@
1 parent aa2e5ef commit bf990be

File tree

3 files changed

+51
-31
lines changed

3 files changed

+51
-31
lines changed

include/oneapi/dpl/internal/scan_by_segment_impl.h

+21-15
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,24 @@ struct __sycl_scan_by_segment_submitter : protected __sycl_submitter_base<_Execu
119119
{
120120
friend __sycl_scan_by_segment_submitter_factory;
121121

122-
template <typename... _Name>
123-
using _SegScanWgPhase = __seg_scan_wg_kernel<__is_inclusive, _Name...>;
124-
125-
template <typename... _Name>
126-
using _SegScanPrefixPhase = __seg_scan_prefix_kernel<__is_inclusive, _Name...>;
122+
using _submitter_base = __sycl_submitter_base<_ExecutionPolicy>;
127123

124+
protected:
128125
template <typename _ExecutionPolicyCtor>
129126
__sycl_scan_by_segment_submitter(_ExecutionPolicyCtor&& __exec)
130127
: __sycl_submitter_base<_ExecutionPolicy>(std::forward<_ExecutionPolicyCtor>(__exec))
131128
{
132129
}
133130

131+
public:
132+
133+
template <typename... _Name>
134+
using _SegScanWgPhase = __seg_scan_wg_kernel<__is_inclusive, _Name...>;
135+
136+
template <typename... _Name>
137+
using _SegScanPrefixPhase = __seg_scan_prefix_kernel<__is_inclusive, _Name...>;
138+
139+
134140
template <typename _BackendTag, typename _Range1, typename _Range2, typename _Range3, typename _BinaryPredicate,
135141
typename _BinaryOperator, typename _T>
136142
void
@@ -153,34 +159,34 @@ struct __sycl_scan_by_segment_submitter : protected __sycl_submitter_base<_Execu
153159

154160
// Limit the work-group size to prevent large sizes on CPUs. Empirically found value.
155161
// This value exceeds the current practical limit for GPUs, but may need to be re-evaluated in the future.
156-
std::size_t __wgroup_size = oneapi::dpl::__internal::__max_work_group_size(__exec, (std::size_t)2048);
162+
std::size_t __wgroup_size = oneapi::dpl::__internal::__max_work_group_size(_submitter_base::__exec, (std::size_t)2048);
157163

158164
// We require 2 * sizeof(__val_type) * __wgroup_size of SLM for the work group segmented scan. We add
159165
// an additional sizeof(__val_type) * __wgroup_size requirement to ensure sufficient SLM for the group algorithms.
160166
__wgroup_size =
161-
oneapi::dpl::__internal::__slm_adjusted_work_group_size(__exec, 3 * sizeof(__val_type), __wgroup_size);
167+
oneapi::dpl::__internal::__slm_adjusted_work_group_size(_submitter_base::__exec, 3 * sizeof(__val_type), __wgroup_size);
162168

163169
#if _ONEDPL_COMPILE_KERNEL
164170
auto __seg_scan_wg_kernel =
165-
__par_backend_hetero::__internal::__kernel_compiler<_SegScanWgKernel>::__compile(__exec);
171+
__par_backend_hetero::__internal::__kernel_compiler<_SegScanWgKernel>::__compile(_submitter_base::__exec);
166172
auto __seg_scan_prefix_kernel =
167-
__par_backend_hetero::__internal::__kernel_compiler<_SegScanPrefixKernel>::__compile(__exec);
173+
__par_backend_hetero::__internal::__kernel_compiler<_SegScanPrefixKernel>::__compile(_submitter_base::__exec);
168174
__wgroup_size =
169-
::std::min({__wgroup_size, oneapi::dpl::__internal::__kernel_work_group_size(__exec, __seg_scan_wg_kernel),
170-
oneapi::dpl::__internal::__kernel_work_group_size(__exec, __seg_scan_prefix_kernel)});
175+
::std::min({__wgroup_size, oneapi::dpl::__internal::__kernel_work_group_size(_submitter_base::__exec, __seg_scan_wg_kernel),
176+
oneapi::dpl::__internal::__kernel_work_group_size(_submitter_base::__exec, __seg_scan_prefix_kernel)});
171177
#endif
172178

173179
::std::size_t __n_groups = __internal::__dpl_ceiling_div(__n, __wgroup_size * __vals_per_item);
174180

175181
auto __partials =
176-
oneapi::dpl::__par_backend_hetero::__buffer<_ExecutionPolicy, __val_type>(__exec, __n_groups).get_buffer();
182+
oneapi::dpl::__par_backend_hetero::__buffer<_ExecutionPolicy, __val_type>(_submitter_base::__exec, __n_groups).get_buffer();
177183

178184
// the number of segment ends found in each work group
179185
auto __seg_ends =
180-
oneapi::dpl::__par_backend_hetero::__buffer<_ExecutionPolicy, bool>(__exec, __n_groups).get_buffer();
186+
oneapi::dpl::__par_backend_hetero::__buffer<_ExecutionPolicy, bool>(_submitter_base::__exec, __n_groups).get_buffer();
181187

182188
// 1. Work group reduction
183-
auto __wg_scan = __exec.queue().submit([&](sycl::handler& __cgh) {
189+
auto __wg_scan = _submitter_base::__exec.queue().submit([&](sycl::handler& __cgh) {
184190
auto __partials_acc = __partials.template get_access<sycl::access_mode::write>(__cgh);
185191
auto __seg_ends_acc = __seg_ends.template get_access<sycl::access_mode::write>(__cgh);
186192

@@ -280,7 +286,7 @@ struct __sycl_scan_by_segment_submitter : protected __sycl_submitter_base<_Execu
280286
});
281287

282288
// 2. Apply work group carry outs, calculate output indices, and load results into correct indices.
283-
__exec.queue()
289+
_submitter_base::__exec.queue()
284290
.submit([&](sycl::handler& __cgh) {
285291
oneapi::dpl::__ranges::__require_access(__cgh, __keys, __out_values);
286292

include/oneapi/dpl/internal/sycl_submitter_base_impl.h

+8
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ struct __sycl_submitter_base
6060
: __exec(std::forward<_ExecutionPolicyCtor>(__exec))
6161
{
6262
}
63+
64+
public:
65+
66+
inline const _ExecutionPolicy&
67+
get_execution_policy() const
68+
{
69+
return __exec;
70+
}
6371
};
6472

6573
} // namespace internal

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h

+22-16
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,13 @@ class __scan_copy_single_wg_kernel;
225225
//------------------------------------------------------------------------
226226

227227
// Please see the comment above __parallel_for_small_submitter for optional kernel name explanation
228-
template <typename _PropagateScanName>
228+
template <typename _ExecutionPolicy, typename _PropagateScanName>
229229
struct __parallel_scan_submitter;
230230

231231
// Even if this class submits three kernel optional name is allowed to be only for one of them
232232
// because for two others we have to provide the name to get the reliable work group size
233-
template <typename... _PropagateScanName>
234-
struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateScanName...>>;
233+
template <typename _ExecutionPolicy, typename... _PropagateScanName>
234+
struct __parallel_scan_submitter<_ExecutionPolicy, __internal::__optional_kernel_name<_PropagateScanName...>>;
235235

236236
struct __parallel_scan_submitter_factory
237237
{
@@ -249,15 +249,21 @@ struct __parallel_scan_submitter_factory
249249
// Even if this class submits three kernel optional name is allowed to be only for one of them
250250
// because for two others we have to provide the name to get the reliable work group size
251251
template <typename _ExecutionPolicy, typename... _PropagateScanName>
252-
struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateScanName...>>
253-
: protected __sycl_submitter_base<_ExecutionPolicy>
252+
struct __parallel_scan_submitter<_ExecutionPolicy, __internal::__optional_kernel_name<_PropagateScanName...>>
253+
: protected internal::__sycl_submitter_base<_ExecutionPolicy>
254254
{
255+
friend __parallel_scan_submitter_factory;
256+
257+
using _submitter_base = internal::__sycl_submitter_base<_ExecutionPolicy>;
258+
259+
protected:
255260
template <typename _ExecutionPolicyCtor>
256261
__parallel_scan_submitter(_ExecutionPolicyCtor&& __exec)
257-
: __sycl_submitter_base<_ExecutionPolicy>(std::forward<_ExecutionPolicyCtor>(__exec))
262+
: internal::__sycl_submitter_base<_ExecutionPolicy>(std::forward<_ExecutionPolicyCtor>(__exec))
258263
{
259264
}
260265

266+
public:
261267
template <typename _Range1, typename _Range2, typename _InitType,
262268
typename _LocalScan, typename _GroupScan, typename _GlobalScan>
263269
auto
@@ -273,21 +279,21 @@ struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateSc
273279
auto __n = __rng1.size();
274280
assert(__n > 0);
275281

276-
auto __max_cu = oneapi::dpl::__internal::__max_compute_units(__exec);
282+
auto __max_cu = oneapi::dpl::__internal::__max_compute_units(_submitter_base::__exec);
277283
// get the work group size adjusted to the local memory limit
278284
// TODO: find a way to generalize getting of reliable work-group sizes
279-
::std::size_t __wgroup_size = oneapi::dpl::__internal::__slm_adjusted_work_group_size(__exec, sizeof(_Type));
285+
::std::size_t __wgroup_size = oneapi::dpl::__internal::__slm_adjusted_work_group_size(_submitter_base::__exec, sizeof(_Type));
280286
// Limit the work-group size to prevent large sizes on CPUs. Empirically found value.
281287
// This value matches the current practical limit for GPUs, but may need to be re-evaluated in the future.
282288
__wgroup_size = std::min(__wgroup_size, (std::size_t)1024);
283289

284290
#if _ONEDPL_COMPILE_KERNEL
285291
//Actually there is one kernel_bundle for the all kernels of the pattern.
286-
auto __kernels = __internal::__kernel_compiler<_LocalScanKernel, _GroupScanKernel>::__compile(__exec);
292+
auto __kernels = __internal::__kernel_compiler<_LocalScanKernel, _GroupScanKernel>::__compile(_submitter_base::__exec);
287293
auto __kernel_1 = __kernels[0];
288294
auto __kernel_2 = __kernels[1];
289-
auto __wgroup_size_kernel_1 = oneapi::dpl::__internal::__kernel_work_group_size(__exec, __kernel_1);
290-
auto __wgroup_size_kernel_2 = oneapi::dpl::__internal::__kernel_work_group_size(__exec, __kernel_2);
295+
auto __wgroup_size_kernel_1 = oneapi::dpl::__internal::__kernel_work_group_size(_submitter_base::__exec, __kernel_1);
296+
auto __wgroup_size_kernel_2 = oneapi::dpl::__internal::__kernel_work_group_size(_submitter_base::__exec, __kernel_2);
291297
__wgroup_size = ::std::min({__wgroup_size, __wgroup_size_kernel_1, __wgroup_size_kernel_2});
292298
#endif
293299

@@ -298,12 +304,12 @@ struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateSc
298304
// Storage for the results of scan for each workgroup
299305

300306
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _Type>;
301-
__result_and_scratch_storage_t __result_and_scratch{__exec, 1, __n_groups + 1};
307+
__result_and_scratch_storage_t __result_and_scratch{_submitter_base::__exec, 1, __n_groups + 1};
302308

303-
_PRINT_INFO_IN_DEBUG_MODE(__exec, __wgroup_size, __max_cu);
309+
_PRINT_INFO_IN_DEBUG_MODE(_submitter_base::__exec, __wgroup_size, __max_cu);
304310

305311
// 1. Local scan on each workgroup
306-
auto __submit_event = __exec.queue().submit([&](sycl::handler& __cgh) {
312+
auto __submit_event = _submitter_base::__exec.queue().submit([&](sycl::handler& __cgh) {
307313
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2); //get an access to data under SYCL buffer
308314
auto __temp_acc = __result_and_scratch.template __get_scratch_acc<sycl::access_mode::write>(
309315
__cgh, __dpl_sycl::__no_init{});
@@ -325,7 +331,7 @@ struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateSc
325331
if (__n_groups > 1)
326332
{
327333
auto __iters_per_single_wg = oneapi::dpl::__internal::__dpl_ceiling_div(__n_groups, __wgroup_size);
328-
__submit_event = __exec.queue().submit([&](sycl::handler& __cgh) {
334+
__submit_event = _submitter_base::__exec.queue().submit([&](sycl::handler& __cgh) {
329335
__cgh.depends_on(__submit_event);
330336
auto __temp_acc = __result_and_scratch.template __get_scratch_acc<sycl::access_mode::read_write>(__cgh);
331337
__dpl_sycl::__local_accessor<_Type> __local_acc(__wgroup_size, __cgh);
@@ -346,7 +352,7 @@ struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateSc
346352
}
347353

348354
// 3. Final scan for whole range
349-
auto __final_event = __exec.queue().submit([&](sycl::handler& __cgh) {
355+
auto __final_event = _submitter_base::__exec.queue().submit([&](sycl::handler& __cgh) {
350356
__cgh.depends_on(__submit_event);
351357
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2); //get an access to data under SYCL buffer
352358
auto __temp_acc = __result_and_scratch.template __get_scratch_acc<sycl::access_mode::read>(__cgh);

0 commit comments

Comments
 (0)