Skip to content

Commit 83c3741

Browse files
authored
Re-implement SYCL backend parallel_for to improve bandwidth utilization (#1976)
Signed-off-by: Matthew Michel <[email protected]>
1 parent 7bbaf83 commit 83c3741

35 files changed

+1376
-217
lines changed

include/oneapi/dpl/experimental/kt/internal/esimd_radix_sort_submitters.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace oneapi::dpl::experimental::kt::gpu::esimd::__impl
2727
{
2828

2929
//------------------------------------------------------------------------
30-
// Please see the comment for __parallel_for_submitter for optional kernel name explanation
30+
// Please see the comment above __parallel_for_small_submitter for optional kernel name explanation
3131
//------------------------------------------------------------------------
3232

3333
template <bool __is_ascending, ::std::uint8_t __radix_bits, ::std::uint16_t __data_per_work_item,

include/oneapi/dpl/internal/async_impl/async_impl_hetero.h

+12-6
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ __pattern_walk1_async(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _For
4444

4545
auto __future_obj = oneapi::dpl::__par_backend_hetero::__parallel_for(
4646
_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
47-
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n, __buf.all_view());
47+
unseq_backend::walk1_vector_or_scalar<_ExecutionPolicy, _Function, decltype(__buf.all_view())>{
48+
__f, static_cast<std::size_t>(__n)},
49+
__n, __buf.all_view());
4850
return __future_obj;
4951
}
5052

@@ -67,7 +69,9 @@ __pattern_walk2_async(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _For
6769

6870
auto __future = oneapi::dpl::__par_backend_hetero::__parallel_for(
6971
_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
70-
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n, __buf1.all_view(), __buf2.all_view());
72+
unseq_backend::walk2_vectors_or_scalars<_ExecutionPolicy, _Function, decltype(__buf1.all_view()),
73+
decltype(__buf2.all_view())>{__f, static_cast<std::size_t>(__n)},
74+
__n, __buf1.all_view(), __buf2.all_view());
7175

7276
return __future.__make_future(__first2 + __n);
7377
}
@@ -91,10 +95,12 @@ __pattern_walk3_async(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _For
9195
oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, _ForwardIterator3>();
9296
auto __buf3 = __keep3(__first3, __first3 + __n);
9397

94-
auto __future =
95-
oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
96-
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n,
97-
__buf1.all_view(), __buf2.all_view(), __buf3.all_view());
98+
auto __future = oneapi::dpl::__par_backend_hetero::__parallel_for(
99+
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec),
100+
unseq_backend::walk3_vectors_or_scalars<_ExecutionPolicy, _Function, decltype(__buf1.all_view()),
101+
decltype(__buf2.all_view()), decltype(__buf3.all_view())>{
102+
__f, static_cast<size_t>(__n)},
103+
__n, __buf1.all_view(), __buf2.all_view(), __buf3.all_view());
98104

99105
return __future.__make_future(__first3 + __n);
100106
}

include/oneapi/dpl/internal/binary_search_impl.h

+26-11
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,19 @@ enum class search_algorithm
3737
binary_search
3838
};
3939

40-
template <typename Comp, typename T, search_algorithm func>
41-
struct custom_brick
40+
#if _ONEDPL_BACKEND_SYCL
41+
template <typename Comp, typename T, typename _Range, search_algorithm func>
42+
struct __custom_brick : oneapi::dpl::unseq_backend::walk_scalar_base<_Range>
4243
{
4344
Comp comp;
4445
T size;
4546
bool use_32bit_indexing;
4647

48+
__custom_brick(Comp comp, T size, bool use_32bit_indexing)
49+
: comp(std::move(comp)), size(size), use_32bit_indexing(use_32bit_indexing)
50+
{
51+
}
52+
4753
template <typename _Size, typename _ItemId, typename _Acc>
4854
void
4955
search_impl(_ItemId idx, _Acc acc) const
@@ -68,17 +74,23 @@ struct custom_brick
6874
get<2>(acc[idx]) = (value != end_orig) && (get<1>(acc[idx]) == get<0>(acc[value]));
6975
}
7076
}
71-
72-
template <typename _ItemId, typename _Acc>
77+
template <typename _IsFull, typename _ItemId, typename _Acc>
7378
void
74-
operator()(_ItemId idx, _Acc acc) const
79+
__scalar_path_impl(_IsFull, _ItemId idx, _Acc acc) const
7580
{
7681
if (use_32bit_indexing)
7782
search_impl<std::uint32_t>(idx, acc);
7883
else
7984
search_impl<std::uint64_t>(idx, acc);
8085
}
86+
template <typename _IsFull, typename _ItemId, typename _Acc>
87+
void
88+
operator()(_IsFull __is_full, _ItemId idx, _Acc acc) const
89+
{
90+
__scalar_path_impl(__is_full, idx, acc);
91+
}
8192
};
93+
#endif
8294

8395
template <class _Tag, typename Policy, typename InputIterator1, typename InputIterator2, typename OutputIterator,
8496
typename StrictWeakOrdering>
@@ -155,7 +167,8 @@ lower_bound_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, InputIt
155167
const bool use_32bit_indexing = size <= std::numeric_limits<std::uint32_t>::max();
156168
__bknd::__parallel_for(
157169
_BackendTag{}, ::std::forward<decltype(policy)>(policy),
158-
custom_brick<StrictWeakOrdering, decltype(size), search_algorithm::lower_bound>{comp, size, use_32bit_indexing},
170+
__custom_brick<StrictWeakOrdering, decltype(size), decltype(zip_vw), search_algorithm::lower_bound>{
171+
comp, size, use_32bit_indexing},
159172
value_size, zip_vw)
160173
.__deferrable_wait();
161174
return result + value_size;
@@ -187,7 +200,8 @@ upper_bound_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, InputIt
187200
const bool use_32bit_indexing = size <= std::numeric_limits<std::uint32_t>::max();
188201
__bknd::__parallel_for(
189202
_BackendTag{}, std::forward<decltype(policy)>(policy),
190-
custom_brick<StrictWeakOrdering, decltype(size), search_algorithm::upper_bound>{comp, size, use_32bit_indexing},
203+
__custom_brick<StrictWeakOrdering, decltype(size), decltype(zip_vw), search_algorithm::upper_bound>{
204+
comp, size, use_32bit_indexing},
191205
value_size, zip_vw)
192206
.__deferrable_wait();
193207
return result + value_size;
@@ -217,10 +231,11 @@ binary_search_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, Input
217231
auto result_buf = keep_result(result, result + value_size);
218232
auto zip_vw = make_zip_view(input_buf.all_view(), value_buf.all_view(), result_buf.all_view());
219233
const bool use_32bit_indexing = size <= std::numeric_limits<std::uint32_t>::max();
220-
__bknd::__parallel_for(_BackendTag{}, std::forward<decltype(policy)>(policy),
221-
custom_brick<StrictWeakOrdering, decltype(size), search_algorithm::binary_search>{
222-
comp, size, use_32bit_indexing},
223-
value_size, zip_vw)
234+
__bknd::__parallel_for(
235+
_BackendTag{}, std::forward<decltype(policy)>(policy),
236+
__custom_brick<StrictWeakOrdering, decltype(size), decltype(zip_vw), search_algorithm::binary_search>{
237+
comp, size, use_32bit_indexing},
238+
value_size, zip_vw)
224239
.__deferrable_wait();
225240
return result + value_size;
226241
}

0 commit comments

Comments
 (0)