Skip to content

Commit 6ec465c

Browse files
committed
[oneDPL][ranges][merge] support size limit for output; fixes and rafactoring
1 parent 30ebe24 commit 6ec465c

File tree

3 files changed

+32
-87
lines changed

3 files changed

+32
-87
lines changed

include/oneapi/dpl/pstl/algorithm_impl.h

+23-12
Original file line numberDiff line numberDiff line change
@@ -2950,7 +2950,8 @@ __pattern_remove_if(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec,
29502950

29512951
template<std::random_access_iterator It1, std::random_access_iterator It2, std::random_access_iterator ItOut, typename _Comp>
29522952
std::pair<It1, It2>
2953-
__brick_merge(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp)
2953+
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
2954+
/* __is_vector = */ std::false_type)
29542955
{
29552956
while(__it_1 != __it_1_e && __it_2 != __it_2_e)
29562957
{
@@ -2982,6 +2983,14 @@ __brick_merge(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out
29822983
return {__it_1, __it_2};
29832984
}
29842985

2986+
template<std::random_access_iterator It1, std::random_access_iterator It2, std::random_access_iterator ItOut, typename _Comp>
2987+
std::pair<It1, It2>
2988+
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
2989+
/* __is_vector = */ std::true_type)
2990+
{
2991+
return __unseq_backend::__simd_merge(__it_1, __it_1_e, __it_2, __it_2_e, __it_out, __it_out_e, __comp);
2992+
}
2993+
29852994
template <class _ForwardIterator1, class _ForwardIterator2, class _OutputIterator, class _Compare>
29862995
_OutputIterator
29872996
__brick_merge(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2,
@@ -3014,6 +3023,16 @@ __pattern_merge(_Tag, _ExecutionPolicy&&, _ForwardIterator1 __first1, _ForwardIt
30143023
typename _Tag::__is_vector{});
30153024
}
30163025

3026+
template<class _Tag, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
3027+
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
3028+
std::pair<_It1, _It2>
3029+
__pattern_merge_2(_Tag, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
3030+
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
3031+
{
3032+
return __brick_merge_2(__it_1, __it_1 + __n_1, __it_2, __it_2 + __n_2, __it_out, __it_out + __n_out, __comp,
3033+
typename _Tag::__is_vector{});
3034+
}
3035+
30173036
template<typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
30183037
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
30193038
std::pair<_It1, _It2>
@@ -3062,16 +3081,16 @@ __pattern_merge_2(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _It1 __i
30623081
}
30633082

30643083
//serial merge n elements, starting from input x and y, to [i, j) output range
3065-
auto __res = __brick_merge(__it_1 + __r, __it_1 + __n_1,
3084+
auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1,
30663085
__it_2 + __c, __it_2 + __n_2,
3067-
__it_out + __i, __it_out + __j, __comp);
3086+
__it_out + __i, __it_out + __j, __comp, _IsVector{});
30683087

30693088
if(__j == __n_out)
30703089
{
30713090
__it_res_1 = __res.first;
30723091
__it_res_2 = __res.second;
30733092
}
3074-
}, _ONEDPL_MERGE_CUT_OFF);
3093+
}, _ONEDPL_MERGE_CUT_OFF); //grainsize
30753094
});
30763095

30773096
return {__it_res_1, __it_res_2};
@@ -3084,7 +3103,6 @@ __pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _Ran
30843103
_RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2, _RandomAccessIterator2 __last2,
30853104
_RandomAccessIterator3 __d_first, _Compare __comp)
30863105
{
3087-
#if 0
30883106
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;
30893107

30903108
return __internal::__except_handler([&]() {
@@ -3097,13 +3115,6 @@ __pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _Ran
30973115
});
30983116
return __d_first + (__last1 - __first1) + (__last2 - __first2);
30993117
});
3100-
#else
3101-
auto __n_1 = __last1 - __first1;
3102-
auto __n_2 = __last2 - __first2;
3103-
auto __n_3 = __n_1 + __n_2;
3104-
__pattern_merge_2(__tag, std::forward<_ExecutionPolicy>(__exec), __first2, __n_2, __first1, __n_1, __d_first, __n_3, __comp);
3105-
return __d_first + __n_3;
3106-
#endif
31073118
}
31083119

31093120
//------------------------------------------------------------------------

include/oneapi/dpl/pstl/algorithm_ranges_impl.h

