Skip to content

Commit d443dbe

Browse files
committed
[oneDPL][ranges][merge] support size limit for output; + draft for merge path for the host backend (__pattern_merge_2)
1 parent 85032a7 commit d443dbe

File tree

3 files changed

+141
-92
lines changed

3 files changed

+141
-92
lines changed

include/oneapi/dpl/pstl/algorithm_impl.h

+135-1
Original file line numberDiff line numberDiff line change
@@ -2948,6 +2948,40 @@ __pattern_remove_if(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec,
29482948
// merge
29492949
//------------------------------------------------------------------------
29502950

2951+
template<std::random_access_iterator It1, std::random_access_iterator It2, std::random_access_iterator ItOut, typename _Comp>
2952+
std::pair<It1, It2>
2953+
__brick_merge(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp)
2954+
{
2955+
while(__it_1 != __it_1_e && __it_2 != __it_2_e)
2956+
{
2957+
if (__comp(*__it_1, *__it_2))
2958+
{
2959+
*__it_out = *__it_1;
2960+
++__it_out, ++__it_1;
2961+
}
2962+
else
2963+
{
2964+
*__it_out = *__it_2;
2965+
++__it_out, ++__it_2;
2966+
}
2967+
if(__it_out == __it_out_e)
2968+
return {__it_1, __it_2};
2969+
}
2970+
2971+
if(__it_1 == __it_1_e)
2972+
{
2973+
for(; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out)
2974+
*__it_out = *__it_2;
2975+
}
2976+
else
2977+
{
2978+
//assert(__it_2 == __it_2_e);
2979+
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
2980+
*__it_out = *__it_1;
2981+
}
2982+
return {__it_1, __it_2};
2983+
}
2984+
29512985
template <class _ForwardIterator1, class _ForwardIterator2, class _OutputIterator, class _Compare>
29522986
_OutputIterator
29532987
__brick_merge(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2,
@@ -2980,13 +3014,106 @@ __pattern_merge(_Tag, _ExecutionPolicy&&, _ForwardIterator1 __first1, _ForwardIt
29803014
typename _Tag::__is_vector{});
29813015
}
29823016

3017+
template<class ForwardIt, class T = typename std::iterator_traits<ForwardIt>::value_type,
3018+
class Compare>
3019+
ForwardIt lower_bound_2(ForwardIt first, ForwardIt last, const T& value, Compare comp)
3020+
{
3021+
ForwardIt it;
3022+
typename std::iterator_traits<ForwardIt>::difference_type count, step;
3023+
count = std::distance(first, last);
3024+
3025+
while (count > 0)
3026+
{
3027+
it = first;
3028+
step = count / 2;
3029+
std::advance(it, step);
3030+
3031+
std::cout << "it: " << *it << " ";
3032+
if (comp(*it, value))
3033+
{
3034+
first = ++it;
3035+
count -= step + 1;
3036+
}
3037+
else
3038+
count = step;
3039+
}
3040+
3041+
std::cout << "first: " << *first << " ";
3042+
std::cout << std::endl;
3043+
return first;
3044+
}
3045+
3046+
template<typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2,
3047+
typename _Index2, typename _OutIt, typename _Index3, typename _Comp>
3048+
std::pair<_It1, _It2>
3049+
__pattern_merge_2(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2,
3050+
_Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp)
3051+
{
3052+
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;
3053+
3054+
_It1 __it_res_1;
3055+
_It2 __it_res_2;
3056+
3057+
__internal::__except_handler([&]() {
3058+
__par_backend::__parallel_for(__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
3059+
[=, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j)
3060+
{
3061+
//a start merging point on the merge path; for each thread
3062+
_Index1 __r = 0; //row index
3063+
_Index2 __c = 0; //column index
3064+
3065+
if(__i > 0)
3066+
{
3067+
//calc merge path intersection:
3068+
const _Index3 __d_size = std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;
3069+
3070+
auto __get_row = [__i, __n_1](auto __d) { return std::min<_Index1>(__i, __n_1) - __d - 1; };
3071+
auto __get_column = [__i, __n_1](auto __d) { return std::max<_Index1>(0, __i - __n_1 - 1) + __d + (__i / (__n_1 + 1) > 0 ? 1 : 0); };
3072+
3073+
oneapi::dpl::counting_iterator<_Index3> __it_d(0);
3074+
3075+
auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1,
3076+
[&](auto __d, auto __val) {
3077+
auto __r = __get_row(__d);
3078+
auto __c = __get_column(__d);
3079+
3080+
oneapi::dpl::__internal::__compare<_Comp, std::identity> __cmp{__comp, std::identity{}};
3081+
const auto __res = (__cmp(__it_1[__r], __it_2[__c]) ? 1 : 0);
3082+
3083+
return __res < __val;
3084+
}
3085+
);
3086+
3087+
//intersection point
3088+
__r = __get_row(__res_d);
3089+
__c = __get_column(__res_d);
3090+
++__r; //to get a merge matrix ceil, lying on the current diagonal
3091+
}
3092+
3093+
//serial merge n elements, starting from input x and y, to [i, j) output range
3094+
auto __res = __brick_merge(__it_1 + __r, __it_1 + __n_1,
3095+
__it_2 + __c, __it_2 + __n_2,
3096+
__it_out + __i, __it_out + __j, __comp);
3097+
3098+
if(__j == __n_out)
3099+
{
3100+
__it_res_1 = __res.first;
3101+
__it_res_2 = __res.second;
3102+
}
3103+
}, /*_ONEDPL_MERGE_CUT_OFF*/10);
3104+
});
3105+
3106+
return {__it_res_1, __it_res_2};
3107+
}
3108+
29833109
template <class _IsVector, class _ExecutionPolicy, class _RandomAccessIterator1, class _RandomAccessIterator2,
29843110
class _RandomAccessIterator3, class _Compare>
29853111
_RandomAccessIterator3
2986-
__pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1,
3112+
__pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1,
29873113
_RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2, _RandomAccessIterator2 __last2,
29883114
_RandomAccessIterator3 __d_first, _Compare __comp)
29893115
{
3116+
#if 0
29903117
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;
29913118

29923119
return __internal::__except_handler([&]() {
@@ -2999,6 +3126,13 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAcc
29993126
});
30003127
return __d_first + (__last1 - __first1) + (__last2 - __first2);
30013128
});
3129+
#else
3130+
auto __n_1 = __last1 - __first1;
3131+
auto __n_2 = __last2 - __first2;
3132+
auto __n_3 = __n_1 + __n_2;
3133+
__pattern_merge_2(__tag, std::forward<_ExecutionPolicy>(__exec), __first1, __n_1, __first2, __n_2, __d_first, __n_3, __comp);
3134+
return __d_first + __n_3;
3135+
#endif
30023136
}
30033137

