|
31 | 31 | #include "parallel_backend.h"
|
32 | 32 | #include "parallel_impl.h"
|
33 | 33 | #include "iterator_impl.h"
|
| 34 | +#include "../functional" |
34 | 35 |
|
35 | 36 | #if _ONEDPL_HETERO_BACKEND
|
36 | 37 | # include "hetero/algorithm_impl_hetero.h" // for __pattern_fill_n, __pattern_generate_n
|
@@ -2948,6 +2949,49 @@ __pattern_remove_if(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec,
|
2948 | 2949 | // merge
|
2949 | 2950 | //------------------------------------------------------------------------
|
2950 | 2951 |
|
| 2952 | +template<typename It1, typename It2, typename ItOut, typename _Comp> |
| 2953 | +std::pair<It1, It2> |
| 2954 | +__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp, |
| 2955 | + /* __is_vector = */ std::false_type) |
| 2956 | +{ |
| 2957 | + while(__it_1 != __it_1_e && __it_2 != __it_2_e) |
| 2958 | + { |
| 2959 | + if (__comp(*__it_1, *__it_2)) |
| 2960 | + { |
| 2961 | + *__it_out = *__it_1; |
| 2962 | + ++__it_out, ++__it_1; |
| 2963 | + } |
| 2964 | + else |
| 2965 | + { |
| 2966 | + *__it_out = *__it_2; |
| 2967 | + ++__it_out, ++__it_2; |
| 2968 | + } |
| 2969 | + if(__it_out == __it_out_e) |
| 2970 | + return {__it_1, __it_2}; |
| 2971 | + } |
| 2972 | + |
| 2973 | + if(__it_1 == __it_1_e) |
| 2974 | + { |
| 2975 | + for(; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out) |
| 2976 | + *__it_out = *__it_2; |
| 2977 | + } |
| 2978 | + else |
| 2979 | + { |
| 2980 | + //assert(__it_2 == __it_2_e); |
| 2981 | + for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out) |
| 2982 | + *__it_out = *__it_1; |
| 2983 | + } |
| 2984 | + return {__it_1, __it_2}; |
| 2985 | +} |
| 2986 | + |
| 2987 | +template<typename It1, typename It2, typename ItOut, typename _Comp> |
| 2988 | +std::pair<It1, It2> |
| 2989 | +__brick_merge_2(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp, |
| 2990 | + /* __is_vector = */ std::true_type) |
| 2991 | +{ |
| 2992 | + return __unseq_backend::__simd_merge(__it_1, __it_1_e, __it_2, __it_2_e, __it_out, __it_out_e, __comp); |
| 2993 | +} |
| 2994 | + |
2951 | 2995 | template <class _ForwardIterator1, class _ForwardIterator2, class _OutputIterator, class _Compare>
|
2952 | 2996 | _OutputIterator
|
2953 | 2997 | __brick_merge(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2,
|
@@ -2980,10 +3024,87 @@ __pattern_merge(_Tag, _ExecutionPolicy&&, _ForwardIterator1 __first1, _ForwardIt
|
2980 | 3024 | typename _Tag::__is_vector{});
|
2981 | 3025 | }
|
2982 | 3026 |
|
| 3027 | +template<class _Tag, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2, |
| 3028 | + typename _Index2, typename _OutIt, typename _Index3, typename _Comp> |
| 3029 | +std::pair<_It1, _It2> |
| 3030 | +__pattern_merge_2(_Tag, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2, |
| 3031 | + _Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp) |
| 3032 | +{ |
| 3033 | + return __brick_merge_2(__it_1, __it_1 + __n_1, __it_2, __it_2 + __n_2, __it_out, __it_out + __n_out, __comp, |
| 3034 | + typename _Tag::__is_vector{}); |
| 3035 | +} |
| 3036 | + |
| 3037 | +template<typename _IsVector, typename _ExecutionPolicy, typename _It1, typename _Index1, typename _It2, |
| 3038 | + typename _Index2, typename _OutIt, typename _Index3, typename _Comp> |
| 3039 | +std::pair<_It1, _It2> |
| 3040 | +__pattern_merge_2(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _It1 __it_1, _Index1 __n_1, _It2 __it_2, |
| 3041 | + _Index2 __n_2, _OutIt __it_out, _Index3 __n_out, _Comp __comp) |
| 3042 | +{ |
| 3043 | + using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag; |
| 3044 | + |
| 3045 | + _It1 __it_res_1; |
| 3046 | + _It2 __it_res_2; |
| 3047 | + |
| 3048 | + __internal::__except_handler([&]() { |
| 3049 | + __par_backend::__parallel_for(__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out, |
| 3050 | + [=, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j) |
| 3051 | + { |
| 3052 | + //a start merging point on the merge path; for each thread |
| 3053 | + _Index1 __r = 0; //row index |
| 3054 | + _Index2 __c = 0; //column index |
| 3055 | + |
| 3056 | + if(__i > 0) |
| 3057 | + { |
| 3058 | + //calc merge path intersection: |
| 3059 | + const _Index3 __d_size = |
| 3060 | + std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1; |
| 3061 | + |
| 3062 | + auto __get_row = [__i, __n_1](auto __d) |
| 3063 | + { return std::min<_Index1>(__i, __n_1) - __d - 1; }; |
| 3064 | + auto __get_column = [__i, __n_1](auto __d) |
| 3065 | + { return std::max<_Index1>(0, __i - __n_1 - 1) + __d + (__i / (__n_1 + 1) > 0 ? 1 : 0); }; |
| 3066 | + |
| 3067 | + oneapi::dpl::counting_iterator<_Index3> __it_d(0); |
| 3068 | + |
| 3069 | + auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1, |
| 3070 | + [&](auto __d, auto __val) { |
| 3071 | + auto __r = __get_row(__d); |
| 3072 | + auto __c = __get_column(__d); |
| 3073 | + |
| 3074 | + oneapi::dpl::__internal::__compare<_Comp, oneapi::dpl::identity> |
| 3075 | + __cmp{__comp, oneapi::dpl::identity{}}; |
| 3076 | + const auto __res = (__cmp(__it_1[__r], __it_2[__c]) ? 1 : 0); |
| 3077 | + |
| 3078 | + return __res < __val; |
| 3079 | + } |
| 3080 | + ); |
| 3081 | + |
| 3082 | + //intersection point |
| 3083 | + __r = __get_row(__res_d); |
| 3084 | + __c = __get_column(__res_d); |
| 3085 | + ++__r; //to get a merge matrix ceil, lying on the current diagonal |
| 3086 | + } |
| 3087 | + |
| 3088 | + //serial merge n elements, starting from input x and y, to [i, j) output range |
| 3089 | + auto __res = __brick_merge_2(__it_1 + __r, __it_1 + __n_1, |
| 3090 | + __it_2 + __c, __it_2 + __n_2, |
| 3091 | + __it_out + __i, __it_out + __j, __comp, _IsVector{}); |
| 3092 | + |
| 3093 | + if(__j == __n_out) |
| 3094 | + { |
| 3095 | + __it_res_1 = __res.first; |
| 3096 | + __it_res_2 = __res.second; |
| 3097 | + } |
| 3098 | + }, _ONEDPL_MERGE_CUT_OFF); //grainsize |
| 3099 | + }); |
| 3100 | + |
| 3101 | + return {__it_res_1, __it_res_2}; |
| 3102 | +} |
| 3103 | + |
2983 | 3104 | template <class _IsVector, class _ExecutionPolicy, class _RandomAccessIterator1, class _RandomAccessIterator2,
|
2984 | 3105 | class _RandomAccessIterator3, class _Compare>
|
2985 | 3106 | _RandomAccessIterator3
|
2986 |
| -__pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1, |
| 3107 | +__pattern_merge(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __first1, |
2987 | 3108 | _RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2, _RandomAccessIterator2 __last2,
|
2988 | 3109 | _RandomAccessIterator3 __d_first, _Compare __comp)
|
2989 | 3110 | {
|
|
0 commit comments