Skip to content

Commit f4ebe09

Browse files
committed
Split tables for sin and cos, as metal has odd precision for sin. Add support for fast_tanh on all backends.
1 parent 5256b79 commit f4ebe09

File tree

3 files changed

+63
-16
lines changed

3 files changed

+63
-16
lines changed

src/FastMathFunctions.cpp

+53-12
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,32 @@ Expr fast_log(const Expr &x, ApproximationPrecision prec) {
307307
return result;
308308
}
309309

310+
Expr fast_tanh(const Expr &x, ApproximationPrecision prec) {
311+
// Rewrite with definition:
312+
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
313+
// = (1 - exp(-2x)) / (1 + exp(-2x))
314+
// But abs(x) the argument, and flip when negative.
315+
Type type = x.type();
316+
Expr abs_x = abs(x);
317+
Expr flip_sign = x < 0;
318+
if (prec.optimized_for == ApproximationPrecision::MULPE) {
319+
// Positive arguments to exp() have preciser ULP.
320+
// So, we will rewrite the expression to always use exp(2*x)
321+
// instead of exp(-2*x) when we are close to zero.
322+
Expr flip_exp = abs_x > constant(type, 4);
323+
Expr arg_exp = select(flip_exp, -abs_x, abs_x);
324+
Expr exp2x = Halide::fast_exp(2 * arg_exp, prec);
325+
Expr tanh = (exp2x - constant(type, 1.0)) / (exp2x + constant(type, 1));
326+
tanh = select(flip_exp ^ flip_sign, -tanh, tanh);
327+
return common_subexpression_elimination(tanh, true);
328+
} else {
329+
Expr exp2x = Halide::fast_exp(-2 * abs_x, prec);
330+
Expr tanh = (constant(type, 1) - exp2x) / (constant(type, 1) + exp2x);
331+
tanh = select(flip_sign, -tanh, tanh);
332+
return common_subexpression_elimination(tanh, true);
333+
}
334+
}
335+
310336
} // namespace ApproxImpl
311337

312338
using OO = ApproximationPrecision::OptimizationObjective;
@@ -341,11 +367,20 @@ struct IntrinsicsInfoPerDeviceAPI {
341367
};
342368