30043138
//------------------------------------------------------------------------

include/oneapi/dpl/pstl/algorithm_ranges_impl.h

+3-89
Original file line numberDiff line numberDiff line change
@@ -465,48 +465,12 @@ __pattern_merge(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _
465465
return __return_type{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) + std::ranges::size(__r2), __res};
466466
}
467467

468-
template<std::random_access_iterator It1, std::random_access_iterator It2, std::random_access_iterator ItOut, typename _Comp>
469-
std::pair<It1, It2>
470-
__brick_merge(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp)
471-
{
472-
while(__it_1 != __it_1_e && __it_2 != __it_2_e)
473-
{
474-
if (__comp(*__it_1, *__it_2))
475-
{
476-
*__it_out = *__it_1;
477-
++__it_out, ++__it_1;
478-
}
479-
else
480-
{
481-
*__it_out = *__it_2;
482-
++__it_out, ++__it_2;
483-
}
484-
if(__it_out == __it_out_e)
485-
return {__it_1, __it_2};
486-
}
487-
488-
if(__it_1 == __it_1_e)
489-
{
490-
for(; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out)
491-
*__it_out = *__it_2;
492-
}
493-
else
494-
{
495-
//assert(__it_2 == __it_2_e);
496-
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
497-
*__it_out = *__it_1;
498-
}
499-
return {__it_1, __it_2};
500-
}
501-
502468
template<typename _IsVector, typename _ExecutionPolicy, typename _R1, typename _R2, typename _OutRange, typename _Comp,
503469
typename _Proj1, typename _Proj2>
504470
auto
505-
__pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
471+
__pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
506472
_Proj1 __proj1, _Proj2 __proj2)
507473
{
508-
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;
509-
510474
auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
511475
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
512476
std::forward<decltype(__val2)>(__val2)));};
@@ -517,68 +481,18 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1
517481

