Skip to content

Commit 4cbf18c

Browse files
authored
[oneDPL][ranges] support size limit for output for merge algorithm (#1942)
1 parent 528c0f0 commit 4cbf18c

11 files changed

+403
-127
lines changed

include/oneapi/dpl/pstl/algorithm_impl.h

+115
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include <type_traits>
2121
#include <functional>
2222
#include <algorithm>
23+
#include <cassert>
24+
#include <cmath>
2325

2426
#include "algorithm_fwd.h"
2527

@@ -31,6 +33,7 @@
3133
#include "parallel_backend.h"
3234
#include "parallel_impl.h"
3335
#include "iterator_impl.h"
36+
#include "../functional" //for oneapi::dpl::identity
3437

3538
#if _ONEDPL_HETERO_BACKEND
3639
# include "hetero/algorithm_impl_hetero.h" // for __pattern_fill_n, __pattern_generate_n
@@ -2947,6 +2950,40 @@ __pattern_remove_if(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec,
29472950
//------------------------------------------------------------------------
29482951
// merge
29492952
//------------------------------------------------------------------------
2953+
// Serial version of ___merge_path_out_lim merges the 1st sequence and the 2nd sequence in "reverse order":
2954+
// the identical elements from the 2nd sequence are merged first.
2955+
template <typename _Iterator1, typename _Iterator2, typename _Iterator3, typename _Comp>
2956+
std::pair<_Iterator1, _Iterator2>
2957+
__serial_merge_out_lim(_Iterator1 __x, _Iterator1 __x_e, _Iterator2 __y, _Iterator2 __y_e, _Iterator3 __out,
2958+
_Iterator3 __out_e, _Comp __comp)
2959+
{
2960+
for (_Iterator3 __k = __out; __k != __out_e; ++__k)
2961+
{
2962+
if (__x == __x_e)
2963+
{
2964+
assert(__y != __y_e);
2965+
*__k = *__y;
2966+
++__y;
2967+
}
2968+
else if (__y == __y_e)
2969+
{
2970+
assert(__x != __x_e);
2971+
*__k = *__x;
2972+
++__x;
2973+
}
2974+
else if (std::invoke(__comp, *__x, *__y))
2975+
{
2976+
*__k = *__x;
2977+
++__x;
2978+
}
2979+
else
2980+
{
2981+
*__k = *__y;
2982+
++__y;
2983+
}
2984+
}
2985+
return {__x, __y};
2986+
}
29502987

29512988
template <class _ForwardIterator1, class _ForwardIterator2, class _OutputIterator, class _Compare>
29522989
_OutputIterator
@@ -2980,6 +3017,84 @@ __pattern_merge(_Tag, _ExecutionPolicy&&, _ForwardIterator1 __first1, _ForwardIt
29803017
typename _Tag::__is_vector{});
29813018
}
29823019

3020+
template <typename _Tag, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2, typename _Index2,
3021+
typename _OutIt, typename _Index3, typename _Comp>
3022+
std::pair<_It1, _It2>
3023+
___merge_path_out_lim(_Tag, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2, _Index2 __n_2,
3024+
_OutIt __it_out, _Index3 __n_out, _Comp __comp)
3025+
{
3026+
return __serial_merge_out_lim(__it_1, __it_1 + __n_1, __it_2, __it_2 + __n_2, __it_out, __it_out + __n_out, __comp);
3027+
}
3028+
3029+
inline constexpr std::size_t __merge_path_cut_off = 2000;
3030+
3031+
// Parallel version of ___merge_path_out_lim merges the 1st sequence and the 2nd sequence in "reverse order":
3032+
// the identical elements from the 2nd sequence are merged first.
3033+
template <typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
3034+
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
3035+
std::pair<_It1, _It2>
3036+
___merge_path_out_lim(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
3037+
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
3038+
{
3039+
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;
3040+
3041+
_It1 __it_res_1;
3042+
_It2 __it_res_2;
3043+
3044+
__internal::__except_handler([&]() {
3045+
__par_backend::__parallel_for(
3046+
__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
3047+
[=, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j) {
3048+
//a start merging point on the merge path; for each thread
3049+
_Index1 __r = 0; //row index
3050+
_Index2 __c = 0; //column index
3051+
3052+
if (__i > 0)
3053+
{
3054+
//calc merge path intersection:
3055+
const _Index3 __d_size =
3056+
std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;
3057+
3058+
auto __get_row = [__i, __n_1](auto __d) { return std::min<_Index1>(__i, __n_1) - __d - 1; };
3059+
auto __get_column = [__i, __n_1](auto __d) {
3060+
return std::max<_Index1>(0, __i - __n_1 - 1) + __d + (__i / (__n_1 + 1) > 0 ? 1 : 0);
3061+
};
3062+
3063+
oneapi::dpl::counting_iterator<_Index3> __it_d(0);
3064+
3065+
auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1, [&](auto __d, auto __val) {
3066+
auto __r = __get_row(__d);
3067+
auto __c = __get_column(__d);
3068+
3069+
oneapi::dpl::__internal::__compare<_Comp, oneapi::dpl::identity> __cmp{__comp,
3070+
oneapi::dpl::identity{}};
3071+
const auto __res = __cmp(__it_1[__r], __it_2[__c]) ? 1 : 0;
3072+
3073+
return __res < __val;
3074+
});
3075+
3076+
//intersection point
3077+
__r = __get_row(__res_d);
3078+
__c = __get_column(__res_d);
3079+
++__r; //to get a merge matrix ceil, lying on the current diagonal
3080+
}
3081+
3082+
//serial merge n elements, starting from input x and y, to [i, j) output range
3083+
const std::pair __res = __serial_merge_out_lim(__it_1 + __r, __it_1 + __n_1, __it_2 + __c,
3084+
__it_2 + __n_2, __it_out + __i, __it_out + __j, __comp);
3085+
3086+
if (__j == __n_out)
3087+
{
3088+
__it_res_1 = __res.first;
3089+
__it_res_2 = __res.second;
3090+
}
3091+
},
3092+
__merge_path_cut_off); //grainsize
3093+
});
3094+
3095+
return {__it_res_1, __it_res_2};
3096+
}
3097+
29833098
template <class _IsVector, class _ExecutionPolicy, class _RandomAccessIterator1, class _RandomAccessIterator2,
29843099
class _RandomAccessIterator3, class _Compare>
29853100
_RandomAccessIterator3

include/oneapi/dpl/pstl/algorithm_ranges_impl.h

+22-17
Original file line numberDiff line numberDiff line change
@@ -448,31 +448,36 @@ auto
448448
__pattern_merge(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
449449
_Proj1 __proj1, _Proj2 __proj2)
450450
{
451-
static_assert(__is_parallel_tag_v<_Tag> || typename _Tag::__is_vector{});
452-
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only
451+
using __return_type =
452+
std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
453+
std::ranges::borrowed_iterator_t<_OutRange>>;
453454

454455
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
455456
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
456457
std::forward<decltype(__val2)>(__val2)));};
457458

