diff --git a/core/test/base/half.cpp b/core/test/base/half.cpp index e242d17fb50..b7ef18a689d 100644 --- a/core/test/base/half.cpp +++ b/core/test/base/half.cpp @@ -131,6 +131,77 @@ TEST(FloatToHalf, TruncatesLargeNumberRoundToEven) } +TEST(FloatToHalf, ConvertsRandomPositiveNumber) +{ + half x = create_from_bits("0" "01101011" "00011011000001011110010"); + + ASSERT_EQ(get_bits(x), get_bits("0" "00000" "0000010010")); +} + + +TEST(FloatToHalf, RoundsUpToEvenNumber) +{ + half x = create_from_bits("0" "01101001" "11100000000000000000000"); + + ASSERT_EQ(get_bits(x), get_bits("0" "00000" "0000001000")); +} + + +TEST(FloatToHalf, RoundsDownToEvenNumber) +{ + half x = create_from_bits("1" "01101100" "11010100000000000000000"); + + ASSERT_EQ(get_bits(x), get_bits("1" "00000" "0000111010")); +} + +TEST(FloatToHalf, RoundsDownToEvenNumber2) +{ + half x = create_from_bits("1" "01101100" "11010100000000000000001"); + + ASSERT_EQ(get_bits(x), get_bits("1" "00000" "0000111011")); +} + + +TEST(FloatToHalf, LargestNumberThatConvertsToZero) +{ + half x = create_from_bits("0" "01100110" "00000000000000000000000"); + + ASSERT_EQ(get_bits(x), get_bits("0" "00000" "0000000000")); +} + + +TEST(FloatToHalf, SmallestNumberThatDoesntConvertToZero) +{ + half x = create_from_bits("0" "01100110" "00000000000000000000001"); + + ASSERT_EQ(get_bits(x), get_bits("0" "00000" "0000000001")); +} + + +TEST(FloatToHalf, RandomNumberThatConvertsToPositiveZero) +{ + half x = create_from_bits("0" "01001101" "10010100101111001101010"); + + ASSERT_EQ(get_bits(x), get_bits("0" "00000" "0000000000")); +} + + +TEST(FloatToHalf, RandomNumberThatConvertsToNegativeZero) +{ + half x = create_from_bits("1" "01001000" "10000010110100001010001"); + + ASSERT_EQ(get_bits(x), get_bits("1" "00000" "0000000000")); +} + + +TEST(FloatToHalf, BoundaryTest) +{ + half x = create_from_bits("0" "01110000" "11111111110000000000000"); + + ASSERT_EQ(get_bits(x), get_bits("0" "00001" "0000000000")); +} + + TEST(HalfToFloat, ConvertsOne) { float x = create_from_bits("0" "01111" "0000000000"); @@ -195,4 +266,52 @@ TEST(HalfToFloat, ExtendsLargeNumber) } +TEST(HalfToFloat, ConvertsPositiveRandomDenormal) +{ + float x = create_from_bits("0" "00000" "1110101100"); + + ASSERT_EQ(get_bits(x), get_bits("0" "01110000" "11010110000000000000000")); +} + + +TEST(HalfToFloat, ConvertsNegativeRandomDenormal) +{ + float x = create_from_bits("1" "00000" "0010111101"); + + ASSERT_EQ(get_bits(x), get_bits("1" "01101110" "01111010000000000000000")); +} + + +TEST(HalfToFloat, ConvertsSmallestPositiveDenormal) +{ + float x = create_from_bits("0" "00000" "0000000001"); + + ASSERT_EQ(get_bits(x), get_bits("0" "01100111" "00000000000000000000000")); +} + + +TEST(HalfToFloat, ConvertsSmallestNegativeDenormal) +{ + float x = create_from_bits("1" "00000" "1111111111"); + + ASSERT_EQ(get_bits(x), get_bits("1" "01110000" "11111111100000000000000")); +} + + +TEST(HalfToFloat, ConvertsLargestPositiveDenormal) +{ + float x = create_from_bits("0" "00000" "1111111111"); + + ASSERT_EQ(get_bits(x), get_bits("0" "01110000" "11111111100000000000000")); +} + + +TEST(HalfToFloat, ConvertsLargestNegativeDenormal) +{ + float x = create_from_bits("1" "00000" "0000000001"); + + ASSERT_EQ(get_bits(x), get_bits("1" "01100111" "00000000000000000000000")); +} + + // clang-format on diff --git a/include/ginkgo/core/base/half.hpp b/include/ginkgo/core/base/half.hpp index 4b492cf7df7..2aeed57d949 100644 --- a/include/ginkgo/core/base/half.hpp +++ b/include/ginkgo/core/base/half.hpp @@ -415,8 +415,43 @@ class alignas(std::uint16_t) half { if (f16_traits::is_inf(exp)) { return conv::shift_sign(data_) | exp; } else if (f16_traits::is_denom(exp)) { - // TODO: handle denormals - return conv::shift_sign(data_); + // gap to fp16 denormal exponents (+1 from normal to denormal + // exponent base) + const auto gap_to_fp16 = + ((conv::bias_change - + ((data_ & f32_traits::exponent_mask) >> + conv::significand_offset)) >> + f16_traits::significand_bits) + + 1; + + // get the tail length which will be rounding + const auto tail_len = gap_to_fp16 + conv::significand_offset; + + if (tail_len > f32_traits::significand_bits + 1) { + return conv::shift_sign(data_); + } + + // all significant (including implicitly leading 1) will be + // moved after representation field more than one digit (less + // than half) such that it will rounding to zero. + const auto explicit_significand = + (data_ & f32_traits::significand_mask) | + (1 << f32_traits::significand_bits); + + const auto tail = + explicit_significand & + static_cast((1 << tail_len) - 1); + + auto new_significand = explicit_significand >> tail_len; + + const auto result = + conv::shift_sign(data_) | exp | new_significand; + + const auto half = + static_cast(1 << (tail_len - 1)); + + return result + + (tail > half || ((tail == half) && (result & 1))); } else { // Rounding to even const auto result = conv::shift_sign(data_) | exp | @@ -442,8 +477,42 @@ class alignas(std::uint16_t) half { return conv::shift_sign(data_) | f32_traits::exponent_mask | f32_traits::significand_mask; } else if (f16_traits::is_denom(data_)) { - // TODO: handle denormals - return conv::shift_sign(data_); + if (!(data_ & f16_traits::significand_mask)) { + return conv::shift_sign(data_); + } + + int leading_zeros{}; + +// Counts leading zeros in the significand to determine the +// normalization shift +#if defined(_MSC_VER) + unsigned long index{}; + _BitScanReverse(&index, static_cast( + f16_traits::significand_mask & data_)); + + leading_zeros = f16_traits::significand_bits - index - 1; +#else + leading_zeros = __builtin_clz(static_cast( + f16_traits::significand_mask & data_)) - + f16_traits::exponent_bits - f16_traits::sign_bits - + CHAR_BIT * (sizeof(conv::result_bits) - + sizeof(conv::source_bits)); +#endif + + // Computes the new exponent, 0xxxxxxxx000...00 + auto new_exponent = + ((conv::bias_change >> f32_traits::significand_bits) - + leading_zeros) + << f32_traits::significand_bits; + + // Shifts the original significand to normalize it, remove the + // implicit '1', and align it in the new 23-bit field + auto new_significand = + (static_cast(data_) + << (conv::significand_offset + leading_zeros + 1)) & + f32_traits::significand_mask; + + return conv::shift_sign(data_) | new_exponent | new_significand; } else { return conv::shift_sign(data_) | conv::shift_exponent(data_) | conv::shift_significand(data_);