518482
_Index1 __n_1 = std::ranges::size(__r1);
519483
_Index2 __n_2 = std::ranges::size(__r2);
520-
521484
_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));
522485

523486
auto __it_1 = std::ranges::begin(__r1);
524487
auto __it_2 = std::ranges::begin(__r2);
525488
auto __it_out = std::ranges::begin(__out_r);
526489

527-
std::ranges::borrowed_iterator_t<_R1> __it_res_1;
528-
std::ranges::borrowed_iterator_t<_R1> __it_res_2;
529-
530-
__internal::__except_handler([&]() {
531-
__par_backend::__parallel_for(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
532-
[=, &__r1, &__r2, &__out_r, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j)
533-
{/*...*/
534-
535-
//a start merging point on the merge path; for each thread
536-
std::ranges::range_difference_t<_R1> __x = 0;
537-
std::ranges::range_difference_t<_R2> __y = 0;
538-
539-
if(__i > 0)
540-
{
541-
//calc merge path intersection:
542-
const _Index3 __d_size = std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;
543-
544-
auto __get_x = [__i, __n_1](auto __d) { return std::min<_Index1>(__i, __n_1) - __d; };
545-
auto __get_y = [__i, __n_1](auto __d) { return std::max<_Index1>(0, __i - __n_1) + __d; };
546-
547-
oneapi::dpl::counting_iterator<_Index3> __it_d(0);
548-
auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1,
549-
[&](auto __d, auto __val) {
550-
auto __x = __get_x(__d);
551-
auto __y = __get_y(__d);
552-
553-
const auto __res = __comp_2(__r1[__x], __r2[__y]) ? 0 : 1;
554-
return __res < __val;
555-
}
556-
);
557-
//intersection point
558-
__x = __get_x(__res_d);
559-
__y = __get_y(__res_d);
560-
561-
}
562-
563-
const _Index3 __n = __j - __i;
564-
565-
//serial merge n elements, starting from input x and y, to [i, j) output range
566-
auto __res = __brick_merge(__it_1 + __x, __it_1 + __n_1,
567-
__it_2 + __y, __it_2 + __n_2,
568-
__it_out + __i, __it_out + __j, __comp_2);
569-
570-
if(__j == __n_out)
571-
{
572-
__it_res_1 = __res.first;
573-
__it_res_2 = __res.second;
574-
}
575-
});
576-
});
490+
auto __res = __pattern_merge_2(__tag, std::forward<_ExecutionPolicy>(__exec), __it_1, __n_1, __it_2, __n_2, __it_out, __n_out, __comp_2);
577491

578492
using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
579493
std::ranges::borrowed_iterator_t<_OutRange>>;
580494

581-
return __return_type{__it_res_1, __it_res_2, std::ranges::begin(__out_r) + __n_out};
495+
return __return_type{__res.first, __res.second, __it_out + __n_out};
582496
}
583497

584498
//TODO:

include/oneapi/dpl/pstl/parallel_backend_tbb.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,11 @@ class __parallel_for_body
9292
// wrapper over tbb::parallel_for
9393
template <class _ExecutionPolicy, class _Index, class _Fp>
9494
void
95-
__parallel_for(oneapi::dpl::__internal::__tbb_backend_tag, _ExecutionPolicy&&, _Index __first, _Index __last, _Fp __f)
95+
__parallel_for(oneapi::dpl::__internal::__tbb_backend_tag, _ExecutionPolicy&&, _Index __first, _Index __last, _Fp __f,
96+
std::size_t __grainsize = 1)
9697
{
9798
tbb::this_task_arena::isolate([=]() {
98-
tbb::parallel_for(tbb::blocked_range<_Index>(__first, __last), __parallel_for_body<_Index, _Fp>(__f));
99+
tbb::parallel_for(tbb::blocked_range<_Index>(__first, __last, __grainsize), __parallel_for_body<_Index, _Fp>(__f));
99100
});
100101
}
101102

0 commit comments

Comments
 (0)