458-
auto __res = oneapi::dpl::__internal::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec),
459-
std::ranges::begin(__r1), std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2),
460-
std::ranges::begin(__r2) + std::ranges::size(__r2), std::ranges::begin(__out_r), __comp_2);
459+
using _Index1 = std::ranges::range_difference_t<_R1>;
460+
using _Index2 = std::ranges::range_difference_t<_R2>;
461+
using _Index3 = std::ranges::range_difference_t<_OutRange>;
461462

462-
using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
463-
std::ranges::borrowed_iterator_t<_OutRange>>;
463+
const _Index1 __n_1 = std::ranges::size(__r1);
464+
const _Index2 __n_2 = std::ranges::size(__r2);
465+
const _Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));
464466

465-
return __return_type{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) + std::ranges::size(__r2), __res};
466-
}
467+
auto __it_1 = std::ranges::begin(__r1);
468+
auto __it_2 = std::ranges::begin(__r2);
469+
auto __it_out = std::ranges::begin(__out_r);
467470

468-
template<typename _ExecutionPolicy, typename _R1, typename _R2, typename _OutRange, typename _Comp,
469-
typename _Proj1, typename _Proj2>
470-
auto
471-
__pattern_merge(__serial_tag</*IsVector*/std::false_type>, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
472-
_Proj1 __proj1, _Proj2 __proj2)
473-
{
474-
return std::ranges::merge(std::forward<_R1>(__r1), std::forward<_R2>(__r2), std::ranges::begin(__out_r), __comp, __proj1,
475-
__proj2);
471+
if (__n_out == 0)
472+
return __return_type{__it_1, __it_2, __it_out};
473+
474+
// Parallel and serial versions of ___merge_path_out_lim merge the 1st sequence and the 2nd sequence in "reverse order":
475+
// the identical elements from the 2nd sequence are merged first.
476+
// So, the call to ___merge_path_out_lim swaps the order of sequences.
477+
std::pair __res = ___merge_path_out_lim(__tag, std::forward<_ExecutionPolicy>(__exec), __it_2, __n_2, __it_1, __n_1,
478+
__it_out, __n_out, __comp_2);
479+
480+
return __return_type{__res.second, __res.first, __it_out + __n_out};
476481
}
477482

