Skip to content

Commit 85032a7

Browse files
committed
[oneDPL][ranges][merge] + fixes
1 parent 907859e commit 85032a7

File tree

2 files changed

+26
-22
lines changed

2 files changed

+26
-22
lines changed

include/oneapi/dpl/pstl/algorithm_ranges_impl.h

+18-14
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ __brick_merge(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out
492492
}
493493
else
494494
{
495-
assert(__it_2 == __it_2_e);
495+
//assert(__it_2 == __it_2_e);
496496
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
497497
*__it_out = *__it_1;
498498
}
@@ -511,9 +511,14 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1
511511
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
512512
std::forward<decltype(__val2)>(__val2)));};
513513

514-
auto __n_1 = std::ranges::size(__r1);
515-
auto __n_2 = std::ranges::size(__r2);
516-
auto __n_out = std::min(__n_1 + __n_2, std::ranges::size(__out_r));
514+
using _Index1 = std::ranges::range_difference_t<_R1>;
515+
using _Index2 = std::ranges::range_difference_t<_R2>;
516+
using _Index3 = std::ranges::range_difference_t<_OutRange>;
517+
518+
_Index1 __n_1 = std::ranges::size(__r1);
519+
_Index2 __n_2 = std::ranges::size(__r2);
520+
521+
_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));
517522

518523
auto __it_1 = std::ranges::begin(__r1);
519524
auto __it_2 = std::ranges::begin(__r2);
@@ -523,8 +528,8 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1
523528
std::ranges::borrowed_iterator_t<_R1> __it_res_2;
524529

525530
__internal::__except_handler([&]() {
526-
__par_backend::__parallel_for(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), 0, __n_out,
527-
[=, &__r1, &__r2, &__out_r, &__it_res_1, &__it_res_2](auto __i, auto __j)
531+
__par_backend::__parallel_for(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
532+
[=, &__r1, &__r2, &__out_r, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j)
528533
{/*...*/
529534

530535
//a start merging point on the merge path; for each thread
@@ -534,19 +539,18 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1
534539
if(__i > 0)
535540
{
536541
//calc merge path intersection:
537-
using _Index3 = std::ranges::range_difference_t<_OutRange>;
538-
_Index3 __d_size = std::abs(std::max(0, __i - __n_2) - (std::min(__i, __n_1) - 1)) + 1;
542+
const _Index3 __d_size = std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;
539543

540-
auto __get_x = [__i, __n_1](auto __d) { return std::min(__i, __n_1) - __d; };
541-
auto __get_y = [__i, __n_1](auto __d) { return std::max(0, __i - __n_1) + __d; };
544+
auto __get_x = [__i, __n_1](auto __d) { return std::min<_Index1>(__i, __n_1) - __d; };
545+
auto __get_y = [__i, __n_1](auto __d) { return std::max<_Index1>(0, __i - __n_1) + __d; };
542546

543547
oneapi::dpl::counting_iterator<_Index3> __it_d(0);
544548
auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1,
545-
[](auto __d, auto __val) {
549+
[&](auto __d, auto __val) {
546550
auto __x = __get_x(__d);
547551
auto __y = __get_y(__d);
548552

549-
const auto __res = __comp_2(__r1[__x], __r2[__y]) ? 0 : 1
553+
const auto __res = __comp_2(__r1[__x], __r2[__y]) ? 0 : 1;
550554
return __res < __val;
551555
}
552556
);
@@ -556,7 +560,7 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1
556560

557561
}
558562

559-
const auto __n = __j - __i;
563+
const _Index3 __n = __j - __i;
560564

561565
//serial merge n elements, starting from input x and y, to [i, j) output range
562566
auto __res = __brick_merge(__it_1 + __x, __it_1 + __n_1,
@@ -647,7 +651,7 @@ __pattern_merge(__serial_tag</*IsVector*/std::false_type>, _ExecutionPolicy&& __
647651
}
648652
else
649653
{
650-
assert(__it_2 == std::ranges::end(__r2);
654+
//assert(__it_2 == std::ranges::end(__r2);
651655
for(; __it_1 != std::ranges::end(__r1) && __it_out != std::ranges::end(__out_r); ++__it_1, ++__it_out)
652656
*__it_out = *__it_1;
653657
}

test/parallel_api/ranges/std_ranges_merge.pass.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ main()
2626
auto merge_checker = [](std::ranges::random_access_range auto&& r_1,
2727
std::ranges::random_access_range auto&& r_2,
2828
std::ranges::random_access_range auto&& r_out, auto comp, auto proj1,
29-
auto proj2, auto&&... args)
29+
auto proj2)
3030
{
3131
using ret_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<decltype(r_1)>,
3232
std::ranges::borrowed_iterator_t<decltype(r_2)>, std::ranges::borrowed_iterator_t<decltype(r_out)>>;
@@ -37,19 +37,19 @@ main()
3737
for(;;)
3838
{
3939
if(it_out == std::ranges::end(r_out))
40-
return ret_type{res.in1, res.in2, res.out};
40+
return ret_type{it_1, it_2, it_out};
4141

4242
if(it_1 == std::ranges::end(r_1))
4343
{
44-
for(auto it_2 = std::ranges::begin(r_2) && it_out != std::ranges::end(r_out); ++it_2, ++it_out)
44+
for(auto it_2 = std::ranges::begin(r_2); it_out != std::ranges::end(r_out); ++it_2, ++it_out)
4545
*it_out = *it_2;
46-
return ret_type{res.in1, res.in2, res.out};
46+
return ret_type{it_1, it_2, it_out};
4747
}
4848
else if(it_2 == std::ranges::end(r_2))
4949
{
50-
for(auto it_1 = std::ranges::begin(r_1) && it_out != std::ranges::end(r_out); ++it_1, ++it_out)
50+
for(auto it_1 = std::ranges::begin(r_1); it_out != std::ranges::end(r_out); ++it_1, ++it_out)
5151
*it_out = *it_1;
52-
return ret_type{res.in1, res.in2, res.out};
52+
return ret_type{it_1, it_2, it_out};
5353
}
5454

5555
if (std::invoke(comp, std::invoke(proj1, *it_1), std::invoke(proj2, *it_2)))
@@ -64,10 +64,10 @@ main()
6464
}
6565
}
6666

67-
return ret_type{res.in1, res.in2, res.out};
67+
return ret_type{it_1, it_2, it_out};
6868
};
6969

70-
test_range_algo<0, int, data_in_in_out>{big_sz}(dpl_ranges::merge, merge_checker, std::ranges::less{});
70+
test_range_algo<0, int, data_in_in_out>{big_sz}(dpl_ranges::merge, merge_checker, std::ranges::less{}, std::identity{}, std::identity{});
7171

7272
test_range_algo<1, int, data_in_in_out>{}(dpl_ranges::merge, merge_checker, std::ranges::less{}, proj, proj);
7373
test_range_algo<2, P2, data_in_in_out>{}(dpl_ranges::merge, merge_checker, std::ranges::less{}, &P2::x, &P2::x);

0 commit comments

Comments
 (0)