@@ -845,26 +845,26 @@ struct __lazy_load_transform_op
845
845
}
846
846
};
847
847
848
- template <std::uint16_t __vec_size>
848
+ template <std::uint8_t __vec_size>
849
849
struct __vector_load
850
850
{
851
- static_assert (__vec_size <= 4 );
851
+ static_assert (__vec_size <= 4 , " Only vector sizes of 4 or less are supported " );
852
852
std::size_t __n;
853
853
template <typename _IdxType, typename _LoadOp, typename ... _Acc>
854
854
void
855
855
operator ()(std::true_type, _IdxType __start_idx, _LoadOp __load_op, _Acc... __acc) const
856
856
{
857
857
_ONEDPL_PRAGMA_UNROLL
858
- for (std::uint16_t __i = 0 ; __i < __vec_size; ++__i)
858
+ for (std::uint8_t __i = 0 ; __i < __vec_size; ++__i)
859
859
__load_op (__start_idx + __i, __i, __acc...);
860
860
}
861
861
862
862
template <typename _IdxType, typename _LoadOp, typename ... _Acc>
863
863
void
864
864
operator ()(std::false_type, _IdxType __start_idx, _LoadOp __load_op, _Acc... __acc) const
865
865
{
866
- std::uint16_t __elements = std::min (__vec_size, decltype (__vec_size)( __n - __start_idx) );
867
- for (std::uint16_t __i = 0 ; __i < __elements; ++__i)
866
+ std::uint8_t __elements = std::min (std:: size_t { __vec_size}, std:: size_t { __n - __start_idx} );
867
+ for (std::uint8_t __i = 0 ; __i < __elements; ++__i)
868
868
__load_op (__start_idx + __i, __i, __acc...);
869
869
}
870
870
};
@@ -891,20 +891,19 @@ struct __lazy_store_transform_op
891
891
}
892
892
};
893
893
894
- template <std::uint16_t __vec_size>
894
+ template <std::uint8_t __vec_size>
895
895
struct __vector_walk
896
896
{
897
- static_assert (__vec_size <= 4 );
897
+ static_assert (__vec_size <= 4 , " Only vector sizes of 4 or less are supported " );
898
898
std::size_t __n;
899
899
900
900
template <typename _IdxType, typename _WalkFunction, typename ... _Rngs>
901
901
void
902
902
operator ()(std::true_type, _IdxType __idx, _WalkFunction __f, _Rngs&&... __rngs) const
903
903
{
904
904
_ONEDPL_PRAGMA_UNROLL
905
- for (std::uint16_t __i = 0 ; __i < __vec_size; ++__i)
905
+ for (std::uint8_t __i = 0 ; __i < __vec_size; ++__i)
906
906
{
907
-
908
907
__f (__rngs[__idx + __i]...);
909
908
}
910
909
}
@@ -914,61 +913,63 @@ struct __vector_walk
914
913
void
915
914
operator ()(std::false_type, _IdxType __idx, _WalkFunction __f, _Rngs&&... __rngs) const
916
915
{
917
- std::uint16_t __elements = std::min (__vec_size, decltype (__vec_size)( __n - __idx) );
918
- for (std::uint16_t __i = 0 ; __i < __elements; ++__i)
916
+ std::uint8_t __elements = std::min (std:: size_t { __vec_size}, std:: size_t { __n - __idx} );
917
+ for (std::uint8_t __i = 0 ; __i < __elements; ++__i)
919
918
{
920
919
__f (__rngs[__idx + __i]...);
921
920
}
922
921
}
923
922
};
924
923
925
- template <std::uint16_t __vec_size>
924
+ template <std::uint8_t __vec_size>
926
925
struct __vector_store
927
926
{
927
+ static_assert (__vec_size <= 4 , " Only vector sizes of 4 or less are supported" );
928
928
std::size_t __n;
929
- static_assert (__vec_size <= 4 );
930
- template <typename _IdxType, typename _StoreOp, typename ... _Acc >
929
+
930
+ template <typename _IdxType, typename _StoreOp, typename ... _Rngs >
931
931
void
932
- operator ()(std::true_type, _IdxType __start_idx, _StoreOp __store_op, _Acc ... __acc ) const
932
+ operator ()(std::true_type, _IdxType __start_idx, _StoreOp __store_op, _Rngs ... __rngs ) const
933
933
{
934
934
_ONEDPL_PRAGMA_UNROLL
935
- for (std::uint16_t __i = 0 ; __i < __vec_size; ++__i)
936
- __store_op (__i, __start_idx + __i, __acc ...);
935
+ for (std::uint8_t __i = 0 ; __i < __vec_size; ++__i)
936
+ __store_op (__i, __start_idx + __i, __rngs ...);
937
937
}
938
- template <typename _IdxType, typename _StoreOp, typename ... _Acc >
938
+ template <typename _IdxType, typename _StoreOp, typename ... _Rngs >
939
939
void
940
- operator ()(std::false_type, _IdxType __start_idx, _StoreOp __store_op, _Acc ... __acc ) const
940
+ operator ()(std::false_type, _IdxType __start_idx, _StoreOp __store_op, _Rngs ... __rngs ) const
941
941
{
942
- std::uint16_t __elements = std::min (__vec_size, decltype (__vec_size)( __n - __start_idx) );
943
- for (std::uint16_t __i = 0 ; __i < __elements; ++__i)
944
- __store_op (__i, __start_idx + __i, __acc ...);
942
+ std::uint8_t __elements = std::min (std:: size_t { __vec_size}, std:: size_t { __n - __start_idx} );
943
+ for (std::uint8_t __i = 0 ; __i < __elements; ++__i)
944
+ __store_op (__i, __start_idx + __i, __rngs ...);
945
945
}
946
946
};
947
947
948
- template <std::uint16_t __vec_size>
948
+ template <std::uint8_t __vec_size>
949
949
struct __vector_reverse
950
950
{
951
+ static_assert (__vec_size <= 4 , " Only vector sizes of 4 or less are supported" );
951
952
template <typename _IsFull, typename _Idx, typename _Array>
952
953
void
953
954
operator ()(_IsFull __is_full, const _Idx __elements_to_process, _Array __array) const
954
955
{
955
956
if constexpr (__is_full)
956
957
{
957
958
_ONEDPL_PRAGMA_UNROLL
958
- for (std::uint16_t __i = 0 ; __i != __vec_size / 2 ; ++__i)
959
+ for (std::uint8_t __i = 0 ; __i < __vec_size / 2 ; ++__i)
959
960
std::swap (__array[__i].__v , __array[__vec_size - __i - 1 ].__v );
960
961
}
961
962
else
962
963
{
963
- for (std::uint16_t __i = 0 ; __i != __elements_to_process / 2 ; ++__i)
964
+ for (std::uint8_t __i = 0 ; __i < __elements_to_process / 2 ; ++__i)
964
965
std::swap (__array[__i].__v , __array[__elements_to_process - __i - 1 ].__v );
965
966
}
966
967
}
967
968
};
968
969
969
970
// Processes a loop with a given stride. Intended to be used with sub-group / work-group strides for good memory access patterns
970
971
// (potentially with vectorization)
971
- template <std::uint16_t __num_strides>
972
+ template <std::uint8_t __num_strides>
972
973
struct __strided_loop
973
974
{
974
975
std::size_t __n;
@@ -978,7 +979,7 @@ struct __strided_loop
978
979
_Ranges&&... __rngs) const
979
980
{
980
981
_ONEDPL_PRAGMA_UNROLL
981
- for (std::uint16_t __i = 0 ; __i < __num_strides; ++__i)
982
+ for (std::uint8_t __i = 0 ; __i < __num_strides; ++__i)
982
983
{
983
984
__loop_body_op (std::true_type{}, __idx, __rngs...);
984
985
__idx += __stride;
@@ -992,7 +993,7 @@ struct __strided_loop
992
993
// Constrain the number of iterations as much as possible and then pass the knowledge that we are not a full loop to the body operation
993
994
const std::uint8_t __adjusted_iters_per_work_item =
994
995
oneapi::dpl::__internal::__dpl_ceiling_div (__n - __idx, __stride);
995
- for (std::uint16_t __i = 0 ; __i < __adjusted_iters_per_work_item; ++__i)
996
+ for (std::uint8_t __i = 0 ; __i < __adjusted_iters_per_work_item; ++__i)
996
997
{
997
998
__loop_body_op (std::false_type{}, __idx, __rngs...);
998
999
__idx += __stride;
0 commit comments