478483
} // namespace __ranges

include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -1173,9 +1173,12 @@ merge(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _Range3&& _
11731173
{
11741174
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec, __rng1, __rng2, __rng3);
11751175

1176-
return oneapi::dpl::__internal::__ranges::__pattern_merge(
1176+
auto __view_res = views::all_write(::std::forward<_Range3>(__rng3));
1177+
oneapi::dpl::__internal::__ranges::__pattern_merge(
11771178
__dispatch_tag, ::std::forward<_ExecutionPolicy>(__exec), views::all_read(::std::forward<_Range1>(__rng1)),
1178-
views::all_read(::std::forward<_Range2>(__rng2)), views::all_write(::std::forward<_Range3>(__rng3)), __comp);
1179+
views::all_read(::std::forward<_Range2>(__rng2)), __view_res, __comp);
1180+
1181+
return __view_res.size();
11791182
}
11801183

11811184
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3>

include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h

+39-26
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,19 @@ namespace __ranges
5151
//------------------------------------------------------------------------
5252

5353
template <typename _BackendTag, typename _ExecutionPolicy, typename _Function, typename... _Ranges>
54-
void
54+
auto /* see _Size inside the function */
5555
__pattern_walk_n(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Function __f, _Ranges&&... __rngs)
5656
{
57-
auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);
57+
using _Size = std::make_unsigned_t<std::common_type_t<oneapi::dpl::__internal::__difference_t<_Ranges>...>>;
58+
const _Size __n = std::min({_Size(__rngs.size())...});
5859
if (__n > 0)
5960
{
6061
oneapi::dpl::__par_backend_hetero::__parallel_for(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
6162
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n,
6263
::std::forward<_Ranges>(__rngs)...)
6364
.__deferrable_wait();
6465
}
66+
return __n;
6567
}
6668

6769
#if _ONEDPL_CPP20_RANGES_PRESENT
@@ -678,46 +680,51 @@ struct __copy1_wrapper;
678680
template <typename _Name>
679681
struct __copy2_wrapper;
680682

