@@ -343,8 +343,35 @@ Expr fast_exp(const Expr &x_full, ApproximationPrecision prec) {
343
343
344
344
// Shift the bits up into the exponent field and reinterpret this
345
345
// thing as float.
346
- Expr two_to_the_n = reinterpret<float >(biased << 23 );
347
- result *= two_to_the_n;
346
+ Expr two_to_the_k = reinterpret<float >(biased << 23 );
347
+ result *= two_to_the_k;
348
+ result = common_subexpression_elimination (result, true );
349
+ return result;
350
+ }
351
+
352
+ Expr fast_expm1 (const Expr &x_full, ApproximationPrecision prec) {
353
+ Type type = x_full.type ();
354
+ user_assert (x_full.type () == Float (32 )) << " fast_exp only works for Float(32)" ;
355
+
356
+ Expr log2 = make_const (type, std::log (2.0 ));
357
+
358
+ Expr scaled = x_full / log2 ;
359
+ Expr k_real = round (scaled); // Here we round instead of floor, to reduce to [-log(2)/2, log(2)/2].
360
+ Expr k = cast<int >(k_real);
361
+ Expr x = x_full - k_real * log2 ;
362
+
363
+ const Internal::Approximation *approx = Internal::ApproximationTables::best_expm1_approximation (prec, type);
364
+ Expr result = eval_approx (approx, x);
365
+
366
+ // Compute 2^k.
367
+ int fpbias = 127 ;
368
+ Expr biased = clamp (k + fpbias, 0 , 255 );
369
+
370
+ // Shift the bits up into the exponent field and reinterpret this
371
+ // thing as float.
372
+ Expr two_to_the_k = reinterpret<float >(biased << 23 );
373
+
374
+ result = select (k == 0 , result, (result + 1 ) * two_to_the_k - 1 );
348
375
result = common_subexpression_elimination (result, true );
349
376
return result;
350
377
}
@@ -370,26 +397,37 @@ Expr fast_tanh(const Expr &x, ApproximationPrecision prec) {
370
397
// Rewrite with definition:
371
398
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
372
399
// = (1 - exp(-2x)) / (1 + exp(-2x))
400
+ // = (expm1(2x)) / (expm1(2x) + 2)
373
401
// But abs(x) the argument, and flip when negative.
374
402
Type type = x.type ();
375
403
Expr abs_x = abs (x);
376
404
Expr flip_sign = x < 0 ;
377
405
if (prec.optimized_for == ApproximationPrecision::MULPE) {
406
+ #if 0
378
407
// Positive arguments to exp() have preciser ULP.
379
408
// So, we will rewrite the expression to always use exp(2*x)
380
409
// instead of exp(-2*x) when we are close to zero.
381
410
// Rewriting it like this is slighlty more expensive, hence the branch
382
411
// to only pay this extra cost in case we need MULPE-optimized approximations.
383
412
Expr flip_exp = abs_x > make_const(type, 4);
384
413
Expr arg_exp = select(flip_exp, -abs_x, abs_x);
385
- Expr exp2x = Halide::fast_exp (2 * arg_exp, prec);
386
- Expr tanh = (exp2x - make_const (type, 1.0 )) / (exp2x + make_const (type, 1 ));
414
+ Expr exp2xm1 = Halide::fast_expm1 (2 * arg_exp, prec);
415
+ Expr tanh = (exp2xm1) / (exp2xm1 + make_const(type, 2 ));
387
416
tanh = select(flip_exp ^ flip_sign, -tanh, tanh);
388
417
return common_subexpression_elimination(tanh, true);
418
+ #else
419
+ // expm1 is devloped around 0 and is ULP accurate in [-ln(2)/2, ln(2)/2].
420
+ Expr exp2xm1 = Halide::fast_expm1 (-2 * abs_x, prec);
421
+ Expr tanh = (exp2xm1) / (exp2xm1 + make_const (type, 2 ));
422
+ tanh = select (flip_sign, tanh , -tanh );
423
+ return common_subexpression_elimination (tanh , true );
424
+ #endif
389
425
} else {
390
426
// Even if we are optimizing for MAE, the nested call to exp()
391
427
// should be MULPE optimized for accuracy, as we are taking ratios.
392
- prec.optimized_for = ApproximationPrecision::MULPE;
428
+ if (prec.optimized_for == ApproximationPrecision::MAE) {
429
+ prec.optimized_for = ApproximationPrecision::MULPE;
430
+ } // else it's on AUTO, and we want to keep that (AUTO tanh uses AUTO exp).
393
431
Expr exp2x = Halide::fast_exp (-2 * abs_x, prec);
394
432
Expr tanh = (make_const (type, 1 ) - exp2x) / (make_const (type, 1 ) + exp2x);
395
433
tanh = select (flip_sign, -tanh , tanh );
@@ -466,6 +504,10 @@ IntrinsicsInfoPerDeviceAPI ii_tan{
466
504
{DeviceAPI::OpenCL, {false }, {OO::MAE, 2e-6f , 1'000'000 }},
467
505
}};
468
506
507
+ IntrinsicsInfoPerDeviceAPI ii_expm1{
508
+ OO::MULPE, 0 .0f , 50 , { /* No intrinsics on any backend. */
509
+ }};
510
+
469
511
IntrinsicsInfoPerDeviceAPI ii_exp{
470
512
OO::MULPE, 0 .0f , 50 , {
471
513
{DeviceAPI::Vulkan, {true }, {}},
@@ -478,10 +520,10 @@ IntrinsicsInfoPerDeviceAPI ii_exp{
478
520
IntrinsicsInfoPerDeviceAPI ii_log{
479
521
OO::MAE, 1e-5f , 1000 , {
480
522
{DeviceAPI::Vulkan, {true }, {}},
481
- {DeviceAPI::CUDA, {false }, {OO::MULPE , 0 .0f , 3'800'000 }},
523
+ {DeviceAPI::CUDA, {false }, {OO::MAE , 0 .0f , 3'800'000 }},
482
524
{DeviceAPI::Metal, {false }, {OO::MAE, 0 .0f , 3'800'000 }}, // slow log() on metal
483
525
{DeviceAPI::WebGPU, {true }, {}},
484
- {DeviceAPI::OpenCL, {true }, {OO::MULPE , 0 .0f , 3'800'000 }},
526
+ {DeviceAPI::OpenCL, {true }, {OO::MAE , 0 .0f , 3'800'000 }},
485
527
}};
486
528
487
529
IntrinsicsInfoPerDeviceAPI ii_pow{
@@ -519,6 +561,9 @@ bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, Dev
519
561
case Call::fast_cos:
520
562
iipda = &ii_cos;
521
563
break ;
564
+ case Call::fast_expm1:
565
+ iipda = &ii_expm1;
566
+ break ;
522
567
case Call::fast_exp:
523
568
iipda = &ii_exp;
524
569
break ;
@@ -563,14 +608,17 @@ bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, Dev
563
608
return false ;
564
609
}
565
610
566
- IntrinsicsInfo resolve_precision (ApproximationPrecision &prec, const IntrinsicsInfoPerDeviceAPI &iida, DeviceAPI api) {
567
- IntrinsicsInfo ii{};
611
+ IntrinsicsInfo find_intrinsics_info_for_device_api (const IntrinsicsInfoPerDeviceAPI &iida, DeviceAPI api) {
568
612
for (const auto &cand : iida.device_apis ) {
569
613
if (cand.device_api == api) {
570
- ii = cand;
571
- break ;
614
+ return cand;
572
615
}
573
616
}
617
+ return {};
618
+ }
619
+
620
+ IntrinsicsInfo resolve_precision (ApproximationPrecision &prec, const IntrinsicsInfoPerDeviceAPI &iida, DeviceAPI api) {
621
+ IntrinsicsInfo ii = find_intrinsics_info_for_device_api (iida, api);
574
622
575
623
if (prec.optimized_for == ApproximationPrecision::AUTO) {
576
624
if (!ii.intrinsic .defined ()) {
@@ -690,18 +738,6 @@ class LowerFastMathFunctions : public IRMutator {
690
738
return for_device_api == DeviceAPI::CUDA && target.get_cuda_capability_lower_bound () >= 75 ;
691
739
}
692
740
693
- void adjust_precision_for_target (ApproximationPrecision &prec) {
694
- if (for_device_api == DeviceAPI::None) {
695
- if (target.arch == Target::Arch::X86) {
696
- // If we do not have fused-multiply-add, we lose some precision.
697
- if (target.bits == 32 || !target.has_feature (Target::Feature::FMA)) {
698
- prec.constraint_max_absolute_error *= 0 .5f ;
699
- prec.constraint_max_ulp_error /= 2 ;
700
- }
701
- }
702
- }
703
- }
704
-
705
741
/* * Strips the fast_ prefix, appends the type suffix, and
706
742
* drops the precision argument from the end. */
707
743
Expr to_native_func (const Call *op) {
@@ -720,7 +756,7 @@ class LowerFastMathFunctions : public IRMutator {
720
756
std::vector<Expr> args;
721
757
for (size_t i = 0 ; i < op->args .size () - 1 ; ++i) {
722
758
const Expr &arg = op->args [i];
723
- args.push_back (IRMutator:: mutate (arg));
759
+ args.push_back (mutate (arg));
724
760
}
725
761
return Call::make (op->type , new_name, args, Call::PureExtern);
726
762
}
@@ -738,7 +774,7 @@ class LowerFastMathFunctions : public IRMutator {
738
774
std::vector<Expr> args;
739
775
for (size_t i = 0 ; i < op->args .size () - 1 ; ++i) {
740
776
const Expr &arg = op->args [i];
741
- args.push_back (IRMutator:: mutate (arg));
777
+ args.push_back (mutate (arg));
742
778
}
743
779
return Call::make (op->type , new_name, args, Call::PureExtern);
744
780
}
@@ -792,7 +828,6 @@ class LowerFastMathFunctions : public IRMutator {
792
828
}
793
829
794
830
// No known fast version available, we will expand our own approximation.
795
- adjust_precision_for_target (prec);
796
831
return ApproxImpl::fast_sin (mutate (op->args [0 ]), prec);
797
832
} else if (op->is_intrinsic (Call::fast_cos)) {
798
833
ApproximationPrecision prec = extract_approximation_precision (op);
@@ -805,7 +840,6 @@ class LowerFastMathFunctions : public IRMutator {
805
840
}
806
841
807
842
// No known fast version available, we will expand our own approximation.
808
- adjust_precision_for_target (prec);
809
843
return ApproxImpl::fast_cos (mutate (op->args [0 ]), prec);
810
844
} else if (op->is_intrinsic (Call::fast_atan) || op->is_intrinsic (Call::fast_atan2)) {
811
845
// Handle fast_atan and fast_atan2 together!
@@ -816,7 +850,6 @@ class LowerFastMathFunctions : public IRMutator {
816
850
return to_native_func (op);
817
851
}
818
852
819
- adjust_precision_for_target (prec);
820
853
if (op->is_intrinsic (Call::fast_atan)) {
821
854
return ApproxImpl::fast_atan (mutate (op->args [0 ]), prec);
822
855
} else {
@@ -841,10 +874,12 @@ class LowerFastMathFunctions : public IRMutator {
841
874
return to_native_func (op);
842
875
}
843
876
844
- adjust_precision_for_target (prec);
845
877
return ApproxImpl::fast_tan (mutate (op->args [0 ]), prec);
878
+ } else if (op->is_intrinsic (Call::fast_expm1)) {
879
+ ApproximationPrecision prec = extract_approximation_precision (op);
880
+ resolve_precision (prec, ii_expm1, for_device_api);
881
+ return ApproxImpl::fast_expm1 (mutate (op->args [0 ]), prec);
846
882
} else if (op->is_intrinsic (Call::fast_exp)) {
847
- // Handle fast_exp and fast_log together!
848
883
ApproximationPrecision prec = extract_approximation_precision (op);
849
884
IntrinsicsInfo ii = resolve_precision (prec, ii_exp, for_device_api);
850
885
if (op->type == Float (32 ) && is_cuda_cc20 () && intrinsic_satisfies_precision (ii, prec)) {
@@ -865,7 +900,6 @@ class LowerFastMathFunctions : public IRMutator {
865
900
return to_native_func (op);
866
901
}
867
902
868
- adjust_precision_for_target (prec);
869
903
return ApproxImpl::fast_exp (mutate (op->args [0 ]), prec);
870
904
} else if (op->is_intrinsic (Call::fast_log)) {
871
905
// Handle fast_exp and fast_log together!
@@ -887,10 +921,24 @@ class LowerFastMathFunctions : public IRMutator {
887
921
return to_native_func (op);
888
922
}
889
923
890
- adjust_precision_for_target (prec);
891
924
return ApproxImpl::fast_log (mutate (op->args [0 ]), prec);
892
925
} else if (op->is_intrinsic (Call::fast_tanh)) {
893
926
ApproximationPrecision prec = extract_approximation_precision (op);
927
+ // Here is a little special treatment. tanh() on cuda can be rewritten to exp(), but
928
+ // that would behave MAE, instead of MULPE. MULPE is the default behavior for the
929
+ // tanh.approx.f32 intrinsic. So resolve_precision() would set it to MULPE to be able
930
+ // to use that intrinsic, but that is dependent on CC7.5. So we will instead first
931
+ // check if we are on CC <7.5 and are on AUTO, no precision requirements.
932
+ // If that's the case, we leave the objective on AUTO, and immediately rewrite.
933
+ if (op->type == Float (32 ) && is_cuda_cc20 () && !is_cuda_cc75 ()) {
934
+ if (prec.optimized_for == ApproximationPrecision::AUTO &&
935
+ prec.constraint_max_absolute_error == 0 &&
936
+ prec.constraint_max_ulp_error == 0 &&
937
+ prec.force_halide_polynomial == 0 ) {
938
+ return mutate (ApproxImpl::fast_tanh (op->args [0 ], prec));
939
+ }
940
+ }
941
+ // Now we know we're not in that case, proceed like usually.
894
942
IntrinsicsInfo ii = resolve_precision (prec, ii_tanh, for_device_api);
895
943
// We have a fast version on PTX with CC7.5
896
944
if (op->type == Float (32 ) && is_cuda_cc75 () && intrinsic_satisfies_precision (ii, prec)) {
0 commit comments