5
5
#include " IRMutator.h"
6
6
#include " IROperator.h"
7
7
#include " IRPrinter.h"
8
+ #include " Util.h"
8
9
9
10
namespace Halide {
10
11
namespace Internal {
11
- namespace ApproxImpl {
12
12
13
+ namespace {
13
14
constexpr double PI = 3.14159265358979323846 ;
14
15
constexpr double ONE_OVER_PI = 1.0 / PI;
15
16
constexpr double TWO_OVER_PI = 2.0 / PI;
16
17
constexpr double PI_OVER_TWO = PI / 2 ;
17
18
19
+ float ulp_to_ae (float max, int ulp) {
20
+ internal_assert (max > 0.0 );
21
+ uint32_t n = reinterpret_bits<uint32_t >(max);
22
+ float fn = reinterpret_bits<float >(n + ulp);
23
+ return fn - max;
24
+ }
25
+
26
+ uint32_t ae_to_ulp (float smallest, float ae) {
27
+ internal_assert (smallest >= 0.0 );
28
+ float fn = smallest + ae;
29
+ return reinterpret_bits<uint32_t >(fn) - reinterpret_bits<uint32_t >(smallest);
30
+ }
31
+ } // namespace
32
+
33
+ namespace ApproxImpl {
34
+
18
35
std::pair<float , float > split_float (double value) {
19
36
float high = float (value); // Convert to single precision
20
37
float low = float (value - double (high)); // Compute the residual part
@@ -152,7 +169,7 @@ Expr fast_sin(const Expr &x_full, ApproximationPrecision precision) {
152
169
Expr k = cast<int >(k_real);
153
170
Expr k_mod4 = k % 4 ; // Halide mod is always positive!
154
171
Expr mirror = (k_mod4 == 1 ) || (k_mod4 == 3 );
155
- Expr flip_sign = (k_mod4 > 1 ) ^ (x_full < 0 );
172
+ Expr flip_sign = (k_mod4 > 1 ) != (x_full < 0 );
156
173
157
174
// Reduce the angle modulo pi/2: i.e., to the angle within the quadrant.
158
175
Expr x = x_abs - k_real * make_const (type, PI_OVER_TWO);
@@ -417,7 +434,7 @@ Expr fast_tanh(const Expr &x, ApproximationPrecision prec) {
417
434
Expr arg_exp = select(flip_exp, -abs_x, abs_x);
418
435
Expr exp2xm1 = Halide::fast_expm1(2 * arg_exp, prec);
419
436
Expr tanh = (exp2xm1) / (exp2xm1 + make_const(type, 2));
420
- tanh = select(flip_exp ^ flip_sign, -tanh, tanh);
437
+ tanh = select(flip_exp != flip_sign, -tanh, tanh);
421
438
return common_subexpression_elimination(tanh, true);
422
439
#else
423
440
// expm1 is devloped around 0 and is ULP accurate in [-ln(2)/2, ln(2)/2].
@@ -465,6 +482,19 @@ struct IntrinsicsInfo {
465
482
} intrinsic;
466
483
};
467
484
485
+ IntrinsicsInfo::NativeFunc MAE_func (bool fast, float mae, float smallest_output = 0 .0f ) {
486
+ return IntrinsicsInfo::NativeFunc{fast, OO::MAE, mae, ae_to_ulp (smallest_output, mae)};
487
+ }
488
+ IntrinsicsInfo::NativeFunc MULPE_func (bool fast, uint64_t mulpe, float largest_output) {
489
+ return IntrinsicsInfo::NativeFunc{fast, OO::MULPE, ulp_to_ae (largest_output, mulpe), mulpe};
490
+ }
491
+ IntrinsicsInfo::IntrinsicImpl MAE_intrinsic (float mae, float smallest_output = 0 .0f ) {
492
+ return IntrinsicsInfo::IntrinsicImpl{OO::MAE, mae, ae_to_ulp (smallest_output, mae)};
493
+ }
494
+ IntrinsicsInfo::IntrinsicImpl MULPE_intrinsic (uint64_t mulpe, float largest_output) {
495
+ return IntrinsicsInfo::IntrinsicImpl{OO::MULPE, ulp_to_ae (largest_output, mulpe), mulpe};
496
+ }
497
+
468
498
struct IntrinsicsInfoPerDeviceAPI {
469
499
OO reasonable_behavior; // A reasonable optimization objective for a given function.
470
500
float default_mae; // A reasonable desirable MAE (if specified)
@@ -475,37 +505,45 @@ struct IntrinsicsInfoPerDeviceAPI {
475
505
// clang-format off
476
506
IntrinsicsInfoPerDeviceAPI ii_sin{
477
507
OO::MAE, 1e-5f , 0 , {
478
- {DeviceAPI::Vulkan, { true } , {}},
479
- {DeviceAPI::CUDA, {false }, {OO::MAE, 5e-7f , 1'000'000 } },
480
- {DeviceAPI::Metal, {true }, {OO::MAE, 6e- 5f , 400'000 }},
508
+ {DeviceAPI::Vulkan, MAE_func ( true , 5e- 4f ) , {}},
509
+ {DeviceAPI::CUDA, {false }, MAE_intrinsic ( 5e-7f ) },
510
+ {DeviceAPI::Metal, {true }, MAE_intrinsic ( 1 .2e- 4f )}, // 2^-13
481
511
{DeviceAPI::WebGPU, {true }, {}},
482
- {DeviceAPI::OpenCL, {false }, {OO::MAE, 5e-7f , 1'000'000 } },
512
+ {DeviceAPI::OpenCL, {false }, MAE_intrinsic ( 5e-7f ) },
483
513
}};
484
514
485
515
IntrinsicsInfoPerDeviceAPI ii_cos{
486
516
OO::MAE, 1e-5f , 0 , {
487
- {DeviceAPI::Vulkan, { true } , {}},
488
- {DeviceAPI::CUDA, {false }, {OO::MAE, 5e-7f , 1'000'000 } },
489
- {DeviceAPI::Metal, {true }, {OO::MAE, 7e-7f , 5'000 }},
517
+ {DeviceAPI::Vulkan, MAE_func ( true , 5e- 4f ) , {}},
518
+ {DeviceAPI::CUDA, {false }, MAE_intrinsic ( 5e-7f ) },
519
+ {DeviceAPI::Metal, {true }, MAE_intrinsic ( 1 .2e- 4f )}, // Seems to be 7e-7, but spec says 2^-13...
490
520
{DeviceAPI::WebGPU, {true }, {}},
491
- {DeviceAPI::OpenCL, {false }, {OO::MAE, 5e-7f , 1'000'000 } },
521
+ {DeviceAPI::OpenCL, {false }, MAE_intrinsic ( 5e-7f ) },
492
522
}};
493
523
494
- IntrinsicsInfoPerDeviceAPI ii_atan_atan2 {
524
+ IntrinsicsInfoPerDeviceAPI ii_atan {
495
525
OO::MAE, 1e-5f , 0 , {
496
526
// no intrinsics available
497
527
{DeviceAPI::Vulkan, {false }, {}},
498
- {DeviceAPI::Metal, {true }, {OO::MAE, 5e-6f }},
528
+ {DeviceAPI::Metal, {true }, MULPE_intrinsic (5 , float (PI * 0.501 ))}, // They claim <= 5 ULP!
529
+ {DeviceAPI::WebGPU, {true }, {}},
530
+ }};
531
+
532
+ IntrinsicsInfoPerDeviceAPI ii_atan2{
533
+ OO::MAE, 1e-5f , 0 , {
534
+ // no intrinsics available
535
+ {DeviceAPI::Vulkan, {false }, {}},
536
+ {DeviceAPI::Metal, {true }, MAE_intrinsic (5e-6f , 0 .0f )},
499
537
{DeviceAPI::WebGPU, {true }, {}},
500
538
}};
501
539
502
540
IntrinsicsInfoPerDeviceAPI ii_tan{
503
541
OO::MULPE, 0 .0f , 2000 , {
504
- {DeviceAPI::Vulkan, { true , OO::MAE, 2e-6f , 1'000'000 }, {}}, // Vulkan tan seems to mimic our CUDA implementation
505
- {DeviceAPI::CUDA, {false }, {OO::MAE, 2e-6f , 1'000'000 } },
506
- {DeviceAPI::Metal, {true }, {OO::MULPE, 2e-6f , 1'000'000 }},
542
+ {DeviceAPI::Vulkan, MAE_func ( true , 2e-6f ), {}}, // Vulkan tan() seems to mimic our CUDA implementation
543
+ {DeviceAPI::CUDA, {false }, MAE_intrinsic ( 2e-6f ) },
544
+ {DeviceAPI::Metal, {true }, MAE_intrinsic ( 2e-6f )}, // sin()/cos()
507
545
{DeviceAPI::WebGPU, {true }, {}},
508
- {DeviceAPI::OpenCL, {false }, {OO::MAE, 2e-6f , 1'000'000 } },
546
+ {DeviceAPI::OpenCL, {false }, MAE_intrinsic ( 2e-6f ) },
509
547
}};
510
548
511
549
IntrinsicsInfoPerDeviceAPI ii_expm1{
@@ -514,16 +552,16 @@ IntrinsicsInfoPerDeviceAPI ii_expm1{
514
552
515
553
IntrinsicsInfoPerDeviceAPI ii_exp{
516
554
OO::MULPE, 0 .0f , 50 , {
517
- {DeviceAPI::Vulkan, { true } , {}},
518
- {DeviceAPI::CUDA, {false }, {OO::MULPE, 0 .0f , 5 } },
519
- {DeviceAPI::Metal, {true }, {OO::MULPE, 0 .0f , 5 } }, // precise::exp() is fast on metal
555
+ {DeviceAPI::Vulkan, MULPE_func ( true , 3 + 2 * 2 , 2 . 0f ) , {}},
556
+ {DeviceAPI::CUDA, {false }, MULPE_intrinsic ( 5 , 2 .0f ) },
557
+ {DeviceAPI::Metal, {true }, MULPE_intrinsic ( 5 , 2 .0f ) }, // precise::exp() is fast on metal
520
558
{DeviceAPI::WebGPU, {true }, {}},
521
- {DeviceAPI::OpenCL, {true }, {OO::MULPE, 0 .0f , 5 } }, // Both exp() and native_exp() are faster than polys.
559
+ {DeviceAPI::OpenCL, {true }, MULPE_intrinsic ( 5 , 2 .0f ) }, // Both exp() and native_exp() are faster than polys.
522
560
}};
523
561
524
562
IntrinsicsInfoPerDeviceAPI ii_log{
525
563
OO::MAE, 1e-5f , 1000 , {
526
- {DeviceAPI::Vulkan, {true }, {}},
564
+ {DeviceAPI::Vulkan, {true , ApproximationPrecision::MULPE, 5e- 7f , 3 }, {}}, // Precision piecewise defined: 3 ULP outside the range [0.5,2.0]. Absolute error < 2^−21 inside the range [0.5,2.0].
527
565
{DeviceAPI::CUDA, {false }, {OO::MAE, 0 .0f , 3'800'000 }},
528
566
{DeviceAPI::Metal, {false }, {OO::MAE, 0 .0f , 3'800'000 }}, // slow log() on metal
529
567
{DeviceAPI::WebGPU, {true }, {}},
@@ -551,6 +589,7 @@ IntrinsicsInfoPerDeviceAPI ii_asin_acos{
551
589
OO::MULPE, 1e-5f , 500 , {
552
590
{DeviceAPI::Vulkan, {true }, {}},
553
591
{DeviceAPI::CUDA, {true }, {}},
592
+ {DeviceAPI::Metal, {true }, MULPE_intrinsic (5 , PI)},
554
593
{DeviceAPI::OpenCL, {true }, {}},
555
594
}};
556
595
// clang-format on
@@ -559,8 +598,10 @@ bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, Dev
559
598
const IntrinsicsInfoPerDeviceAPI *iipda = nullptr ;
560
599
switch (op) {
561
600
case Call::fast_atan:
601
+ iipda = &ii_atan;
602
+ break ;
562
603
case Call::fast_atan2:
563
- iipda = &ii_atan_atan2 ;
604
+ iipda = &ii_atan2 ;
564
605
break ;
565
606
case Call::fast_cos:
566
607
iipda = &ii_cos;
@@ -858,20 +899,24 @@ class LowerFastMathFunctions : public IRMutator {
858
899
859
900
// No known fast version available, we will expand our own approximation.
860
901
return ApproxImpl::fast_cos (mutate (op->args [0 ]), prec);
861
- } else if (op->is_intrinsic (Call::fast_atan) || op-> is_intrinsic (Call::fast_atan2) ) {
902
+ } else if (op->is_intrinsic (Call::fast_atan)) {
862
903
// Handle fast_atan and fast_atan2 together!
863
904
ApproximationPrecision prec = extract_approximation_precision (op);
864
- IntrinsicsInfo ii = resolve_precision (prec, ii_atan_atan2 , for_device_api);
905
+ IntrinsicsInfo ii = resolve_precision (prec, ii_atan , for_device_api);
865
906
if (ii.native_func .is_fast && native_func_satisfies_precision (ii, prec)) {
866
907
// The native atan is fast: fall back to native and continue lowering.
867
908
return to_native_func (op);
868
909
}
869
-
870
- if (op->is_intrinsic (Call::fast_atan)) {
871
- return ApproxImpl::fast_atan (mutate (op->args [0 ]), prec);
872
- } else {
873
- return ApproxImpl::fast_atan2 (mutate (op->args [0 ]), mutate (op->args [1 ]), prec);
910
+ return ApproxImpl::fast_atan (mutate (op->args [0 ]), prec);
911
+ } else if (op->is_intrinsic (Call::fast_atan2)) {
912
+ // Handle fast_atan and fast_atan2 together!
913
+ ApproximationPrecision prec = extract_approximation_precision (op);
914
+ IntrinsicsInfo ii = resolve_precision (prec, ii_atan2, for_device_api);
915
+ if (ii.native_func .is_fast && native_func_satisfies_precision (ii, prec)) {
916
+ // The native atan2 is fast: fall back to native and continue lowering.
917
+ return to_native_func (op);
874
918
}
919
+ return ApproxImpl::fast_atan2 (mutate (op->args [0 ]), mutate (op->args [1 ]), prec);
875
920
} else if (op->is_intrinsic (Call::fast_tan)) {
876
921
ApproximationPrecision prec = extract_approximation_precision (op);
877
922
IntrinsicsInfo ii = resolve_precision (prec, ii_tan, for_device_api);
@@ -913,7 +958,7 @@ class LowerFastMathFunctions : public IRMutator {
913
958
return append_type_suffix (op);
914
959
}
915
960
if (ii.native_func .is_fast && native_func_satisfies_precision (ii, prec)) {
916
- // The native atan is fast: fall back to native and continue lowering.
961
+ // The native exp is fast: fall back to native and continue lowering.
917
962
return to_native_func (op);
918
963
}
919
964
0 commit comments