Skip to content

Commit c0c8ba4

Browse files
committed
[oneDPL][ranges][merge] support size limit for output
1 parent 5f26aea commit c0c8ba4

13 files changed

+370
-112
lines changed

include/oneapi/dpl/pstl/algorithm_impl.h

+122-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "parallel_backend.h"
3232
#include "parallel_impl.h"
3333
#include "iterator_impl.h"
34+
#include "../functional"
3435

3536
#if _ONEDPL_HETERO_BACKEND
3637
# include "hetero/algorithm_impl_hetero.h" // for __pattern_fill_n, __pattern_generate_n
@@ -2948,6 +2949,49 @@ __pattern_remove_if(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec,
29482949
// merge
29492950
//------------------------------------------------------------------------
29502951

2952+
template<typename It1, typename It2, typename ItOut, typename _Comp>
2953+
std::pair<It1, It2>
2954+
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
2955+
/* __is_vector = */ std::false_type)
2956+
{
2957+
while(__it_1 != __it_1_e && __it_2 != __it_2_e)
2958+
{
2959+
if (__comp(*__it_1, *__it_2))
2960+
{
2961+
*__it_out = *__it_1;
2962+
++__it_out, ++__it_1;
2963+
}
2964+
else
2965+
{
2966+
*__it_out = *__it_2;
2967+
++__it_out, ++__it_2;
2968+
}
2969+
if(__it_out == __it_out_e)
2970+
return {__it_1, __it_2};
2971+
}
2972+
2973+
if(__it_1 == __it_1_e)
2974+
{
2975+
for(; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out)
2976+
*__it_out = *__it_2;
2977+
}
2978+
else
2979+
{
2980+
//assert(__it_2 == __it_2_e);
2981+
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
2982+
*__it_out = *__it_1;
2983+
}
2984+
return {__it_1, __it_2};
2985+
}
2986+
2987+
template<typename It1, typename It2, typename ItOut, typename _Comp>
2988+
std::pair<It1, It2>
2989+
__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp,
2990+
/* __is_vector = */ std::true_type)
2991+
{
2992+
return __unseq_backend::__simd_merge(__it_1, __it_1_e, __it_2, __it_2_e, __it_out, __it_out_e, __comp);
2993+
}
2994+
29512995
template <class _ForwardIterator1, class _ForwardIterator2, class _OutputIterator, class _Compare>
29522996
_OutputIterator
29532997
__brick_merge(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2,
@@ -2980,10 +3024,87 @@ __pattern_merge(_Tag, _ExecutionPolicy&&, _ForwardIterator1 __first1, _ForwardIt
29803024
typename _Tag::__is_vector{});
29813025
}
29823026

3027+
template<class _Tag, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
3028+
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
3029+
std::pair<_It1, _It2>
3030+
__pattern_merge_2(_Tag, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
3031+
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
3032+
{
3033+
return __brick_merge_2(__it_1, __it_1 + __n_1, __it_2, __it_2 + __n_2, __it_out, __it_out + __n_out, __comp,
3034+
typename _Tag::__is_vector{});
3035+
}
3036+
3037+
template<typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
3038+
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
3039+
std::pair<_It1, _It2>
3040+
__pattern_merge_2(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
3041+
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
3042+
{
3043+
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;
3044+
3045+
_It1 __it_res_1;
3046+
_It2 __it_res_2;
3047+
3048+
__internal::__except_handler([&]() {
3049+
__par_backend::__parallel_for(__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
3050+
[=, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j)
3051+
{
3052+
//a start merging point on the merge path; for each thread
3053+
_Index1 __r = 0; //row index
3054+
_Index2 __c = 0; //column index
3055+
3056+
if(__i > 0)
3057+
{
3058+
//calc merge path intersection:
3059+
const _Index3 __d_size =
3060+
std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;
3061+
3062+
auto __get_row = [__i, __n_1](auto __d)
3063+
{ return std::min<_Index1>(__i, __n_1) - __d - 1; };
3064+
auto __get_column = [__i, __n_1](auto __d)
3065+
{ return std::max<_Index1>(0, __i - __n_1 - 1) + __d + (__i / (__n_1 + 1) > 0 ? 1 : 0); };
3066+
3067+
oneapi::dpl::counting_iterator<_Index3> __it_d(0);
3068+
3069+
auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1,
3070+
[&](auto __d, auto __val) {
3071+
auto __r = __get_row(__d);
3072+
auto __c = __get_column(__d);
3073+
3074+
oneapi::dpl::__internal::__compare<_Comp, oneapi::dpl::identity>
3075+
__cmp{__comp, oneapi::dpl::identity{}};
3076+
const auto __res = (__cmp(__it_1[__r], __it_2[__c]) ? 1 : 0);
3077+
3078+
return __res < __val;
3079+
}
3080+
);
3081+
3082+
//intersection point
3083+
__r = __get_row(__res_d);
3084+
__c = __get_column(__res_d);
3085+
++__r; //to get a merge matrix ceil, lying on the current diagonal
3086+
}
3087+
3088+
//serial merge n elements, starting from input x and y, to [i, j) output range
3089+
auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1,
3090+
__it_2 + __c, __it_2 + __n_2,
3091+
__it_out + __i, __it_out + __j, __comp, _IsVector{});
3092+
3093+
if(__j == __n_out)
3094+
{
3095+
__it_res_1 = __res.first;
3096+
__it_res_2 = __res.second;
3097+
}
3098+
}, _ONEDPL_MERGE_CUT_OFF); //grainsize
3099+
});
3100+
3101+
return {__it_res_1, __it_res_2};
3102+
}
3103+
29833104
template <class _IsVector, class _ExecutionPolicy, class _RandomAccessIterator1, class _RandomAccessIterator2,
29843105
class _RandomAccessIterator3, class _Compare>
29853106
_RandomAccessIterator3
2986-
__pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1,
3107+
__pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1,
29873108
_RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2, _RandomAccessIterator2 __last2,
29883109
_RandomAccessIterator3 __d_first, _Compare __comp)
29893110
{

include/oneapi/dpl/pstl/algorithm_ranges_impl.h

+18-18
Original file line numberDiff line numberDiff line change
@@ -448,31 +448,31 @@ auto
448448
__pattern_merge(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
449449
_Proj1 __proj1, _Proj2 __proj2)
450450
{
451-
static_assert(__is_parallel_tag_v<_Tag> || typename _Tag::__is_vector{});
452-
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only
453-
451+
using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
452+
std::ranges::borrowed_iterator_t<_OutRange>>;
453+
454454
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
455455
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
456456
std::forward<decltype(__val2)>(__val2)));};
457457