-65
Original file line numberDiff line numberDiff line change
@@ -447,29 +447,6 @@ template<typename _Tag, typename _ExecutionPolicy, typename _R1, typename _R2, t
447447
auto
448448
__pattern_merge(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
449449
_Proj1 __proj1, _Proj2 __proj2)
450-
{
451-
static_assert(__is_parallel_tag_v<_Tag> || typename _Tag::__is_vector{});
452-
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only
453-
454-
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
455-
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
456-
std::forward<decltype(__val2)>(__val2)));};
457-
458-
auto __res = oneapi::dpl::__internal::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec),
459-
std::ranges::begin(__r1), std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2),
460-
std::ranges::begin(__r2) + std::ranges::size(__r2), std::ranges::begin(__out_r), __comp_2);
461-
462-
using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
463-
std::ranges::borrowed_iterator_t<_OutRange>>;
464-
465-
return __return_type{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) + std::ranges::size(__r2), __res};
466-
}
467-
468-
template<typename _IsVector, typename _ExecutionPolicy, typename _R1, typename _R2, typename _OutRange, typename _Comp,
469-
typename _Proj1, typename _Proj2>
470-
auto
471-
__pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
472-
_Proj1 __proj1, _Proj2 __proj2)
473450
{
474451
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
475452
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
@@ -495,48 +472,6 @@ __pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _R1&
495472
return __return_type{__res.second, __res.first, __it_out + __n_out};
496473
}
497474

498-
template<typename _ExecutionPolicy, typename _R1, typename _R2, typename _OutRange, typename _Comp,
499-
typename _Proj1, typename _Proj2>
500-
auto
501-
__pattern_merge(__serial_tag</*IsVector*/std::false_type>, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
502-
_Proj1 __proj1, _Proj2 __proj2)
503-
{
504-
using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
505-
std::ranges::borrowed_iterator_t<_OutRange>>;
506-
507-
auto __it_1 = std::ranges::begin(__r1);
508-
auto __it_2 = std::ranges::begin(__r2);
509-
auto __it_out = std::ranges::begin(__out_r);
510-
while(__it_1 != std::ranges::end(__r1) && __it_2 != std::ranges::end(__r2))
511-
{
512-
if (std::invoke(__comp, std::invoke(__proj1, *__it_1), std::invoke(__proj2, *__it_2)))
513-
{
514-
*__it_out = *__it_1;
515-
++__it_out, ++__it_1;
516-
}
517-
else
518-
{
519-
*__it_out = *__it_2;
520-
++__it_out, ++__it_2;
521-
}
522-
if(__it_out == std::ranges::end(__out_r))
523-
return __return_type{__it_1, __it_2, __it_out};
524-
}
525-
526-
if(__it_1 == std::ranges::end(__r1))
527-
{
528-
for(; __it_2 != std::ranges::end(__r2) && __it_out != std::ranges::end(__out_r); ++__it_2, ++__it_out)
529-
*__it_out = *__it_2;
530-
}
531-
else
532-
{
533-
//assert(__it_2 == std::ranges::end(__r2);
534-
for(; __it_1 != std::ranges::end(__r1) && __it_out != std::ranges::end(__out_r); ++__it_1, ++__it_out)
535-
*__it_out = *__it_1;
536-
}
537-
return __return_type{__it_1, __it_2, __it_out};
538-
}
539-
540475
} // namespace __ranges
541476
} // namespace __internal
542477
} // namespace dpl

include/oneapi/dpl/pstl/unseq_backend_simd.h

+9-10
Original file line numberDiff line numberDiff line change
@@ -880,34 +880,33 @@ __simd_remove_if(_RandomAccessIterator __first, _DifferenceType __n, _UnaryPredi
880880
return __current + __cnt;
881881
}
882882

883-
template<typename _Index1, typename _Index2, typename _Index3, typename _A, typename _B, typename _C>
884-
std::pair<_Index1, _Index2>
885-
__simd_merge(_Index1 __x, _Index1 __x_e, _Index2 __y, _Index2 __y_e, _Index3 __i, _Index3 __j,
886-
_A&& __a, _B&& __b, _C&& __c)
883+
template<typename _Iterator1, typename _Iterator2, typename _Iterator3, typename _Comp>
884+
std::pair<_Iterator1, _Iterator2>
885+
__simd_merge(_Iterator1 __x, _Iterator1 __x_e, _Iterator2 __y, _Iterator2 __y_e, _Iterator3 __i, _Iterator3 __j, _Comp __comp)
887886
{
888887
_ONEDPL_PRAGMA_SIMD
889-
for(_Index3 __k = __i; __k < __j; ++__k)
888+
for(_Iterator3 __k = __i; __k < __j; ++__k)
890889
{
891890
if(__x >= __x_e)
892891
{
893892
assert(__y < __y_e);
894-
__c[__k] = __b[__y];
893+
*__k = *__y;
895894
++__y;
896895
}
897896
else if(__y >= __y_e)
898897
{
899898
assert(__x < __x_e);
900-
__c[__k] = __a[__x];
899+
*__k = *__x;
901900
++__x;
902901
}
903-
else if(__a[__x] < __b[__y])
902+
else if(std::invoke(__comp, *__x, *__y))
904903
{
905-
__c[__k] = __a[__x];
904+
*__k = *__x;
906905
++__x;
907906
}
908907
else
909908
{
910-
__c[__k] = __b[__y];
909+
*__k = *__y;
911910
++__y;
912911
}
913912
}

0 commit comments

Comments
 (0)