Skip to content

Commit 23f6ff7

Browse files
committed
Implement expm1. Fix accuracy of tanh. Fix lowering of tanh on CUDA. Selectively disable some tests that require strict_float on GPU backends.
1 parent 4c08aa3 commit 23f6ff7

11 files changed

+381
-79
lines changed

src/ApproximationTables.cpp

+106
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,108 @@ const std::vector<Approximation> table_tan = {
500500
},
501501
};
502502

503+
const std::vector<Approximation> table_expm1 = {
504+
/* MULPE optimized */
505+
{ /* Polynomial degree 2: 1*x + 0.5006693548784*x^2 */
506+
/* f16 */ {6.973743e-06, nan, 0},
507+
/* f32 */ {6.969223e-06, 0x1.ebb68p-8, 251914},
508+
/* f64 */ {6.969224e-06, nan, 0},
509+
/* p */ {0, 1, 0x1.0057bbd29fd1ep-1},
510+
},
511+
{ /* Polynomial degree 3: 1*x + 0.5034739414620*x^2 + 0.1676710752100*x^3 */
512+
/* f16 */ {0.000000e+00, nan, 0},
513+
/* f32 */ {3.367883e-09, 0x1.86dp-13, 6263},
514+
/* f64 */ {3.367884e-09, nan, 0},
515+
/* p */ {0, 1, 0x1.01c75621ef769p-1, 0x1.5763eec418d18p-3},
516+
},
517+
{ /* Polynomial degree 4: 1*x + 0.4999934522294*x^2 + 0.1674641440143*x^3 + 0.0418883769826*x^4 */
518+
/* f16 */ {0.000000e+00, nan, 0},
519+
/* f32 */ {7.937537e-12, 0x1.22p-17, 290},
520+
/* f64 */ {7.937461e-12, nan, 0},
521+
/* p */ {0, 1, 0x1.fffe4896282b8p-2, 0x1.56f770ee59ccdp-3, 0x1.57264b2721b28p-5},
522+
},
523+
{ /* Polynomial degree 5: 1*x + 0.4999948095067*x^2 + 0.1666705913520*x^3 + 0.0418641947519*x^4 + 0.0083245399856*x^5 */
524+
/* f16 */ {0.000000e+00, nan, 0},
525+
/* f32 */ {5.121846e-15, 0x1p-22, 9},
526+
/* f64 */ {5.032477e-15, nan, 0},
527+
/* p */ {0, 1, 0x1.fffea3ac00fecp-2, 0x1.555764187ec0cp-3, 0x1.56f3946aa5fddp-5, 0x1.10c74d7f0b9e3p-7},
528+
},
529+
{ /* Polynomial degree 6: 1*x + 0.4999999783332*x^2 + 0.1666655167631*x^3 + 0.0416674530503*x^4 + 0.0083656894489*x^5 + 0.0013868266193*x^6 */
530+
/* f16 */ {0.000000e+00, nan, 0},
531+
/* f32 */ {9.151552e-17, 0x1p-24, 3},
532+
/* f64 */ {3.980170e-18, nan, 0},
533+
/* p */ {0, 1, 0x1.fffffe8bc45fdp-2, 0x1.5554bafef2a4cp-3, 0x1.5556fb851488cp-5, 0x1.12207d4bbd602p-7, 0x1.6b8c5be658778p-10},
534+
},
535+
{ /* Polynomial degree 7: 1*x + 0.5000000039620*x^2 + 0.1666666668832*x^3 + 0.0416663782542*x^4 + 0.0083333114192*x^5 + 0.0013939439655*x^6 + 0.0001989114932*x^7 */
536+
/* f16 */ {0.000000e+00, nan, 0},
537+
/* f32 */ {8.791334e-17, 0x1p-24, 3},
538+
/* f64 */ {1.261949e-21, nan, 0},
539+
/* p */ {0, 1, 0x1.00000022086cdp-1, 0x1.5555555cc5f6bp-3, 0x1.5554ba7e3b3ap-5, 0x1.1110e201a0746p-7, 0x1.6d69fefa37758p-10, 0x1.a125cb74c2fdcp-13},
540+
},
541+
{ /* Polynomial degree 8: 1*x + 0.5000000000002*x^2 + 0.1666666674457*x^3 + 0.0416666667550*x^4 + 0.0083332919144*x^5 + 0.0013888838822*x^6 + 0.0001990314010*x^7 + 0.0000248701821*x^8 */
542+
/* f16 */ {0.000000e+00, nan, 0},
543+
/* f32 */ {8.794097e-17, 0x1p-24, 3},
544+
/* f64 */ {6.327484e-25, nan, 0},
545+
/* p */ {0, 1, 0x1.0000000000618p-1, 0x1.5555557019e1dp-3, 0x1.5555556177a9cp-5, 0x1.1110b81eca4bdp-7, 0x1.6c166b6843098p-10, 0x1.a1662b74ce94ap-13, 0x1.a1409e6521e4p-16},
546+
},
547+
{ /* Polynomial degree 9: 1*x + 0.4999999999985*x^2 + 0.1666666666682*x^3 + 0.0416666668663*x^4 + 0.0083333332671*x^5 + 0.0013888825262*x^6 + 0.0001984132091*x^7 + 0.0000248745945*x^8 + 0.0000027582234*x^9 */
548+
/* f16 */ {0.000000e+00, nan, 0},
549+
/* f32 */ {8.793395e-17, 0x1p-24, 3},
550+
/* f64 */ {1.531604e-28, nan, 0},
551+
/* p */ {0, 1, 0x1.fffffffff940fp-2, 0x1.555555556268ap-3, 0x1.55555570c649p-5, 0x1.111110ecaa65p-7, 0x1.6c16541ce2eep-10, 0x1.a01a47d13935p-13, 0x1.a15391e6e2bcp-16, 0x1.7233d57b06acp-19},
552+
},
553+
554+
/* MAE optimized */
555+
{ /* Polynomial degree 2: 1*x + 0.5050242124682*x^2 */
556+
/* f16 */ {6.973743e-06, nan, 0},
557+
/* f32 */ {6.950645e-06, 0x1.c96fp-8, 276101},
558+
/* f64 */ {6.950646e-06, nan, 0},
559+
/* p */ {0, 1, 0x1.029288987a54cp-1},
560+
},
561+
{ /* Polynomial degree 3: 1*x + 0.5041221231243*x^2 + 0.1676698092003*x^3 */
562+
/* f16 */ {0.000000e+00, nan, 0},
563+
/* f32 */ {4.160910e-09, 0x1.c7p-14, 7815},
564+
/* f64 */ {4.160914e-09, nan, 0},
565+
/* p */ {0, 1, 0x1.021c4b8004a3ap-1, 0x1.576344d85599fp-3},
566+
},
567+
{ /* Polynomial degree 4: 1*x + 0.4999895150973*x^2 + 0.1675387336054*x^3 + 0.0419211379777*x^4 */
568+
/* f16 */ {0.000000e+00, nan, 0},
569+
/* f32 */ {9.945929e-12, 0x1.72p-18, 370},
570+
/* f64 */ {9.945737e-12, nan, 0},
571+
/* p */ {0, 1, 0x1.fffd405ebe74bp-2, 0x1.571e8c2d2f987p-3, 0x1.576aff9401dcp-5},
572+
},
573+
{ /* Polynomial degree 5: 1*x + 0.4999914702852*x^2 + 0.1666645763191*x^3 + 0.0418982706165*x^4 + 0.0083746050916*x^5 */
574+
/* f16 */ {0.000000e+00, nan, 0},
575+
/* f32 */ {3.805249e-15, 0x1.4p-23, 14},
576+
/* f64 */ {3.714810e-15, nan, 0},
577+
/* p */ {0, 1, 0x1.fffdc3949dcaep-2, 0x1.55543cc5899b8p-3, 0x1.573b0ac1d1b71p-5, 0x1.126b477e23ba6p-7},
578+
},
579+
{ /* Polynomial degree 6: 1*x + 0.5000000095104*x^2 + 0.1666651891580*x^3 + 0.0416662060631*x^4 + 0.0083688803426*x^5 + 0.0013950473985*x^6 */
580+
/* f16 */ {0.000000e+00, nan, 0},
581+
/* f32 */ {9.192510e-17, 0x1p-24, 3},
582+
/* f64 */ {3.769683e-18, nan, 0},
583+
/* p */ {0, 1, 0x1.00000051b18efp-1, 0x1.55548f06853e7p-3, 0x1.55545e0c74cfcp-5, 0x1.123b41b01319dp-7, 0x1.6db40bcfe61dp-10},
584+
},
585+
{ /* Polynomial degree 7: 1*x + 0.5000000077859*x^2 + 0.1666666686005*x^3 + 0.0416662701044*x^4 + 0.0083332644982*x^5 + 0.0013946061254*x^6 + 0.0001991830927*x^7 */
586+
/* f16 */ {0.000000e+00, nan, 0},
587+
/* f32 */ {8.790274e-17, 0x1p-24, 3},
588+
/* f64 */ {1.003267e-21, nan, 0},
589+
/* p */ {0, 1, 0x1.00000042e152ap-1, 0x1.55555597c7c4ap-3, 0x1.5554806e3a70cp-5, 0x1.11107d3e893fp-7, 0x1.6d966ecc0e888p-10, 0x1.a1b79bcd9bc7p-13},
590+
},
591+
{ /* Polynomial degree 8: 1*x + 0.4999999999952*x^2 + 0.1666666678656*x^3 + 0.0416666670540*x^4 + 0.0083332812914*x^5 + 0.0013888796454*x^6 + 0.0001990923050*x^7 + 0.0000248875972*x^8 */
592+
/* f16 */ {0.000000e+00, nan, 0},
593+
/* f32 */ {8.794057e-17, 0x1p-24, 3},
594+
/* f64 */ {5.533894e-25, nan, 0},
595+
/* p */ {0, 1, 0x1.ffffffffeae2bp-2, 0x1.5555557e86fd4p-3, 0x1.5555558a91454p-5, 0x1.1110a14eb4df8p-7, 0x1.6c16229ee20dp-10, 0x1.a186de09bce3fp-13, 0x1.a18b6a8cc4fp-16},
596+
},
597+
{ /* Polynomial degree 9: 1*x + 0.4999999999960*x^2 + 0.1666666666657*x^3 + 0.0416666669889*x^4 + 0.0083333333889*x^5 + 0.0013888807600*x^6 + 0.0001984116265*x^7 + 0.0000248822674*x^8 + 0.0000027643875*x^9 */
598+
/* f16 */ {0.000000e+00, nan, 0},
599+
/* f32 */ {8.793395e-17, 0x1p-24, 3},
600+
/* f64 */ {1.074717e-28, nan, 0},
601+
/* p */ {0, 1, 0x1.ffffffffee98ep-2, 0x1.555555554c93dp-3, 0x1.555555819f9cp-5, 0x1.1111112fa1c6p-7, 0x1.6c1635c4da36p-10, 0x1.a0196e4f3bb98p-13, 0x1.a1748651dec8p-16, 0x1.7307a199bd04p-19},
602+
},
603+
};
604+
503605
const std::vector<Approximation> table_exp = {
504606
/* MULPE optimized (with fixed x⁰ and x¹ coefficients 1 and 1). */
505607
{ /* Polynomial degree 1: 1 + 1*x */
@@ -905,6 +1007,10 @@ const Approximation *best_tan_approximation(Halide::ApproximationPrecision preci
9051007
return find_best_approximation("tan", table_tan, precision, type);
9061008
}
9071009

1010+
const Approximation *best_expm1_approximation(Halide::ApproximationPrecision precision, Type type) {
1011+
return find_best_approximation("expm1", table_expm1, precision, type);
1012+
}
1013+
9081014
const Approximation *best_exp_approximation(Halide::ApproximationPrecision precision, Type type) {
9091015
return find_best_approximation("exp", table_exp, precision, type);
9101016
}

src/ApproximationTables.h

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ extern const std::vector<Approximation> table_atan;
3636
extern const std::vector<Approximation> table_sin;
3737
extern const std::vector<Approximation> table_cos;
3838
extern const std::vector<Approximation> table_tan;
39+
extern const std::vector<Approximation> table_expm1;
3940
extern const std::vector<Approximation> table_exp;
4041
extern const std::vector<Approximation> table_log;
4142

@@ -45,6 +46,7 @@ const Approximation *best_cos_approximation(Halide::ApproximationPrecision preci
4546
const Approximation *best_tan_approximation(Halide::ApproximationPrecision precision, Type type);
4647
const Approximation *best_log_approximation(Halide::ApproximationPrecision precision, Type type);
4748
const Approximation *best_exp_approximation(Halide::ApproximationPrecision precision, Type type);
49+
const Approximation *best_expm1_approximation(Halide::ApproximationPrecision precision, Type type);
4850
} // namespace ApproximationTables
4951

5052
} // namespace Internal

src/Derivative.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,9 @@ void ReverseAccumulationVisitor::visit(const Call *op) {
10701070
if (is_math_func(op, "exp", Call::fast_exp)) {
10711071
// d/dx exp(x) = exp(x)
10721072
accumulate(op->args[0], adjoint * exp(op->args[0]));
1073+
} else if (is_math_func(op, "expm1", Call::fast_expm1)) {
1074+
// d/dx (exp(x) - 1) = exp(x)
1075+
accumulate(op->args[0], adjoint * exp(op->args[0]));
10731076
} else if (is_math_func(op, "log", Call::fast_log)) {
10741077
// d/dx log(x) = 1 / x
10751078
accumulate(op->args[0], adjoint / op->args[0]);

src/FastMathFunctions.cpp

+80-32
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,35 @@ Expr fast_exp(const Expr &x_full, ApproximationPrecision prec) {
343343

344344
// Shift the bits up into the exponent field and reinterpret this
345345
// thing as float.
346-
Expr two_to_the_n = reinterpret<float>(biased << 23);
347-
result *= two_to_the_n;
346+
Expr two_to_the_k = reinterpret<float>(biased << 23);
347+
result *= two_to_the_k;
348+
result = common_subexpression_elimination(result, true);
349+
return result;
350+
}
351+
352+
Expr fast_expm1(const Expr &x_full, ApproximationPrecision prec) {
353+
Type type = x_full.type();
354+
user_assert(x_full.type() == Float(32)) << "fast_exp only works for Float(32)";
355+
356+
Expr log2 = make_const(type, std::log(2.0));
357+
358+
Expr scaled = x_full / log2;
359+
Expr k_real = round(scaled); // Here we round instead of floor, to reduce to [-log(2)/2, log(2)/2].
360+
Expr k = cast<int>(k_real);
361+
Expr x = x_full - k_real * log2;
362+
363+
const Internal::Approximation *approx = Internal::ApproximationTables::best_expm1_approximation(prec, type);
364+
Expr result = eval_approx(approx, x);
365+
366+
// Compute 2^k.
367+
int fpbias = 127;
368+
Expr biased = clamp(k + fpbias, 0, 255);
369+
370+
// Shift the bits up into the exponent field and reinterpret this
371+
// thing as float.
372+
Expr two_to_the_k = reinterpret<float>(biased << 23);
373+
374+
result = select(k == 0, result, (result + 1) * two_to_the_k - 1);
348375
result = common_subexpression_elimination(result, true);
349376
return result;
350377
}
@@ -370,26 +397,37 @@ Expr fast_tanh(const Expr &x, ApproximationPrecision prec) {
370397
// Rewrite with definition:
371398
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
372399
// = (1 - exp(-2x)) / (1 + exp(-2x))
400+
// = (expm1(2x)) / (expm1(2x) + 2)
373401
// But abs(x) the argument, and flip when negative.
374402
Type type = x.type();
375403
Expr abs_x = abs(x);
376404
Expr flip_sign = x < 0;
377405
if (prec.optimized_for == ApproximationPrecision::MULPE) {
406+
#if 0
378407
// Positive arguments to exp() have preciser ULP.
379408
// So, we will rewrite the expression to always use exp(2*x)
380409
// instead of exp(-2*x) when we are close to zero.
381410
// Rewriting it like this is slighlty more expensive, hence the branch
382411
// to only pay this extra cost in case we need MULPE-optimized approximations.
383412
Expr flip_exp = abs_x > make_const(type, 4);
384413
Expr arg_exp = select(flip_exp, -abs_x, abs_x);
385-
Expr exp2x = Halide::fast_exp(2 * arg_exp, prec);
386-
Expr tanh = (exp2x - make_const(type, 1.0)) / (exp2x + make_const(type, 1));
414+
Expr exp2xm1 = Halide::fast_expm1(2 * arg_exp, prec);
415+
Expr tanh = (exp2xm1) / (exp2xm1 + make_const(type, 2));
387416
tanh = select(flip_exp ^ flip_sign, -tanh, tanh);
388417
return common_subexpression_elimination(tanh, true);
418+
#else
419+
// expm1 is devloped around 0 and is ULP accurate in [-ln(2)/2, ln(2)/2].
420+
Expr exp2xm1 = Halide::fast_expm1(-2 * abs_x, prec);
421+
Expr tanh = (exp2xm1) / (exp2xm1 + make_const(type, 2));
422+
tanh = select(flip_sign, tanh, -tanh);
423+
return common_subexpression_elimination(tanh, true);
424+
#endif
389425
} else {
390426
// Even if we are optimizing for MAE, the nested call to exp()
391427
// should be MULPE optimized for accuracy, as we are taking ratios.
392-
prec.optimized_for = ApproximationPrecision::MULPE;
428+
if (prec.optimized_for == ApproximationPrecision::MAE) {
429+
prec.optimized_for = ApproximationPrecision::MULPE;
430+
} // else it's on AUTO, and we want to keep that (AUTO tanh uses AUTO exp).
393431
Expr exp2x = Halide::fast_exp(-2 * abs_x, prec);
394432
Expr tanh = (make_const(type, 1) - exp2x) / (make_const(type, 1) + exp2x);
395433
tanh = select(flip_sign, -tanh, tanh);
@@ -466,6 +504,10 @@ IntrinsicsInfoPerDeviceAPI ii_tan{
466504
{DeviceAPI::OpenCL, {false}, {OO::MAE, 2e-6f, 1'000'000}},
467505
}};
468506

507+
IntrinsicsInfoPerDeviceAPI ii_expm1{
508+
OO::MULPE, 0.0f, 50, { /* No intrinsics on any backend. */
509+
}};
510+
469511
IntrinsicsInfoPerDeviceAPI ii_exp{
470512
OO::MULPE, 0.0f, 50, {
471513
{DeviceAPI::Vulkan, {true}, {}},
@@ -478,10 +520,10 @@ IntrinsicsInfoPerDeviceAPI ii_exp{
478520
IntrinsicsInfoPerDeviceAPI ii_log{
479521
OO::MAE, 1e-5f, 1000, {
480522
{DeviceAPI::Vulkan, {true}, {}},
481-
{DeviceAPI::CUDA, {false}, {OO::MULPE, 0.0f, 3'800'000}},
523+
{DeviceAPI::CUDA, {false}, {OO::MAE, 0.0f, 3'800'000}},
482524
{DeviceAPI::Metal, {false}, {OO::MAE, 0.0f, 3'800'000}}, // slow log() on metal
483525
{DeviceAPI::WebGPU, {true}, {}},
484-
{DeviceAPI::OpenCL, {true}, {OO::MULPE, 0.0f, 3'800'000}},
526+
{DeviceAPI::OpenCL, {true}, {OO::MAE, 0.0f, 3'800'000}},
485527
}};
486528

487529
IntrinsicsInfoPerDeviceAPI ii_pow{
@@ -519,6 +561,9 @@ bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, Dev
519561
case Call::fast_cos:
520562
iipda = &ii_cos;
521563
break;
564+
case Call::fast_expm1:
565+
iipda = &ii_expm1;
566+
break;
522567
case Call::fast_exp:
523568
iipda = &ii_exp;
524569
break;
@@ -563,14 +608,17 @@ bool fast_math_func_has_intrinsic_based_implementation(Call::IntrinsicOp op, Dev
563608
return false;
564609
}
565610

566-
IntrinsicsInfo resolve_precision(ApproximationPrecision &prec, const IntrinsicsInfoPerDeviceAPI &iida, DeviceAPI api) {
567-
IntrinsicsInfo ii{};
611+
IntrinsicsInfo find_intrinsics_info_for_device_api(const IntrinsicsInfoPerDeviceAPI &iida, DeviceAPI api) {
568612
for (const auto &cand : iida.device_apis) {
569613
if (cand.device_api == api) {
570-
ii = cand;
571-
break;
614+
return cand;
572615
}
573616
}
617+
return {};
618+
}
619+
620+
IntrinsicsInfo resolve_precision(ApproximationPrecision &prec, const IntrinsicsInfoPerDeviceAPI &iida, DeviceAPI api) {
621+
IntrinsicsInfo ii = find_intrinsics_info_for_device_api(iida, api);
574622

575623
if (prec.optimized_for == ApproximationPrecision::AUTO) {
576624
if (!ii.intrinsic.defined()) {
@@ -690,18 +738,6 @@ class LowerFastMathFunctions : public IRMutator {
690738
return for_device_api == DeviceAPI::CUDA && target.get_cuda_capability_lower_bound() >= 75;
691739
}
692740

693-
void adjust_precision_for_target(ApproximationPrecision &prec) {
694-
if (for_device_api == DeviceAPI::None) {
695-
if (target.arch == Target::Arch::X86) {
696-
// If we do not have fused-multiply-add, we lose some precision.
697-
if (target.bits == 32 || !target.has_feature(Target::Feature::FMA)) {
698-
prec.constraint_max_absolute_error *= 0.5f;
699-
prec.constraint_max_ulp_error /= 2;
700-
}
701-
}
702-
}
703-
}
704-
705741
/** Strips the fast_ prefix, appends the type suffix, and
706742
* drops the precision argument from the end. */
707743
Expr to_native_func(const Call *op) {
@@ -720,7 +756,7 @@ class LowerFastMathFunctions : public IRMutator {
720756
std::vector<Expr> args;
721757
for (size_t i = 0; i < op->args.size() - 1; ++i) {
722758
const Expr &arg = op->args[i];
723-
args.push_back(IRMutator::mutate(arg));
759+
args.push_back(mutate(arg));
724760
}
725761
return Call::make(op->type, new_name, args, Call::PureExtern);
726762
}
@@ -738,7 +774,7 @@ class LowerFastMathFunctions : public IRMutator {
738774
std::vector<Expr> args;
739775
for (size_t i = 0; i < op->args.size() - 1; ++i) {
740776
const Expr &arg = op->args[i];
741-
args.push_back(IRMutator::mutate(arg));
777+
args.push_back(mutate(arg));
742778
}
743779
return Call::make(op->type, new_name, args, Call::PureExtern);
744780
}
@@ -792,7 +828,6 @@ class LowerFastMathFunctions : public IRMutator {
792828
}
793829

794830
// No known fast version available, we will expand our own approximation.
795-
adjust_precision_for_target(prec);
796831
return ApproxImpl::fast_sin(mutate(op->args[0]), prec);
797832
} else if (op->is_intrinsic(Call::fast_cos)) {
798833
ApproximationPrecision prec = extract_approximation_precision(op);
@@ -805,7 +840,6 @@ class LowerFastMathFunctions : public IRMutator {
805840
}
806841

807842
// No known fast version available, we will expand our own approximation.
808-
adjust_precision_for_target(prec);
809843
return ApproxImpl::fast_cos(mutate(op->args[0]), prec);
810844
} else if (op->is_intrinsic(Call::fast_atan) || op->is_intrinsic(Call::fast_atan2)) {
811845
// Handle fast_atan and fast_atan2 together!
@@ -816,7 +850,6 @@ class LowerFastMathFunctions : public IRMutator {
816850
return to_native_func(op);
817851
}
818852

819-
adjust_precision_for_target(prec);
820853
if (op->is_intrinsic(Call::fast_atan)) {
821854
return ApproxImpl::fast_atan(mutate(op->args[0]), prec);
822855
} else {
@@ -841,10 +874,12 @@ class LowerFastMathFunctions : public IRMutator {
841874
return to_native_func(op);
842875
}
843876

844-
adjust_precision_for_target(prec);
845877
return ApproxImpl::fast_tan(mutate(op->args[0]), prec);
878+
} else if (op->is_intrinsic(Call::fast_expm1)) {
879+
ApproximationPrecision prec = extract_approximation_precision(op);
880+
resolve_precision(prec, ii_expm1, for_device_api);
881+
return ApproxImpl::fast_expm1(mutate(op->args[0]), prec);
846882
} else if (op->is_intrinsic(Call::fast_exp)) {
847-
// Handle fast_exp and fast_log together!
848883
ApproximationPrecision prec = extract_approximation_precision(op);
849884
IntrinsicsInfo ii = resolve_precision(prec, ii_exp, for_device_api);
850885
if (op->type == Float(32) && is_cuda_cc20() && intrinsic_satisfies_precision(ii, prec)) {
@@ -865,7 +900,6 @@ class LowerFastMathFunctions : public IRMutator {
865900
return to_native_func(op);
866901
}
867902

868-
adjust_precision_for_target(prec);
869903
return ApproxImpl::fast_exp(mutate(op->args[0]), prec);
870904
} else if (op->is_intrinsic(Call::fast_log)) {
871905
// Handle fast_exp and fast_log together!
@@ -887,10 +921,24 @@ class LowerFastMathFunctions : public IRMutator {
887921
return to_native_func(op);
888922
}
889923

890-
adjust_precision_for_target(prec);
891924
return ApproxImpl::fast_log(mutate(op->args[0]), prec);
892925
} else if (op->is_intrinsic(Call::fast_tanh)) {
893926
ApproximationPrecision prec = extract_approximation_precision(op);
927+
// Here is a little special treatment. tanh() on cuda can be rewritten to exp(), but
928+
// that would behave MAE, instead of MULPE. MULPE is the default behavior for the
929+
// tanh.approx.f32 intrinsic. So resolve_precision() would set it to MULPE to be able
930+
// to use that intrinsic, but that is dependent on CC7.5. So we will instead first
931+
// check if we are on CC <7.5 and are on AUTO, no precision requirements.
932+
// If that's the case, we leave the objective on AUTO, and immediately rewrite.
933+
if (op->type == Float(32) && is_cuda_cc20() && !is_cuda_cc75()) {
934+
if (prec.optimized_for == ApproximationPrecision::AUTO &&
935+
prec.constraint_max_absolute_error == 0 &&
936+
prec.constraint_max_ulp_error == 0 &&
937+
prec.force_halide_polynomial == 0) {
938+
return mutate(ApproxImpl::fast_tanh(op->args[0], prec));
939+
}
940+
}
941+
// Now we know we're not in that case, proceed like usually.
894942
IntrinsicsInfo ii = resolve_precision(prec, ii_tanh, for_device_api);
895943
// We have a fast version on PTX with CC7.5
896944
if (op->type == Float(32) && is_cuda_cc75() && intrinsic_satisfies_precision(ii, prec)) {

src/IR.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,7 @@ const char *const intrinsic_op_names[] = {
635635
"fast_atan2",
636636
"fast_cos",
637637
"fast_exp",
638+
"fast_expm1",
638639
"fast_log",
639640
"fast_pow",
640641
"fast_sin",

0 commit comments

Comments
 (0)