458-
auto __res = oneapi::dpl::__internal::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec),
459-
std::ranges::begin(__r1), std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2),
460-
std::ranges::begin(__r2) + std::ranges::size(__r2), std::ranges::begin(__out_r), __comp_2);
458+
using _Index1 = std::ranges::range_difference_t<_R1>;
459+
using _Index2 = std::ranges::range_difference_t<_R2>;
460+
using _Index3 = std::ranges::range_difference_t<_OutRange>;
461461

462-
using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
463-
std::ranges::borrowed_iterator_t<_OutRange>>;
462+
_Index1 __n_1 = std::ranges::size(__r1);
463+
_Index2 __n_2 = std::ranges::size(__r2);
464+
_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));
464465

465-
return __return_type{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) + std::ranges::size(__r2), __res};
466-
}
466+
auto __it_1 = std::ranges::begin(__r1);
467+
auto __it_2 = std::ranges::begin(__r2);
468+
auto __it_out = std::ranges::begin(__out_r);
467469

468-
template<typename _ExecutionPolicy, typename _R1, typename _R2, typename _OutRange, typename _Comp,
469-
typename _Proj1, typename _Proj2>
470-
auto
471-
__pattern_merge(__serial_tag</*IsVector*/std::false_type>, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
472-
_Proj1 __proj1, _Proj2 __proj2)
473-
{
474-
return std::ranges::merge(std::forward<_R1>(__r1), std::forward<_R2>(__r2), std::ranges::begin(__out_r), __comp, __proj1,
475-
__proj2);
470+
if(__n_out == 0)
471+
return __return_type{__it_1, __it_2, __it_out};
472+
473+
auto __res = __pattern_merge_2(__tag, std::forward<_ExecutionPolicy>(__exec), __it_2, __n_2, __it_1, __n_1, __it_out, __n_out, __comp_2);
474+
475+
return __return_type{__res.second, __res.first, __it_out + __n_out};
476476
}
477477

478478
} // namespace __ranges

include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -1173,9 +1173,12 @@ merge(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _Range3&& _
11731173
{
11741174
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec, __rng1, __rng2, __rng3);
11751175

1176-
return oneapi::dpl::__internal::__ranges::__pattern_merge(
1176+
auto __view_res = views::all_write(::std::forward<_Range3>(__rng3));
1177+
oneapi::dpl::__internal::__ranges::__pattern_merge(
11771178
__dispatch_tag, ::std::forward<_ExecutionPolicy>(__exec), views::all_read(::std::forward<_Range1>(__rng1)),
1178-
views::all_read(::std::forward<_Range2>(__rng2)), views::all_write(::std::forward<_Range3>(__rng3)), __comp);
1179+
views::all_read(::std::forward<_Range2>(__rng2)), __view_res, __comp);
1180+
1181+
return __view_res.size();
11791182
}
11801183