343369
// clang-format off
344-
IntrinsicsInfoPerDeviceAPI ii_sin_cos{
370+
IntrinsicsInfoPerDeviceAPI ii_sin{
371+
OO::MAE, 1e-5f, 0, {
372+
{DeviceAPI::Vulkan, {true}, {}},
373+
{DeviceAPI::CUDA, {false}, {OO::MAE, 5e-7f, 1'000'000}},
374+
{DeviceAPI::Metal, {true}, {OO::MAE, 6e-5f, 400'000}},
375+
{DeviceAPI::WebGPU, {true}, {}},
376+
{DeviceAPI::OpenCL, {false}, {OO::MAE, 5e-7f, 1'000'000}},
377+
}};
378+
379+
IntrinsicsInfoPerDeviceAPI ii_cos{
345380
OO::MAE, 1e-5f, 0, {
346381
{DeviceAPI::Vulkan, {true}, {}},
347382
{DeviceAPI::CUDA, {false}, {OO::MAE, 5e-7f, 1'000'000}},
348-
{DeviceAPI::Metal, {true}, {OO::MAE, 5e-7f, 1'000'000}},
383+
{DeviceAPI::Metal, {true}, {OO::MAE, 7e-7f, 5'000}},
349384
{DeviceAPI::WebGPU, {true}, {}},
350385
{DeviceAPI::OpenCL, {false}, {OO::MAE, 5e-7f, 1'000'000}},
351386
}};
@@ -622,24 +657,30 @@ class LowerFastMathFunctions : public IRMutator {
622657
}
623658

624659
Expr visit(const Call *op) override {
625-
if (op->is_intrinsic(Call::fast_sin) || op->is_intrinsic(Call::fast_cos)) {
626-
// Handle fast_sin and fast_cos together!
660+
if (op->is_intrinsic(Call::fast_sin)) {
627661
ApproximationPrecision prec = extract_approximation_precision(op);
628-
IntrinsicsInfo ii = resolve_precision(prec, ii_sin_cos, for_device_api);
662+
IntrinsicsInfo ii = resolve_precision(prec, ii_sin, for_device_api);
629663
if (op->type == Float(32) && intrinsic_satisfies_precision(ii, prec)) {
630664
return append_type_suffix(op);
631665
}
632666
if (ii.native_func.is_fast && native_func_satisfies_precision(ii, prec)) {
633-
// The native sine and cosine are fast: fall back to native and continue lowering.
634667
return to_native_func(op);
635668
}
636669

637670
// No known fast version available, we will expand our own approximation.
638-
if (op->is_intrinsic(Call::fast_sin)) {
639-
return ApproxImpl::fast_sin(mutate(op->args[0]), prec);
640-
} else {
641-
return ApproxImpl::fast_cos(mutate(op->args[0]), prec);
671+
return ApproxImpl::fast_sin(mutate(op->args[0]), prec);
672+
} else if (op->is_intrinsic(Call::fast_cos)) {
673+
ApproximationPrecision prec = extract_approximation_precision(op);
674+
IntrinsicsInfo ii = resolve_precision(prec, ii_cos, for_device_api);
675+
if (op->type == Float(32) && intrinsic_satisfies_precision(ii, prec)) {
676+
return append_type_suffix(op);
642677
}
678+
if (ii.native_func.is_fast && native_func_satisfies_precision(ii, prec)) {
679+
return to_native_func(op);
680+
}
681+
682+
// No known fast version available, we will expand our own approximation.
683+
return ApproxImpl::fast_cos(mutate(op->args[0]), prec);
643684
} else if (op->is_intrinsic(Call::fast_atan) || op->is_intrinsic(Call::fast_atan2)) {
644685
// Handle fast_atan and fast_atan2 together!
645686
ApproximationPrecision prec = extract_approximation_precision(op);
@@ -722,8 +763,8 @@ class LowerFastMathFunctions : public IRMutator {
722763
return append_type_suffix(op);
723764
}
724765

725-
// Unfortunately, no fast_tanh approximation implemented yet!
726-
return to_native_func(op);
766+
// Expand using defintion in terms of exp(2x), and recurse.
767+
return mutate(ApproxImpl::fast_tanh(op->args[0], prec));
727768
} else if (op->is_intrinsic(Call::fast_pow)) {
728769
ApproximationPrecision prec = extract_approximation_precision(op);
729770
IntrinsicsInfo ii = resolve_precision(prec, ii_pow, for_device_api);

test/correctness/fast_function_approximations.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ struct FunctionToTest {
8787
{
8888
{ "-pi/3 to pi/3", {{-pi * 0.333f, pi * 0.333f}}, true, 40, 0 },
8989
{ "-pi/2 to pi/2", {{-pi * 0.5f, pi * 0.5f}}, true, 0, 0 },
90-
{ "-3pi to 3pi", {{-pi * 3.0f, pi * 3.0f}}, false, 0, 0 },
90+
{ "-3pi to 3pi", {{-pi * 3.0f, pi * 3.0f}}, true, 0, 0 },
9191
}
9292
},
9393
{
@@ -133,8 +133,8 @@ struct FunctionToTest {
133133
[](Expr x, Expr y) { return Halide::tanh(x); },
134134
[](Expr x, Expr y, Halide::ApproximationPrecision prec) { return Halide::fast_tanh(x, prec); },
135135
{
136-
{ "precise" , {{ -10.0f , 10.0f }}, true, 70, 20 },
137-
{ "extended" , {{ -100.0f, 100.0f}}, true, 70, 20 },
136+
{ "precise" , {{ -8.0f , 8.0f }}, true, 2500, 20 },
137+
{ "extended" , {{ -100.0f, 100.0f}}, true, 2500, 20 },
138138
}
139139
},
140140
// clang-format on
@@ -372,7 +372,8 @@ int main(int argc, char **argv) {
372372
if (&rat == &ftt.ranged_tests[0]) {
373373
// On the first (typically precise) range.
374374
num_tests++;
375-
if (em.max_abs_error < 1e-5 || em.max_ulp_error < 20'000 || em.max_rel_error < 1e-2) {
375+
if ((em.max_abs_error < 1e-5 || em.max_ulp_error < 20'000 || em.max_rel_error < 1e-2) ||
376+
(em.max_abs_error < 1e-4 && em.mean_abs_error < 1e-5 && em.mean_ulp_error < 400)) {
376377
num_tests_passed++;
377378
print_ok();
378379
} else {

tools/polynomial_optimizer.py

+5
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def optimize_approximation(loss, order):
106106
func = lambda x: np.log(x + 1.0)
107107
exponents = np.arange(1, order + 1)
108108
lower, upper = -0.25, 0.5
109+
elif args.func == "tanh":
110+
func_fixed_part = lambda x: x
111+
func = lambda x: np.tanh(x)
112+
exponents = np.arange(1, order + 1)
113+
lower, upper = 0.0, 4.0
109114
else:
110115
print("Unknown function:", args.func)
111116
exit(1)

0 commit comments

Comments
 (0)