@@ -307,6 +307,32 @@ Expr fast_log(const Expr &x, ApproximationPrecision prec) {
307
307
return result;
308
308
}
309
309
310
+ Expr fast_tanh (const Expr &x, ApproximationPrecision prec) {
311
+ // Rewrite with definition:
312
+ // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
313
+ // = (1 - exp(-2x)) / (1 + exp(-2x))
314
+ // But abs(x) the argument, and flip when negative.
315
+ Type type = x.type ();
316
+ Expr abs_x = abs (x);
317
+ Expr flip_sign = x < 0 ;
318
+ if (prec.optimized_for == ApproximationPrecision::MULPE) {
319
+ // Positive arguments to exp() have preciser ULP.
320
+ // So, we will rewrite the expression to always use exp(2*x)
321
+ // instead of exp(-2*x) when we are close to zero.
322
+ Expr flip_exp = abs_x > constant (type, 4 );
323
+ Expr arg_exp = select (flip_exp, -abs_x, abs_x);
324
+ Expr exp2x = Halide::fast_exp (2 * arg_exp, prec);
325
+ Expr tanh = (exp2x - constant (type, 1.0 )) / (exp2x + constant (type, 1 ));
326
+ tanh = select (flip_exp ^ flip_sign, -tanh , tanh );
327
+ return common_subexpression_elimination (tanh , true );
328
+ } else {
329
+ Expr exp2x = Halide::fast_exp (-2 * abs_x, prec);
330
+ Expr tanh = (constant (type, 1 ) - exp2x) / (constant (type, 1 ) + exp2x);
331
+ tanh = select (flip_sign, -tanh , tanh );
332
+ return common_subexpression_elimination (tanh , true );
333
+ }
334
+ }
335
+
310
336
} // namespace ApproxImpl
311
337
312
338
using OO = ApproximationPrecision::OptimizationObjective;
@@ -341,11 +367,20 @@ struct IntrinsicsInfoPerDeviceAPI {
341
367
};
342
368
343
369
// clang-format off
344
- IntrinsicsInfoPerDeviceAPI ii_sin_cos{
370
+ IntrinsicsInfoPerDeviceAPI ii_sin{
371
+ OO::MAE, 1e-5f , 0 , {
372
+ {DeviceAPI::Vulkan, {true }, {}},
373
+ {DeviceAPI::CUDA, {false }, {OO::MAE, 5e-7f , 1'000'000 }},
374
+ {DeviceAPI::Metal, {true }, {OO::MAE, 6e-5f , 400'000 }},
375
+ {DeviceAPI::WebGPU, {true }, {}},
376
+ {DeviceAPI::OpenCL, {false }, {OO::MAE, 5e-7f , 1'000'000 }},
377
+ }};
378
+
379
+ IntrinsicsInfoPerDeviceAPI ii_cos{
345
380
OO::MAE, 1e-5f , 0 , {
346
381
{DeviceAPI::Vulkan, {true }, {}},
347
382
{DeviceAPI::CUDA, {false }, {OO::MAE, 5e-7f , 1'000'000 }},
348
- {DeviceAPI::Metal, {true }, {OO::MAE, 5e -7f , 1'000 '000 }},
383
+ {DeviceAPI::Metal, {true }, {OO::MAE, 7e -7f , 5 '000 }},
349
384
{DeviceAPI::WebGPU, {true }, {}},
350
385
{DeviceAPI::OpenCL, {false }, {OO::MAE, 5e-7f , 1'000'000 }},
351
386
}};
@@ -622,24 +657,30 @@ class LowerFastMathFunctions : public IRMutator {
622
657
}
623
658
624
659
Expr visit (const Call *op) override {
625
- if (op->is_intrinsic (Call::fast_sin) || op->is_intrinsic (Call::fast_cos)) {
626
- // Handle fast_sin and fast_cos together!
660
+ if (op->is_intrinsic (Call::fast_sin)) {
627
661
ApproximationPrecision prec = extract_approximation_precision (op);
628
- IntrinsicsInfo ii = resolve_precision (prec, ii_sin_cos , for_device_api);
662
+ IntrinsicsInfo ii = resolve_precision (prec, ii_sin , for_device_api);
629
663
if (op->type == Float (32 ) && intrinsic_satisfies_precision (ii, prec)) {
630
664
return append_type_suffix (op);
631
665
}
632
666
if (ii.native_func .is_fast && native_func_satisfies_precision (ii, prec)) {
633
- // The native sine and cosine are fast: fall back to native and continue lowering.
634
667
return to_native_func (op);
635
668
}
636
669
637
670
// No known fast version available, we will expand our own approximation.
638
- if (op->is_intrinsic (Call::fast_sin)) {
639
- return ApproxImpl::fast_sin (mutate (op->args [0 ]), prec);
640
- } else {
641
- return ApproxImpl::fast_cos (mutate (op->args [0 ]), prec);
671
+ return ApproxImpl::fast_sin (mutate (op->args [0 ]), prec);
672
+ } else if (op->is_intrinsic (Call::fast_cos)) {
673
+ ApproximationPrecision prec = extract_approximation_precision (op);
674
+ IntrinsicsInfo ii = resolve_precision (prec, ii_cos, for_device_api);
675
+ if (op->type == Float (32 ) && intrinsic_satisfies_precision (ii, prec)) {
676
+ return append_type_suffix (op);
642
677
}
678
+ if (ii.native_func .is_fast && native_func_satisfies_precision (ii, prec)) {
679
+ return to_native_func (op);
680
+ }
681
+
682
+ // No known fast version available, we will expand our own approximation.
683
+ return ApproxImpl::fast_cos (mutate (op->args [0 ]), prec);
643
684
} else if (op->is_intrinsic (Call::fast_atan) || op->is_intrinsic (Call::fast_atan2)) {
644
685
// Handle fast_atan and fast_atan2 together!
645
686
ApproximationPrecision prec = extract_approximation_precision (op);
@@ -722,8 +763,8 @@ class LowerFastMathFunctions : public IRMutator {
722
763
return append_type_suffix (op);
723
764
}
724
765
725
- // Unfortunately, no fast_tanh approximation implemented yet!
726
- return to_native_func (op );
766
+ // Expand using defintion in terms of exp(2x), and recurse.
767
+ return mutate ( ApproxImpl::fast_tanh (op-> args [ 0 ], prec) );
727
768
} else if (op->is_intrinsic (Call::fast_pow)) {
728
769
ApproximationPrecision prec = extract_approximation_precision (op);
729
770
IntrinsicsInfo ii = resolve_precision (prec, ii_pow, for_device_api);
0 commit comments