683+
struct __out_size_limit : public std::true_type
684+
{
685+
};
686+
681687
template <typename _BackendTag, typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3,
682688
typename _Compare>
683-
oneapi::dpl::__internal::__difference_t<_Range3>
689+
std::pair<oneapi::dpl::__internal::__difference_t<_Range1>, oneapi::dpl::__internal::__difference_t<_Range2>>
684690
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2,
685691
_Range3&& __rng3, _Compare __comp)
686692
{
687-
auto __n1 = __rng1.size();
688-
auto __n2 = __rng2.size();
689-
auto __n = __n1 + __n2;
690-
if (__n == 0)
691-
return 0;
693+
if (__rng3.empty())
694+
return {0, 0};
695+
696+
const auto __n1 = __rng1.size();
697+
const auto __n2 = __rng2.size();
692698

693699
//To consider the direct copying pattern call in case just one of sequences is empty.
694700
if (__n1 == 0)
695701
{
696-
oneapi::dpl::__internal::__ranges::__pattern_walk_n(
702+
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n(
697703
__tag,
698704
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy1_wrapper>(
699705
::std::forward<_ExecutionPolicy>(__exec)),
700706
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
701707
::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3));
708+
return {0, __res};
702709
}
703-
else if (__n2 == 0)
710+
711+
if (__n2 == 0)
704712
{
705-
oneapi::dpl::__internal::__ranges::__pattern_walk_n(
713+
auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n(
706714
__tag,
707715
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy2_wrapper>(
708716
::std::forward<_ExecutionPolicy>(__exec)),
709717
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
710718
::std::forward<_Range1>(__rng1), ::std::forward<_Range3>(__rng3));
711-
}
712-
else
713-
{
714-
__par_backend_hetero::__parallel_merge(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
715-
::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2),
716-
::std::forward<_Range3>(__rng3), __comp)
717-
.__deferrable_wait();
719+
return {__res, 0};
718720
}
719721

720-
return __n;
722+
auto __res = __par_backend_hetero::__parallel_merge(
723+
_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__rng1),
724+
::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3), __comp, __out_size_limit{});
725+
726+
auto __val = __res.get();
727+
return {__val.first, __val.second};
721728
}
722729

723730
#if _ONEDPL_CPP20_RANGES_PRESENT
@@ -727,21 +734,27 @@ auto
727734
__pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r,
728735
_Comp __comp, _Proj1 __proj1, _Proj2 __proj2)
729736
{
730-
assert(std::ranges::size(__r1) + std::ranges::size(__r2) <= std::ranges::size(__out_r)); // for debug purposes only
731-
732737
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
733738
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)),
734739
std::invoke(__proj2, std::forward<decltype(__val2)>(__val2)));};
735740

736-
auto __res = oneapi::dpl::__internal::__ranges::__pattern_merge(__tag, std::forward<_ExecutionPolicy>(__exec),
737-
oneapi::dpl::__ranges::views::all_read(__r1), oneapi::dpl::__ranges::views::all_read(__r2),
738-
oneapi::dpl::__ranges::views::all_write(__out_r), __comp_2);
741+
using _Index1 = std::ranges::range_difference_t<_R1>;
742+
using _Index2 = std::ranges::range_difference_t<_R2>;
743+
using _Index3 = std::ranges::range_difference_t<_OutRange>;
744+
745+
const _Index1 __n_1 = std::ranges::size(__r1);
746+
const _Index2 __n_2 = std::ranges::size(__r2);
747+
const _Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));
748+
749+
const std::pair __res = oneapi::dpl::__internal::__ranges::__pattern_merge(
750+
__tag, std::forward<_ExecutionPolicy>(__exec), oneapi::dpl::__ranges::views::all_read(__r1),
751+
oneapi::dpl::__ranges::views::all_read(__r2), oneapi::dpl::__ranges::views::all_write(__out_r), __comp_2);
739752

740753
using __return_t = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
741754
std::ranges::borrowed_iterator_t<_OutRange>>;
742755

743-
return __return_t{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) +
744-
std::ranges::size(__r2), std::ranges::begin(__out_r) + __res};
756+
return __return_t{std::ranges::begin(__r1) + __res.first, std::ranges::begin(__r2) + __res.second,
757+
std::ranges::begin(__out_r) + __n_out};
745758
}
746759
#endif //_ONEDPL_CPP20_RANGES_PRESENT
747760

0 commit comments

Comments
 (0)