11811184
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3>

include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h

+29-21
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,19 @@ namespace __ranges
5151
//------------------------------------------------------------------------
5252

5353
template <typename _BackendTag, typename _ExecutionPolicy, typename _Function, typename... _Ranges>
54-
void
54+
auto
5555
__pattern_walk_n(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Function __f, _Ranges&&... __rngs)
5656
{
57-
auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);
57+
using _Size = std::make_unsigned_t<std::common_type_t<oneapi::dpl::__internal::__difference_t<_Ranges>...>>;
58+
auto __n = std::min({_Size(__rngs.size())...});
5859
if (__n > 0)
5960
{
6061
oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
6162
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n,
6263
::std::forward<_Ranges>(__rngs)...)
6364
.__deferrable_wait();
6465
}
66+
return __n;
6567
}
6668

6769
#if _ONEDPL_CPP20_RANGES_PRESENT
@@ -680,44 +682,44 @@ struct __copy2_wrapper;
680682

681683
template <typename _BackendTag, typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3,
682684
typename _Compare>
683-
oneapi::dpl::__internal::__difference_t<_Range3>
685+
std::pair<oneapi::dpl::__internal::__difference_t<_Range1>, oneapi::dpl::__internal::__difference_t<_Range2>>
684686
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2,
685687
_Range3&& __rng3, _Compare __comp)
686688
{
687689
auto __n1 = __rng1.size();
688690
auto __n2 = __rng2.size();
689-
auto __n = __n1 + __n2;
690-
if (__n == 0)
691-
return 0;
691+
if (__rng3.size() == 0)
692+
return {0, 0};
692693

693694
//To consider the direct copying pattern call in case just one of sequences is empty.
694695
if (__n1 == 0)
695696
{
696-
oneapi::dpl::__internal::__ranges::__pattern_walk_n(
697+
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n(
697698
__tag,
698699
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy1_wrapper>(
699700
::std::forward<_ExecutionPolicy>(__exec)),
700701
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
701702
::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3));
703+
return {0, __res};
702704
}
703-
else if (__n2 == 0)
705+
706+
if (__n2 == 0)
704707
{
705-
oneapi::dpl::__internal::__ranges::__pattern_walk_n(
708+
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n(
706709
__tag,
707710
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy2_wrapper>(
708711
::std::forward<_ExecutionPolicy>(__exec)),
709712
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
710713
::std::forward<_Range1>(__rng1), ::std::forward<_Range3>(__rng3));
711-
}
712-
else
713-
{
714-
__par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
715-
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2),
716-
::std::forward<_Range3>(__rng3), __comp)
717-
.__deferrable_wait();
714+
return {__res, 0};
718715
}
719716

720-
return __n;
717+
auto __res = __par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
718+
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2),
719+
::std::forward<_Range3>(__rng3), __comp);
720+
721+
auto __val = __res.get();
722+
return {__val.first, __val.second};
721723
}
722724

723725
#if _ONEDPL_CPP20_RANGES_PRESENT
@@ -727,21 +729,27 @@ auto
727729
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r,
728730
_Comp __comp, _Proj1 __proj1, _Proj2 __proj2)
729731
{
730-
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only
731-
732732
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
733733
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)),
734734
std::invoke(__proj2, std::forward<decltype(__val2)>(__val2)));};
735735

736+
using _Index1 = std::ranges::range_difference_t<_R1>;
737+
using _Index2 = std::ranges::range_difference_t<_R2>;
738+
using _Index3 = std::ranges::range_difference_t<_OutRange>;
739+
740+
_Index1 __n_1 = std::ranges::size(__r1);
741+
_Index2 __n_2 = std::ranges::size(__r2);
742+
_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));
743+
736744
auto __res = oneapi::dpl::__internal::__ranges::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec),
737745
oneapi::dpl::__ranges::views::all_read(__r1), oneapi::dpl::__ranges::views::all_read(__r2),
738746
oneapi::dpl::__ranges::views::all_write(__out_r), __comp_2);
739747

740748
using __return_t = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
741749
std::ranges::borrowed_iterator_t<_OutRange>>;
742750

743-
return __return_t{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) +
744-
std::ranges::size(__r2), std::ranges::begin(__out_r) + __res};
751+
return __return_t{std::ranges::begin(__r1) + __res.first, std::ranges::begin(__r2) + __res.second,
752+
std::ranges::begin(__out_r) + __n_out};
745753
}
746754
#endif //_ONEDPL_CPP20_RANGES_PRESENT
747755

0 commit comments

Comments
 (0)