@@ -51,17 +51,19 @@ namespace __ranges
51
51
// ------------------------------------------------------------------------
52
52
53
53
template <typename _BackendTag, typename _ExecutionPolicy, typename _Function, typename ... _Ranges>
54
- void
54
+ auto /* see _Size inside the function */
55
55
__pattern_walk_n (__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Function __f, _Ranges&&... __rngs)
56
56
{
57
- auto __n = oneapi::dpl::__ranges::__get_first_range_size (__rngs...);
57
+ using _Size = std::make_unsigned_t <std::common_type_t <oneapi::dpl::__internal::__difference_t <_Ranges>...>>;
58
+ const _Size __n = std::min ({_Size (__rngs.size ())...});
58
59
if (__n > 0 )
59
60
{
60
61
oneapi::dpl::__par_backend_hetero::__parallel_for (_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
61
62
unseq_backend::walk_n<_ExecutionPolicy, _Function>{__f}, __n,
62
63
::std::forward<_Ranges>(__rngs)...)
63
64
.__deferrable_wait ();
64
65
}
66
+ return __n;
65
67
}
66
68
67
69
#if _ONEDPL_CPP20_RANGES_PRESENT
@@ -678,46 +680,51 @@ struct __copy1_wrapper;
678
680
template <typename _Name>
679
681
struct __copy2_wrapper ;
680
682
683
+ struct __out_size_limit : public std ::true_type
684
+ {
685
+ };
686
+
681
687
template <typename _BackendTag, typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3,
682
688
typename _Compare>
683
- oneapi::dpl::__internal::__difference_t <_Range3 >
689
+ std::pair< oneapi::dpl::__internal::__difference_t <_Range1>, oneapi::dpl::__internal:: __difference_t <_Range2> >
684
690
__pattern_merge (__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2,
685
691
_Range3&& __rng3, _Compare __comp)
686
692
{
687
- auto __n1 = __rng1. size ();
688
- auto __n2 = __rng2. size () ;
689
- auto __n = __n1 + __n2;
690
- if (__n == 0 )
691
- return 0 ;
693
+ if (__rng3. empty ())
694
+ return { 0 , 0 } ;
695
+
696
+ const auto __n1 = __rng1. size ();
697
+ const auto __n2 = __rng2. size () ;
692
698
693
699
// To consider the direct copying pattern call in case just one of sequences is empty.
694
700
if (__n1 == 0 )
695
701
{
696
- oneapi::dpl::__internal::__ranges::__pattern_walk_n (
702
+ auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n (
697
703
__tag,
698
704
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy1_wrapper>(
699
705
::std::forward<_ExecutionPolicy>(__exec)),
700
706
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
701
707
::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3));
708
+ return {0 , __res};
702
709
}
703
- else if (__n2 == 0 )
710
+
711
+ if (__n2 == 0 )
704
712
{
705
- oneapi::dpl::__internal::__ranges::__pattern_walk_n (
713
+ auto __res = oneapi::dpl::__internal::__ranges::__pattern_walk_n (
706
714
__tag,
707
715
oneapi::dpl::__par_backend_hetero::make_wrapped_policy<__copy2_wrapper>(
708
716
::std::forward<_ExecutionPolicy>(__exec)),
709
717
oneapi::dpl::__internal::__brick_copy<__hetero_tag<_BackendTag>, _ExecutionPolicy>{},
710
718
::std::forward<_Range1>(__rng1), ::std::forward<_Range3>(__rng3));
711
- }
712
- else
713
- {
714
- __par_backend_hetero::__parallel_merge (_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
715
- ::std::forward<_Range1>(__rng1), ::std::forward<_Range2>(__rng2),
716
- ::std::forward<_Range3>(__rng3), __comp)
717
- .__deferrable_wait ();
719
+ return {__res, 0 };
718
720
}
719
721
720
- return __n;
722
+ auto __res = __par_backend_hetero::__parallel_merge (
723
+ _BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__rng1),
724
+ ::std::forward<_Range2>(__rng2), ::std::forward<_Range3>(__rng3), __comp, __out_size_limit{});
725
+
726
+ auto __val = __res.get ();
727
+ return {__val.first , __val.second };
721
728
}
722
729
723
730
#if _ONEDPL_CPP20_RANGES_PRESENT
@@ -727,21 +734,27 @@ auto
727
734
__pattern_merge (__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r,
728
735
_Comp __comp, _Proj1 __proj1, _Proj2 __proj2)
729
736
{
730
- assert (std::ranges::size (__r1) + std::ranges::size (__r2) <= std::ranges::size (__out_r)); // for debug purposes only
731
-
732
737
auto __comp_2 = [__comp, __proj1, __proj2](auto && __val1, auto && __val2) { return std::invoke (__comp,
733
738
std::invoke (__proj1, std::forward<decltype (__val1)>(__val1)),
734
739
std::invoke (__proj2, std::forward<decltype (__val2)>(__val2)));};
735
740
736
- auto __res = oneapi::dpl::__internal::__ranges::__pattern_merge (__tag, std::forward<_ExecutionPolicy>(__exec),
737
- oneapi::dpl::__ranges::views::all_read (__r1), oneapi::dpl::__ranges::views::all_read (__r2),
738
- oneapi::dpl::__ranges::views::all_write (__out_r), __comp_2);
741
+ using _Index1 = std::ranges::range_difference_t <_R1>;
742
+ using _Index2 = std::ranges::range_difference_t <_R2>;
743
+ using _Index3 = std::ranges::range_difference_t <_OutRange>;
744
+
745
+ const _Index1 __n_1 = std::ranges::size (__r1);
746
+ const _Index2 __n_2 = std::ranges::size (__r2);
747
+ const _Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size (__out_r));
748
+
749
+ const std::pair __res = oneapi::dpl::__internal::__ranges::__pattern_merge (
750
+ __tag, std::forward<_ExecutionPolicy>(__exec), oneapi::dpl::__ranges::views::all_read (__r1),
751
+ oneapi::dpl::__ranges::views::all_read (__r2), oneapi::dpl::__ranges::views::all_write (__out_r), __comp_2);
739
752
740
753
using __return_t = std::ranges::merge_result<std::ranges::borrowed_iterator_t <_R1>, std::ranges::borrowed_iterator_t <_R2>,
741
754
std::ranges::borrowed_iterator_t <_OutRange>>;
742
755
743
- return __return_t {std::ranges::begin (__r1) + std::ranges::size (__r1) , std::ranges::begin (__r2) +
744
- std::ranges::size (__r2), std::ranges::begin (__out_r) + __res };
756
+ return __return_t {std::ranges::begin (__r1) + __res. first , std::ranges::begin (__r2) + __res. second ,
757
+ std::ranges::begin (__out_r) + __n_out };
745
758
}
746
759
#endif // _ONEDPL_CPP20_RANGES_PRESENT
747
760
0 commit comments