@@ -37,13 +37,19 @@ enum class search_algorithm
37
37
binary_search
38
38
};
39
39
40
- template <typename Comp, typename T, search_algorithm func>
41
- struct custom_brick
40
+ #if _ONEDPL_BACKEND_SYCL
41
+ template <typename Comp, typename T, typename _Range, search_algorithm func>
42
+ struct __custom_brick : oneapi::dpl::unseq_backend::walk_scalar_base<_Range>
42
43
{
43
44
Comp comp;
44
45
T size;
45
46
bool use_32bit_indexing;
46
47
48
+ __custom_brick (Comp comp, T size, bool use_32bit_indexing)
49
+ : comp(std::move(comp)), size(size), use_32bit_indexing(use_32bit_indexing)
50
+ {
51
+ }
52
+
47
53
template <typename _Size, typename _ItemId, typename _Acc>
48
54
void
49
55
search_impl (_ItemId idx, _Acc acc) const
@@ -68,17 +74,23 @@ struct custom_brick
68
74
get<2 >(acc[idx]) = (value != end_orig) && (get<1 >(acc[idx]) == get<0 >(acc[value]));
69
75
}
70
76
}
71
-
72
- template <typename _ItemId, typename _Acc>
77
+ template <typename _IsFull, typename _ItemId, typename _Acc>
73
78
void
74
- operator ()( _ItemId idx, _Acc acc) const
79
+ __scalar_path_impl (_IsFull, _ItemId idx, _Acc acc) const
75
80
{
76
81
if (use_32bit_indexing)
77
82
search_impl<std::uint32_t >(idx, acc);
78
83
else
79
84
search_impl<std::uint64_t >(idx, acc);
80
85
}
86
+ template <typename _IsFull, typename _ItemId, typename _Acc>
87
+ void
88
+ operator ()(_IsFull __is_full, _ItemId idx, _Acc acc) const
89
+ {
90
+ __scalar_path_impl (__is_full, idx, acc);
91
+ }
81
92
};
93
+ #endif
82
94
83
95
template <class _Tag , typename Policy, typename InputIterator1, typename InputIterator2, typename OutputIterator,
84
96
typename StrictWeakOrdering>
@@ -155,7 +167,8 @@ lower_bound_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, InputIt
155
167
const bool use_32bit_indexing = size <= std::numeric_limits<std::uint32_t >::max ();
156
168
__bknd::__parallel_for (
157
169
_BackendTag{}, ::std::forward<decltype (policy)>(policy),
158
- custom_brick<StrictWeakOrdering, decltype (size), search_algorithm::lower_bound>{comp, size, use_32bit_indexing},
170
+ __custom_brick<StrictWeakOrdering, decltype (size), decltype (zip_vw), search_algorithm::lower_bound>{
171
+ comp, size, use_32bit_indexing},
159
172
value_size, zip_vw)
160
173
.__deferrable_wait ();
161
174
return result + value_size;
@@ -187,7 +200,8 @@ upper_bound_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, InputIt
187
200
const bool use_32bit_indexing = size <= std::numeric_limits<std::uint32_t >::max ();
188
201
__bknd::__parallel_for (
189
202
_BackendTag{}, std::forward<decltype (policy)>(policy),
190
- custom_brick<StrictWeakOrdering, decltype (size), search_algorithm::upper_bound>{comp, size, use_32bit_indexing},
203
+ __custom_brick<StrictWeakOrdering, decltype (size), decltype (zip_vw), search_algorithm::upper_bound>{
204
+ comp, size, use_32bit_indexing},
191
205
value_size, zip_vw)
192
206
.__deferrable_wait ();
193
207
return result + value_size;
@@ -217,10 +231,11 @@ binary_search_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, Input
217
231
auto result_buf = keep_result (result, result + value_size);
218
232
auto zip_vw = make_zip_view (input_buf.all_view (), value_buf.all_view (), result_buf.all_view ());
219
233
const bool use_32bit_indexing = size <= std::numeric_limits<std::uint32_t >::max ();
220
- __bknd::__parallel_for (_BackendTag{}, std::forward<decltype (policy)>(policy),
221
- custom_brick<StrictWeakOrdering, decltype (size), search_algorithm::binary_search>{
222
- comp, size, use_32bit_indexing},
223
- value_size, zip_vw)
234
+ __bknd::__parallel_for (
235
+ _BackendTag{}, std::forward<decltype (policy)>(policy),
236
+ __custom_brick<StrictWeakOrdering, decltype (size), decltype (zip_vw), search_algorithm::binary_search>{
237
+ comp, size, use_32bit_indexing},
238
+ value_size, zip_vw)
224
239
.__deferrable_wait ();
225
240
return result + value_size;
226
241
}
0 commit comments