@@ -225,13 +225,13 @@ class __scan_copy_single_wg_kernel;
225
225
// ------------------------------------------------------------------------
226
226
227
227
// Please see the comment above __parallel_for_small_submitter for optional kernel name explanation
228
- template <typename _PropagateScanName>
228
+ template <typename _ExecutionPolicy, typename _PropagateScanName>
229
229
struct __parallel_scan_submitter ;
230
230
231
231
// Even if this class submits three kernel optional name is allowed to be only for one of them
232
232
// because for two others we have to provide the name to get the reliable work group size
233
- template <typename ... _PropagateScanName>
234
- struct __parallel_scan_submitter <__internal::__optional_kernel_name<_PropagateScanName...>>;
233
+ template <typename _ExecutionPolicy, typename ... _PropagateScanName>
234
+ struct __parallel_scan_submitter <_ExecutionPolicy, __internal::__optional_kernel_name<_PropagateScanName...>>;
235
235
236
236
struct __parallel_scan_submitter_factory
237
237
{
@@ -249,15 +249,21 @@ struct __parallel_scan_submitter_factory
249
249
// Even if this class submits three kernel optional name is allowed to be only for one of them
250
250
// because for two others we have to provide the name to get the reliable work group size
251
251
template <typename _ExecutionPolicy, typename ... _PropagateScanName>
252
- struct __parallel_scan_submitter <__internal::__optional_kernel_name<_PropagateScanName...>>
253
- : protected __sycl_submitter_base<_ExecutionPolicy>
252
+ struct __parallel_scan_submitter <_ExecutionPolicy, __internal::__optional_kernel_name<_PropagateScanName...>>
253
+ : protected internal:: __sycl_submitter_base<_ExecutionPolicy>
254
254
{
255
+ friend __parallel_scan_submitter_factory;
256
+
257
+ using _submitter_base = internal::__sycl_submitter_base<_ExecutionPolicy>;
258
+
259
+ protected:
255
260
template <typename _ExecutionPolicyCtor>
256
261
__parallel_scan_submitter (_ExecutionPolicyCtor&& __exec)
257
- : __sycl_submitter_base<_ExecutionPolicy>(std::forward<_ExecutionPolicyCtor>(__exec))
262
+ : internal:: __sycl_submitter_base<_ExecutionPolicy>(std::forward<_ExecutionPolicyCtor>(__exec))
258
263
{
259
264
}
260
265
266
+ public:
261
267
template <typename _Range1, typename _Range2, typename _InitType,
262
268
typename _LocalScan, typename _GroupScan, typename _GlobalScan>
263
269
auto
@@ -273,21 +279,21 @@ struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateSc
273
279
auto __n = __rng1.size ();
274
280
assert (__n > 0 );
275
281
276
- auto __max_cu = oneapi::dpl::__internal::__max_compute_units (__exec);
282
+ auto __max_cu = oneapi::dpl::__internal::__max_compute_units (_submitter_base:: __exec);
277
283
// get the work group size adjusted to the local memory limit
278
284
// TODO: find a way to generalize getting of reliable work-group sizes
279
- ::std::size_t __wgroup_size = oneapi::dpl::__internal::__slm_adjusted_work_group_size (__exec, sizeof (_Type));
285
+ ::std::size_t __wgroup_size = oneapi::dpl::__internal::__slm_adjusted_work_group_size (_submitter_base:: __exec, sizeof (_Type));
280
286
// Limit the work-group size to prevent large sizes on CPUs. Empirically found value.
281
287
// This value matches the current practical limit for GPUs, but may need to be re-evaluated in the future.
282
288
__wgroup_size = std::min (__wgroup_size, (std::size_t )1024 );
283
289
284
290
#if _ONEDPL_COMPILE_KERNEL
285
291
// Actually there is one kernel_bundle for the all kernels of the pattern.
286
- auto __kernels = __internal::__kernel_compiler<_LocalScanKernel, _GroupScanKernel>::__compile (__exec);
292
+ auto __kernels = __internal::__kernel_compiler<_LocalScanKernel, _GroupScanKernel>::__compile (_submitter_base:: __exec);
287
293
auto __kernel_1 = __kernels[0 ];
288
294
auto __kernel_2 = __kernels[1 ];
289
- auto __wgroup_size_kernel_1 = oneapi::dpl::__internal::__kernel_work_group_size (__exec, __kernel_1);
290
- auto __wgroup_size_kernel_2 = oneapi::dpl::__internal::__kernel_work_group_size (__exec, __kernel_2);
295
+ auto __wgroup_size_kernel_1 = oneapi::dpl::__internal::__kernel_work_group_size (_submitter_base:: __exec, __kernel_1);
296
+ auto __wgroup_size_kernel_2 = oneapi::dpl::__internal::__kernel_work_group_size (_submitter_base:: __exec, __kernel_2);
291
297
__wgroup_size = ::std::min ({__wgroup_size, __wgroup_size_kernel_1, __wgroup_size_kernel_2});
292
298
#endif
293
299
@@ -298,12 +304,12 @@ struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateSc
298
304
// Storage for the results of scan for each workgroup
299
305
300
306
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _Type>;
301
- __result_and_scratch_storage_t __result_and_scratch{__exec, 1 , __n_groups + 1 };
307
+ __result_and_scratch_storage_t __result_and_scratch{_submitter_base:: __exec, 1 , __n_groups + 1 };
302
308
303
- _PRINT_INFO_IN_DEBUG_MODE (__exec, __wgroup_size, __max_cu);
309
+ _PRINT_INFO_IN_DEBUG_MODE (_submitter_base:: __exec, __wgroup_size, __max_cu);
304
310
305
311
// 1. Local scan on each workgroup
306
- auto __submit_event = __exec.queue ().submit ([&](sycl::handler& __cgh) {
312
+ auto __submit_event = _submitter_base:: __exec.queue ().submit ([&](sycl::handler& __cgh) {
307
313
oneapi::dpl::__ranges::__require_access (__cgh, __rng1, __rng2); // get an access to data under SYCL buffer
308
314
auto __temp_acc = __result_and_scratch.template __get_scratch_acc <sycl::access_mode::write >(
309
315
__cgh, __dpl_sycl::__no_init{});
@@ -325,7 +331,7 @@ struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateSc
325
331
if (__n_groups > 1 )
326
332
{
327
333
auto __iters_per_single_wg = oneapi::dpl::__internal::__dpl_ceiling_div (__n_groups, __wgroup_size);
328
- __submit_event = __exec.queue ().submit ([&](sycl::handler& __cgh) {
334
+ __submit_event = _submitter_base:: __exec.queue ().submit ([&](sycl::handler& __cgh) {
329
335
__cgh.depends_on (__submit_event);
330
336
auto __temp_acc = __result_and_scratch.template __get_scratch_acc <sycl::access_mode::read_write>(__cgh);
331
337
__dpl_sycl::__local_accessor<_Type> __local_acc (__wgroup_size, __cgh);
@@ -346,7 +352,7 @@ struct __parallel_scan_submitter<__internal::__optional_kernel_name<_PropagateSc
346
352
}
347
353
348
354
// 3. Final scan for whole range
349
- auto __final_event = __exec.queue ().submit ([&](sycl::handler& __cgh) {
355
+ auto __final_event = _submitter_base:: __exec.queue ().submit ([&](sycl::handler& __cgh) {
350
356
__cgh.depends_on (__submit_event);
351
357
oneapi::dpl::__ranges::__require_access (__cgh, __rng1, __rng2); // get an access to data under SYCL buffer
352
358
auto __temp_acc = __result_and_scratch.template __get_scratch_acc <sycl::access_mode::read >(__cgh);
0 commit comments