Skip to content

Commit 2a3a0b8

Browse files
[oneDPL] Fix performance issue in __serial_merge (#2022)
1 parent 1746d46 commit 2a3a0b8

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h

+33-5
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
#include <limits> // std::numeric_limits
2020
#include <cassert> // assert
2121
#include <cstdint> // std::uint8_t, ...
22-
#include <utility> // std::make_pair, std::forward
22+
#include <utility> // std::make_pair, std::forward, std::declval
2323
#include <algorithm> // std::min, std::lower_bound
24+
#include <type_traits> // std::void_t, std::true_type, std::false_type
2425

2526
#include "sycl_defs.h"
2627
#include "parallel_backend_sycl_utils.h"
@@ -130,6 +131,21 @@ __find_start_point(const _Rng1& __rng1, const _Index __rng1_from, _Index __rng1_
130131
return _split_point_t<_Index>{*__res, __index_sum - *__res + 1};
131132
}
132133

134+
template <typename _Rng1DataType, typename _Rng2DataType, typename = void>
135+
struct __can_use_ternary_op : std::false_type
136+
{
137+
};
138+
139+
template <typename _Rng1DataType, typename _Rng2DataType>
140+
struct __can_use_ternary_op<_Rng1DataType, _Rng2DataType,
141+
std::void_t<decltype(true ? std::declval<_Rng1DataType>() : std::declval<_Rng2DataType>())>>
142+
: std::true_type
143+
{
144+
};
145+
146+
template <typename _Rng1DataType, typename _Rng2DataType>
147+
constexpr static bool __can_use_ternary_op_v = __can_use_ternary_op<_Rng1DataType, _Rng2DataType>::value;
148+
133149
// Do serial merge of the data from rng1 (starting from start1) and rng2 (starting from start2) and writing
134150
// to rng3 (starting from start3) in 'chunk' steps, but do not exceed the total size of the sequences (n1 and n2)
135151
template <typename _Rng1, typename _Rng2, typename _Rng3, typename _Index, typename _Compare>
@@ -156,11 +172,23 @@ __serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, const _I
156172
// One of __rng1_idx_less_n1 and __rng2_idx_less_n2 should be true here
157173
// because 1) we should fill output data with elements from one of the input ranges
158174
// 2) we calculate __rng3_idx_end as std::min<_Index>(__rng1_size + __rng2_size, __chunk).
159-
if (__rng1_idx_less_n1 && __rng2_idx_less_n2 && __comp(__rng2[__rng2_idx], __rng1[__rng1_idx]) ||
160-
!__rng1_idx_less_n1)
161-
__rng3[__rng3_idx] = __rng2[__rng2_idx++];
175+
if constexpr (__can_use_ternary_op_v<decltype(__rng1[__rng1_idx]), decltype(__rng2[__rng2_idx])>)
176+
{
177+
// This implementation is required for performance optimization
178+
__rng3[__rng3_idx] = (!__rng1_idx_less_n1 || __rng1_idx_less_n1 && __rng2_idx_less_n2 &&
179+
__comp(__rng2[__rng2_idx], __rng1[__rng1_idx]))
180+
? __rng2[__rng2_idx++]
181+
: __rng1[__rng1_idx++];
182+
}
162183
else
163-
__rng3[__rng3_idx] = __rng1[__rng1_idx++];
184+
{
185+
// TODO required to understand why the usual if-else is slower then ternary operator
186+
if (!__rng1_idx_less_n1 ||
187+
__rng1_idx_less_n1 && __rng2_idx_less_n2 && __comp(__rng2[__rng2_idx], __rng1[__rng1_idx]))
188+
__rng3[__rng3_idx] = __rng2[__rng2_idx++];
189+
else
190+
__rng3[__rng3_idx] = __rng1[__rng1_idx++];
191+
}
164192
}
165193
}
166194

0 commit comments

Comments
 (0)