Skip to content

Commit 7000f21

Browse files
committed
Update some precision info on math intrinsics for Vulkan and Metal.
1 parent 58a6d7c commit 7000f21

File tree

1 file changed

+76
-31
lines changed

1 file changed

+76
-31
lines changed

src/FastMathFunctions.cpp

+76-31
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,33 @@
55
#include "IRMutator.h"
66
#include "IROperator.h"
77
#include "IRPrinter.h"
8+
#include "Util.h"
89

910
namespace Halide {
1011
namespace Internal {
11-
namespace ApproxImpl {
1212

13+
namespace {
1314
constexpr double PI = 3.14159265358979323846;
1415
constexpr double ONE_OVER_PI = 1.0 / PI;
1516
constexpr double TWO_OVER_PI = 2.0 / PI;
1617
constexpr double PI_OVER_TWO = PI / 2;
1718

19+
float ulp_to_ae(float max, int ulp) {
20+
internal_assert(max > 0.0);
21+
uint32_t n = reinterpret_bits<uint32_t>(max);
22+
float fn = reinterpret_bits<float>(n + ulp);
23+
return fn - max;
24+
}
25+
26+
uint32_t ae_to_ulp(float smallest, float ae) {
27+
internal_assert(smallest >= 0.0);
28+
float fn = smallest + ae;
29+
return reinterpret_bits<uint32_t>(fn) - reinterpret_bits<uint32_t>(smallest);
30+
}
31+
} // namespace
32+
33+
namespace ApproxImpl {
34+
1835
std::pair<float, float> split_float(double value) {
1936
float high = float(value); // Convert to single precision
2037
float low = float(value - double(high)); // Compute the residual part
@@ -152,7 +169,7 @@ Expr fast_sin(const Expr &x_full, ApproximationPrecision precision) {
152169
Expr k = cast<int>(k_real);
153170
Expr k_mod4 = k % 4; // Halide mod is always positive!
154171
Expr mirror = (k_mod4 == 1) || (k_mod4 == 3);
155-
Expr flip_sign = (k_mod4 > 1) ^ (x_full < 0);
172+
Expr flip_sign = (k_mod4 > 1) != (x_full < 0);
156173

157174
// Reduce the angle modulo pi/2: i.e., to the angle within the quadrant.
158175
Expr x = x_abs - k_real * make_const(type, PI_OVER_TWO);
@@ -417,7 +434,7 @@ Expr fast_tanh(const Expr &x, ApproximationPrecision prec) {
417434
Expr arg_exp = select(flip_exp, -abs_x, abs_x);
418435
Expr exp2xm1 = Halide::fast_expm1(2 * arg_exp, prec);
419436
Expr tanh = (exp2xm1) / (exp2xm1 + make_const(type, 2));
420-
tanh = select(flip_exp ^ flip_sign, -tanh, tanh);
437+
tanh = select(flip_exp != flip_sign, -tanh, tanh);
421438
return common_subexpression_elimination(tanh, true);
422439
#else
423440
// expm1 is devloped around 0 and is ULP accurate in [-ln(2)/2, ln(2)/2].
@@ -465,6 +482,19 @@ struct IntrinsicsInfo {
465482
} intrinsic;
466483
};
467484

485+
IntrinsicsInfo::NativeFunc MAE_func(bool fast, float mae, float smallest_output = 0.0f) {
486+
return IntrinsicsInfo::NativeFunc{fast, OO::MAE, mae, ae_to_ulp(smallest_output, mae)};
487+
}
488+
IntrinsicsInfo::NativeFunc MULPE_func(bool fast, uint64_t mulpe, float largest_output) {
489+
return IntrinsicsInfo::NativeFunc{fast, OO::MULPE, ulp_to_ae(largest_output, mulpe), mulpe};
490+
}
491+
IntrinsicsInfo::IntrinsicImpl MAE_intrinsic(float mae, float smallest_output = 0.0f) {
492+
return IntrinsicsInfo::IntrinsicImpl{OO::MAE, mae, ae_to_ulp(smallest_output, mae)};
493+
}
494+
IntrinsicsInfo::IntrinsicImpl MULPE_intrinsic(uint64_t mulpe, float largest_output) {
495+
return IntrinsicsInfo::IntrinsicImpl{OO::MULPE, ulp_to_ae(largest_output, mulpe), mulpe};
496+
}
497+
468498
struct IntrinsicsInfoPerDeviceAPI {
469499
OO reasonable_behavior; // A reasonable optimization objective for a given function.
470500
float default_mae; // A reasonable desirable MAE (if specified)
@@ -475,37 +505,45 @@ struct IntrinsicsInfoPerDeviceAPI {
475505
// clang-format off
476506
IntrinsicsInfoPerDeviceAPI ii_sin{
477507
OO::MAE, 1e-5f, 0, {
478-
{DeviceAPI::Vulkan, {true}, {}},
479-
{DeviceAPI::CUDA, {false}, {OO::MAE, 5e-7f, 1'000'000}},
480-
{DeviceAPI::Metal, {true}, {OO::MAE, 6e-5f, 400'000}},
508+
{DeviceAPI::Vulkan, MAE_func(true, 5e-4f), {}},
509+
{DeviceAPI::CUDA, {false}, MAE_intrinsic(5e-7f)},
510+
{DeviceAPI::Metal, {true}, MAE_intrinsic(1.2e-4f)}, // 2^-13
481511
{DeviceAPI::WebGPU, {true}, {}},
482-
{DeviceAPI::OpenCL, {false}, {OO::MAE, 5e-7f, 1'000'000}},
512+
{DeviceAPI::OpenCL, {false}, MAE_intrinsic(5e-7f)},
483513
}};
484514

485515
IntrinsicsInfoPerDeviceAPI ii_cos{
486516
OO::MAE, 1e-5f, 0, {
487-
{DeviceAPI::Vulkan, {true}, {}},
488-
{DeviceAPI::CUDA, {false}, {OO::MAE, 5e-7f, 1'000'000}},
489-
{DeviceAPI::Metal, {true}, {OO::MAE, 7e-7f, 5'000}},
517+
{DeviceAPI::Vulkan, MAE_func(true, 5e-4f), {}},
518+
{DeviceAPI::CUDA, {false}, MAE_intrinsic(5e-7f)},
519+
{DeviceAPI::Metal, {true}, MAE_intrinsic(1.2e-4f)}, // Seems to be 7e-7, but spec says 2^-13...
490520
{DeviceAPI::WebGPU, {true}, {}},
491-
{DeviceAPI::OpenCL, {false}, {OO::MAE, 5e-7f, 1'000'000}},
521+
{DeviceAPI::OpenCL, {false}, MAE_intrinsic(5e-7f)},
492522
}};
493523

494-
IntrinsicsInfoPerDeviceAPI ii_atan_atan2{
524+
IntrinsicsInfoPerDeviceAPI ii_atan{
495525
OO::MAE, 1e-5f, 0, {
496526
// no intrinsics available
497527
{DeviceAPI::Vulkan, {false}, {}},
498-
{DeviceAPI::Metal, {true}, {OO::MAE, 5e-6f}},
528+
{DeviceAPI::Metal, {true}, MULPE_intrinsic(5, float(PI * 0.501))}, // They claim <= 5 ULP!
529+
{DeviceAPI::WebGPU, {true}, {}},
530+
}};
531+
532+
IntrinsicsInfoPerDeviceAPI ii_atan2{
533+
OO::MAE, 1e-5f, 0, {
534+
// no intrinsics available
535+
{DeviceAPI::Vulkan, {false}, {}},
536+
{DeviceAPI::Metal, {true}, MAE_intrinsic(5e-6f, 0.0f)},
499537
{DeviceAPI::WebGPU, {true}, {}},
500538
}};
501539

502540
IntrinsicsInfoPerDeviceAPI ii_tan{
503541
OO::MULPE, 0.0f, 2000, {
504-
{DeviceAPI::Vulkan, {true, OO::MAE, 2e-6f, 1'000'000}, {}}, // Vulkan tan seems to mimic our CUDA implementation
505-
{DeviceAPI::CUDA, {false}, {OO::MAE, 2e-6f, 1'000'000}},
506-
{DeviceAPI::Metal, {true}, {OO::MULPE, 2e-6f, 1'000'000}},
542+
{DeviceAPI::Vulkan, MAE_func(true, 2e-6f), {}}, // Vulkan tan() seems to mimic our CUDA implementation
543+
{DeviceAPI::CUDA, {false}, MAE_intrinsic(2e-6f)},
544+
{DeviceAPI::Metal, {true}, MAE_intrinsic(2e-6f)}, // sin()/cos()
507545
{DeviceAPI::WebGPU, {true}, {}},
508-
{DeviceAPI::OpenCL, {false}, {OO::MAE, 2e-6f, 1'000'000}},
546+
{DeviceAPI::OpenCL, {false}, MAE_intrinsic(2e-6f)},
509547
}};
510548

511549
IntrinsicsInfoPerDeviceAPI ii_expm1{
@@ -514,16 +552,16 @@ IntrinsicsInfoPerDeviceAPI ii_expm1{
514552

515553
IntrinsicsInfoPerDeviceAPI ii_exp{
516554
OO::MULPE, 0.0f, 50, {
517-
{DeviceAPI::Vulkan, {true}, {}},
518-
{DeviceAPI::CUDA, {false}, {OO::MULPE, 0.0f, 5}},
519-
{DeviceAPI::Metal, {true}, {OO::MULPE, 0.0f, 5}}, // precise::exp() is fast on metal
555+
{DeviceAPI::Vulkan, MULPE_func(true, 3 + 2 * 2, 2.0f), {}},
556+
{DeviceAPI::CUDA, {false}, MULPE_intrinsic(5, 2.0f)},
557+
{DeviceAPI::Metal, {true}, MULPE_intrinsic(5, 2.0f)}, // precise::exp() is fast on metal
520558
{DeviceAPI::WebGPU, {true}, {}},
521-
{DeviceAPI::OpenCL, {true}, {OO::MULPE, 0.0f, 5}}, // Both exp() and native_exp() are faster than polys.
559+
{DeviceAPI::OpenCL, {true}, MULPE_intrinsic(5, 2.0f)}, // Both exp() and native_exp() are faster than polys.
522560
}};
523561

524562
IntrinsicsInfoPerDeviceAPI ii_log{
525563
OO::MAE, 1e-5f, 1000, {
526-
{DeviceAPI::Vulkan, {true}, {}},
564+
{DeviceAPI::Vulkan, {true, ApproximationPrecision::MULPE, 5e-7f, 3}, {}}, // Precision piecewise defined: 3 ULP outside the range [0.5,2.0]. Absolute error < 2^−21 inside the range [0.5,2.0].
527565
{DeviceAPI::CUDA, {false}, {OO::MAE, 0.0f, 3'800'000}},
528566
{DeviceAPI::Metal, {false}, {OO::MAE, 0.0f, 3'800'000}}, // slow log() on metal
529567
{DeviceAPI::WebGPU, {true}, {}},
@@ -551,6 +589,7 @@ IntrinsicsInfoPerDeviceAPI ii_asin_acos{
551589
OO::MULPE, 1e-5f, 500, {
552590
{DeviceAPI::Vulkan, {true}, {}},
553591
{DeviceAPI::CUDA, {true}, {}},
592+
{DeviceAPI::Metal, {true}, MULPE_intrinsic(5, PI)},
554593
{DeviceAPI::OpenCL, {true}, {}},
555594
}};
556595
// clang-format on
@@ -559,8 +598,10 @@ bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, Dev
559598
const IntrinsicsInfoPerDeviceAPI *iipda = nullptr;
560599
switch (op) {
561600
case Call::fast_atan:
601+
iipda = &ii_atan;
602+
break;
562603
case Call::fast_atan2:
563-
iipda = &ii_atan_atan2;
604+
iipda = &ii_atan2;
564605
break;
565606
case Call::fast_cos:
566607
iipda = &ii_cos;
@@ -858,20 +899,24 @@ class LowerFastMathFunctions : public IRMutator {
858899

859900
// No known fast version available, we will expand our own approximation.
860901
return ApproxImpl::fast_cos(mutate(op->args[0]), prec);
861-
} else if (op->is_intrinsic(Call::fast_atan) || op->is_intrinsic(Call::fast_atan2)) {
902+
} else if (op->is_intrinsic(Call::fast_atan)) {
862903
// Handle fast_atan and fast_atan2 together!
863904
ApproximationPrecision prec = extract_approximation_precision(op);
864-
IntrinsicsInfo ii = resolve_precision(prec, ii_atan_atan2, for_device_api);
905+
IntrinsicsInfo ii = resolve_precision(prec, ii_atan, for_device_api);
865906
if (ii.native_func.is_fast && native_func_satisfies_precision(ii, prec)) {
866907
// The native atan is fast: fall back to native and continue lowering.
867908
return to_native_func(op);
868909
}
869-
870-
if (op->is_intrinsic(Call::fast_atan)) {
871-
return ApproxImpl::fast_atan(mutate(op->args[0]), prec);
872-
} else {
873-
return ApproxImpl::fast_atan2(mutate(op->args[0]), mutate(op->args[1]), prec);
910+
return ApproxImpl::fast_atan(mutate(op->args[0]), prec);
911+
} else if (op->is_intrinsic(Call::fast_atan2)) {
912+
// Handle fast_atan and fast_atan2 together!
913+
ApproximationPrecision prec = extract_approximation_precision(op);
914+
IntrinsicsInfo ii = resolve_precision(prec, ii_atan2, for_device_api);
915+
if (ii.native_func.is_fast && native_func_satisfies_precision(ii, prec)) {
916+
// The native atan2 is fast: fall back to native and continue lowering.
917+
return to_native_func(op);
874918
}
919+
return ApproxImpl::fast_atan2(mutate(op->args[0]), mutate(op->args[1]), prec);
875920
} else if (op->is_intrinsic(Call::fast_tan)) {
876921
ApproximationPrecision prec = extract_approximation_precision(op);
877922
IntrinsicsInfo ii = resolve_precision(prec, ii_tan, for_device_api);
@@ -913,7 +958,7 @@ class LowerFastMathFunctions : public IRMutator {
913958
return append_type_suffix(op);
914959
}
915960
if (ii.native_func.is_fast && native_func_satisfies_precision(ii, prec)) {
916-
// The native atan is fast: fall back to native and continue lowering.
961+
// The native exp is fast: fall back to native and continue lowering.
917962
return to_native_func(op);
918963
}
919964

0 commit comments

Comments
 (0)