19
19
#include < limits> // std::numeric_limits
20
20
#include < cassert> // assert
21
21
#include < cstdint> // std::uint8_t, ...
22
- #include < utility> // std::make_pair, std::forward
22
+ #include < utility> // std::make_pair, std::forward, std::declval
23
23
#include < algorithm> // std::min, std::lower_bound
24
+ #include < type_traits> // std::void_t, std::true_type, std::false_type
24
25
25
26
#include " sycl_defs.h"
26
27
#include " parallel_backend_sycl_utils.h"
@@ -130,6 +131,21 @@ __find_start_point(const _Rng1& __rng1, const _Index __rng1_from, _Index __rng1_
130
131
return _split_point_t <_Index>{*__res, __index_sum - *__res + 1 };
131
132
}
132
133
134
+ template <typename _Rng1DataType, typename _Rng2DataType, typename = void >
135
+ struct __can_use_ternary_op : std::false_type
136
+ {
137
+ };
138
+
139
+ template <typename _Rng1DataType, typename _Rng2DataType>
140
+ struct __can_use_ternary_op <_Rng1DataType, _Rng2DataType,
141
+ std::void_t <decltype(true ? std::declval<_Rng1DataType>() : std::declval<_Rng2DataType>())>>
142
+ : std::true_type
143
+ {
144
+ };
145
+
146
+ template <typename _Rng1DataType, typename _Rng2DataType>
147
+ constexpr static bool __can_use_ternary_op_v = __can_use_ternary_op<_Rng1DataType, _Rng2DataType>::value;
148
+
133
149
// Do serial merge of the data from rng1 (starting from start1) and rng2 (starting from start2) and writing
134
150
// to rng3 (starting from start3) in 'chunk' steps, but do not exceed the total size of the sequences (n1 and n2)
135
151
template <typename _Rng1, typename _Rng2, typename _Rng3, typename _Index, typename _Compare>
@@ -156,11 +172,23 @@ __serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, const _I
156
172
// One of __rng1_idx_less_n1 and __rng2_idx_less_n2 should be true here
157
173
// because 1) we should fill output data with elements from one of the input ranges
158
174
// 2) we calculate __rng3_idx_end as std::min<_Index>(__rng1_size + __rng2_size, __chunk).
159
- if (__rng1_idx_less_n1 && __rng2_idx_less_n2 && __comp (__rng2[__rng2_idx], __rng1[__rng1_idx]) ||
160
- !__rng1_idx_less_n1)
161
- __rng3[__rng3_idx] = __rng2[__rng2_idx++];
175
+ if constexpr (__can_use_ternary_op_v<decltype (__rng1[__rng1_idx]), decltype (__rng2[__rng2_idx])>)
176
+ {
177
+ // This implementation is required for performance optimization
178
+ __rng3[__rng3_idx] = (!__rng1_idx_less_n1 || __rng1_idx_less_n1 && __rng2_idx_less_n2 &&
179
+ __comp (__rng2[__rng2_idx], __rng1[__rng1_idx]))
180
+ ? __rng2[__rng2_idx++]
181
+ : __rng1[__rng1_idx++];
182
+ }
162
183
else
163
- __rng3[__rng3_idx] = __rng1[__rng1_idx++];
184
+ {
185
+ // TODO required to understand why the usual if-else is slower then ternary operator
186
+ if (!__rng1_idx_less_n1 ||
187
+ __rng1_idx_less_n1 && __rng2_idx_less_n2 && __comp (__rng2[__rng2_idx], __rng1[__rng1_idx]))
188
+ __rng3[__rng3_idx] = __rng2[__rng2_idx++];
189
+ else
190
+ __rng3[__rng3_idx] = __rng1[__rng1_idx++];
191
+ }
164
192
}
165
193
}
166
194
0 commit comments