Skip to content

Commit 735c431

Browse files
committed
Fix incorrect approximation selection when required precision is not available.
1 parent 58cb828 commit 735c431

File tree

4 files changed

+82
-48
lines changed

4 files changed

+82
-48
lines changed

src/ApproximationTables.cpp

+40-37
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,46 @@ namespace Internal {
55

66
namespace {
77

8+
using OO = ApproximationPrecision::OptimizationObjective;
9+
810
// Generate this table with:
911
// python3 src/polynomial_optimizer.py atan --order 1 2 3 4 5 6 7 8 --loss mse mae mulpe mulpe_mae --no-gui --format table
1012
std::vector<Approximation> table_atan = {
11-
{ApproximationPrecision::MSE, 9.249650e-04, 7.078984e-02, 2.411547e+06, {+8.56188008e-01}},
12-
{ApproximationPrecision::MSE, 1.026356e-05, 9.214909e-03, 3.985505e+05, {+9.76213454e-01, -2.00030200e-01}},
13-
{ApproximationPrecision::MSE, 1.577588e-07, 1.323851e-03, 6.724566e+04, {+9.95982073e-01, -2.92278128e-01, +8.30180680e-02}},
14-
{ApproximationPrecision::MSE, 2.849011e-09, 1.992218e-04, 1.142204e+04, {+9.99316541e-01, -3.22286501e-01, +1.49032461e-01, -4.08635592e-02}},
15-
{ApproximationPrecision::MSE, 5.667504e-11, 3.080100e-05, 1.945614e+03, {+9.99883373e-01, -3.30599535e-01, +1.81451316e-01, -8.71733830e-02, +2.18671936e-02}},
16-
{ApproximationPrecision::MSE, 1.202662e-12, 4.846916e-06, 3.318677e+02, {+9.99980065e-01, -3.32694393e-01, +1.94019697e-01, -1.17694732e-01, +5.40822080e-02, -1.22995279e-02}},
17-
{ApproximationPrecision::MSE, 2.672889e-14, 7.722732e-07, 5.664632e+01, {+9.99996589e-01, -3.33190090e-01, +1.98232868e-01, -1.32941469e-01, +8.07623712e-02, -3.46124853e-02, +7.15115276e-03}},
18-
{ApproximationPrecision::MSE, 6.147315e-16, 1.245768e-07, 9.764224e+00, {+9.99999416e-01, -3.33302229e-01, +1.99511173e-01, -1.39332647e-01, +9.70944891e-02, -5.68823386e-02, +2.25679012e-02, -4.25772648e-03}},
19-
20-
{ApproximationPrecision::MAE, 1.097847e-03, 4.801638e-02, 2.793645e+06, {+8.33414544e-01}},
21-
{ApproximationPrecision::MAE, 1.209593e-05, 4.968992e-03, 4.623251e+05, {+9.72410454e-01, -1.91981283e-01}},
22-
{ApproximationPrecision::MAE, 1.839382e-07, 6.107084e-04, 7.766697e+04, {+9.95360080e-01, -2.88702052e-01, +7.93508437e-02}},
23-
{ApproximationPrecision::MAE, 3.296902e-09, 8.164167e-05, 1.313615e+04, {+9.99214108e-01, -3.21178073e-01, +1.46272006e-01, -3.89915187e-02}},
24-
{ApproximationPrecision::MAE, 6.523525e-11, 1.147459e-05, 2.229646e+03, {+9.99866373e-01, -3.30305517e-01, +1.80162434e-01, -8.51611537e-02, +2.08475020e-02}},
25-
{ApproximationPrecision::MAE, 1.378842e-12, 1.667328e-06, 3.792091e+02, {+9.99977226e-01, -3.32622991e-01, +1.93541452e-01, -1.16429278e-01, +5.26504600e-02, -1.17203722e-02}},
26-
{ApproximationPrecision::MAE, 3.055131e-14, 2.480947e-07, 6.457187e+01, {+9.99996113e-01, -3.33173716e-01, +1.98078484e-01, -1.32334692e-01, +7.96260166e-02, -3.36062649e-02, +6.81247117e-03}},
27-
{ApproximationPrecision::MAE, 7.013215e-16, 3.757868e-08, 1.102324e+01, {+9.99999336e-01, -3.33298615e-01, +1.99465749e-01, -1.39086791e-01, +9.64233077e-02, -5.59142254e-02, +2.18643190e-02, -4.05495427e-03}},
28-
29-
{ApproximationPrecision::MULPE, 1.355602e-03, 1.067325e-01, 1.808493e+06, {+8.92130617e-01}},
30-
{ApproximationPrecision::MULPE, 2.100588e-05, 1.075508e-02, 1.822095e+05, {+9.89111122e-01, -2.14468039e-01}},
31-
{ApproximationPrecision::MULPE, 3.573985e-07, 1.316370e-03, 2.227347e+04, {+9.98665077e-01, -3.02990987e-01, +9.10404434e-02}},
32-
{ApproximationPrecision::MULPE, 6.474958e-09, 1.548508e-04, 2.619892e+03, {+9.99842198e-01, -3.26272641e-01, +1.56294460e-01, -4.46207045e-02}},
33-
{ApproximationPrecision::MULPE, 1.313474e-10, 2.533532e-05, 4.294794e+02, {+9.99974110e-01, -3.31823782e-01, +1.85886095e-01, -9.30024008e-02, +2.43894760e-02}},
34-
{ApproximationPrecision::MULPE, 3.007880e-12, 3.530685e-06, 5.983830e+01, {+9.99996388e-01, -3.33036463e-01, +1.95959706e-01, -1.22068745e-01, +5.83403647e-02, -1.37966171e-02}},
35-
{ApproximationPrecision::MULPE, 6.348880e-14, 4.882649e-07, 8.276351e+00, {+9.99999499e-01, -3.33273408e-01, +1.98895454e-01, -1.35153794e-01, +8.43185278e-02, -3.73434598e-02, +7.95583230e-03}},
36-
{ApproximationPrecision::MULPE, 1.369569e-15, 7.585036e-08, 1.284979e+00, {+9.99999922e-01, -3.33320840e-01, +1.99708563e-01, -1.40257063e-01, +9.93094012e-02, -5.97138046e-02, +2.44056181e-02, -4.73371006e-03}},
37-
38-
{ApproximationPrecision::MULPE_MAE, 9.548909e-04, 6.131488e-02, 2.570520e+06, {+8.46713042e-01}},
39-
{ApproximationPrecision::MULPE_MAE, 1.159917e-05, 6.746680e-03, 3.778023e+05, {+9.77449762e-01, -1.98798279e-01}},
40-
{ApproximationPrecision::MULPE_MAE, 1.783646e-07, 8.575388e-04, 6.042236e+04, {+9.96388826e-01, -2.92591679e-01, +8.24585555e-02}},
41-
{ApproximationPrecision::MULPE_MAE, 3.265269e-09, 1.190548e-04, 9.505190e+03, {+9.99430906e-01, -3.22774535e-01, +1.49370817e-01, -4.07480795e-02}},
42-
{ApproximationPrecision::MULPE_MAE, 6.574962e-11, 1.684690e-05, 1.515116e+03, {+9.99909079e-01, -3.30795737e-01, +1.81810037e-01, -8.72860225e-02, +2.17776539e-02}},
43-
{ApproximationPrecision::MULPE_MAE, 1.380489e-12, 2.497538e-06, 2.510721e+02, {+9.99984893e-01, -3.32748885e-01, +1.94193211e-01, -1.17865932e-01, +5.40633775e-02, -1.22309990e-02}},
44-
{ApproximationPrecision::MULPE_MAE, 3.053218e-14, 3.784868e-07, 4.181995e+01, {+9.99997480e-01, -3.33205127e-01, +1.98309644e-01, -1.33094430e-01, +8.08643094e-02, -3.45859503e-02, +7.11261604e-03}},
45-
{ApproximationPrecision::MULPE_MAE, 7.018877e-16, 5.862915e-08, 6.942196e+00, {+9.99999581e-01, -3.33306326e-01, +1.99542180e-01, -1.39433369e-01, +9.72462857e-02, -5.69734398e-02, +2.25639390e-02, -4.24074590e-03}},
13+
{OO::MSE, 9.249650e-04, 7.078984e-02, 2.411e+06, {+8.56188008e-01}},
14+
{OO::MSE, 1.026356e-05, 9.214909e-03, 3.985e+05, {+9.76213454e-01, -2.00030200e-01}},
15+
{OO::MSE, 1.577588e-07, 1.323851e-03, 6.724e+04, {+9.95982073e-01, -2.92278128e-01, +8.30180680e-02}},
16+
{OO::MSE, 2.849011e-09, 1.992218e-04, 1.142e+04, {+9.99316541e-01, -3.22286501e-01, +1.49032461e-01, -4.08635592e-02}},
17+
{OO::MSE, 5.667504e-11, 3.080100e-05, 1.945e+03, {+9.99883373e-01, -3.30599535e-01, +1.81451316e-01, -8.71733830e-02, +2.18671936e-02}},
18+
{OO::MSE, 1.202662e-12, 4.846916e-06, 3.318e+02, {+9.99980065e-01, -3.32694393e-01, +1.94019697e-01, -1.17694732e-01, +5.40822080e-02, -1.22995279e-02}},
19+
{OO::MSE, 2.672889e-14, 7.722732e-07, 5.664e+01, {+9.99996589e-01, -3.33190090e-01, +1.98232868e-01, -1.32941469e-01, +8.07623712e-02, -3.46124853e-02, +7.15115276e-03}},
20+
{OO::MSE, 6.147315e-16, 1.245768e-07, 9.764e+00, {+9.99999416e-01, -3.33302229e-01, +1.99511173e-01, -1.39332647e-01, +9.70944891e-02, -5.68823386e-02, +2.25679012e-02, -4.25772648e-03}},
21+
22+
{OO::MAE, 1.097847e-03, 4.801638e-02, 2.793e+06, {+8.33414544e-01}},
23+
{OO::MAE, 1.209593e-05, 4.968992e-03, 4.623e+05, {+9.72410454e-01, -1.91981283e-01}},
24+
{OO::MAE, 1.839382e-07, 6.107084e-04, 7.766e+04, {+9.95360080e-01, -2.88702052e-01, +7.93508437e-02}},
25+
{OO::MAE, 3.296902e-09, 8.164167e-05, 1.313e+04, {+9.99214108e-01, -3.21178073e-01, +1.46272006e-01, -3.89915187e-02}},
26+
{OO::MAE, 6.523525e-11, 1.147459e-05, 2.229e+03, {+9.99866373e-01, -3.30305517e-01, +1.80162434e-01, -8.51611537e-02, +2.08475020e-02}},
27+
{OO::MAE, 1.378842e-12, 1.667328e-06, 3.792e+02, {+9.99977226e-01, -3.32622991e-01, +1.93541452e-01, -1.16429278e-01, +5.26504600e-02, -1.17203722e-02}},
28+
{OO::MAE, 3.055131e-14, 2.480947e-07, 6.457e+01, {+9.99996113e-01, -3.33173716e-01, +1.98078484e-01, -1.32334692e-01, +7.96260166e-02, -3.36062649e-02, +6.81247117e-03}},
29+
{OO::MAE, 7.013215e-16, 3.757868e-08, 1.102e+01, {+9.99999336e-01, -3.33298615e-01, +1.99465749e-01, -1.39086791e-01, +9.64233077e-02, -5.59142254e-02, +2.18643190e-02, -4.05495427e-03}},
30+
31+
{OO::MULPE, 1.355602e-03, 1.067325e-01, 1.808e+06, {+8.92130617e-01}},
32+
{OO::MULPE, 2.100588e-05, 1.075508e-02, 1.822e+05, {+9.89111122e-01, -2.14468039e-01}},
33+
{OO::MULPE, 3.573985e-07, 1.316370e-03, 2.227e+04, {+9.98665077e-01, -3.02990987e-01, +9.10404434e-02}},
34+
{OO::MULPE, 6.474958e-09, 1.548508e-04, 2.619e+03, {+9.99842198e-01, -3.26272641e-01, +1.56294460e-01, -4.46207045e-02}},
35+
{OO::MULPE, 1.313474e-10, 2.533532e-05, 4.294e+02, {+9.99974110e-01, -3.31823782e-01, +1.85886095e-01, -9.30024008e-02, +2.43894760e-02}},
36+
{OO::MULPE, 3.007880e-12, 3.530685e-06, 5.983e+01, {+9.99996388e-01, -3.33036463e-01, +1.95959706e-01, -1.22068745e-01, +5.83403647e-02, -1.37966171e-02}},
37+
{OO::MULPE, 6.348880e-14, 4.882649e-07, 8.276e+00, {+9.99999499e-01, -3.33273408e-01, +1.98895454e-01, -1.35153794e-01, +8.43185278e-02, -3.73434598e-02, +7.95583230e-03}},
38+
{OO::MULPE, 1.369569e-15, 7.585036e-08, 1.284e+00, {+9.99999922e-01, -3.33320840e-01, +1.99708563e-01, -1.40257063e-01, +9.93094012e-02, -5.97138046e-02, +2.44056181e-02, -4.73371006e-03}},
39+
40+
{OO::MULPE_MAE, 9.548909e-04, 6.131488e-02, 2.570e+06, {+8.46713042e-01}},
41+
{OO::MULPE_MAE, 1.159917e-05, 6.746680e-03, 3.778e+05, {+9.77449762e-01, -1.98798279e-01}},
42+
{OO::MULPE_MAE, 1.783646e-07, 8.575388e-04, 6.042e+04, {+9.96388826e-01, -2.92591679e-01, +8.24585555e-02}},
43+
{OO::MULPE_MAE, 3.265269e-09, 1.190548e-04, 9.505e+03, {+9.99430906e-01, -3.22774535e-01, +1.49370817e-01, -4.07480795e-02}},
44+
{OO::MULPE_MAE, 6.574962e-11, 1.684690e-05, 1.515e+03, {+9.99909079e-01, -3.30795737e-01, +1.81810037e-01, -8.72860225e-02, +2.17776539e-02}},
45+
{OO::MULPE_MAE, 1.380489e-12, 2.497538e-06, 2.510e+02, {+9.99984893e-01, -3.32748885e-01, +1.94193211e-01, -1.17865932e-01, +5.40633775e-02, -1.22309990e-02}},
46+
{OO::MULPE_MAE, 3.053218e-14, 3.784868e-07, 4.181e+01, {+9.99997480e-01, -3.33205127e-01, +1.98309644e-01, -1.33094430e-01, +8.08643094e-02, -3.45859503e-02, +7.11261604e-03}},
47+
{OO::MULPE_MAE, 7.018877e-16, 5.862915e-08, 6.942e+00, {+9.99999581e-01, -3.33306326e-01, +1.99542180e-01, -1.39433369e-01, +9.72462857e-02, -5.69734398e-02, +2.25639390e-02, -4.24074590e-03}},
4648
};
4749
} // namespace
4850

@@ -86,12 +88,13 @@ const Approximation *find_best_approximation(const std::vector<Approximation> &t
8688
}
8789

8890
if (precision.constraint_max_absolute_error > 0.0 && precision.constraint_max_absolute_error < e.mae) {
89-
penalty += 20 * extra_term_cost; // penalty for not getting the required precision.
91+
float error_ratio = e.mae / precision.constraint_max_absolute_error;
92+
penalty += 20 * error_ratio * extra_term_cost; // penalty for not getting the required precision.
9093
}
9194

9295
double score = obj_score + term_count_score + precision_score - penalty;
9396
// std::printf("Score for %zu (%zu terms): %f = %d + %d + %f - penalty %f\n", i, e.coefficients.size(), score, obj_score, term_count_score, precision_score, penalty);
94-
if (score > best_score) {
97+
if (score > best_score || best == nullptr) {
9598
best = &e;
9699
best_score = score;
97100
}

src/polynomial_optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def print_comment(indent=""):
281281
print()
282282

283283
if args.format in ["all", "table"]:
284-
print("{ApproximationPrecision::" + loss.upper() + f", {mean_squared_error:.6e}, {max_abs_error:.6e}, {max_ulp_error:.6e}, "
284+
print("{ApproximationPrecision::" + loss.upper() + f", {mean_squared_error:.6e}, {max_abs_error:.6e}, {max_ulp_error:.3e}, "
285285
+ "{" + ", ".join([f"{c:+.8e}" for c in coeffs]) + "}},")
286286
print()
287287

test/correctness/fast_arctan.cpp

+33-10
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,43 @@ int bits_diff(float fa, float fb) {
1717
return count;
1818
}
1919

20+
int ulp_diff(float fa, float fb) {
21+
uint32_t a = Halide::Internal::reinterpret_bits<uint32_t>(fa);
22+
uint32_t b = Halide::Internal::reinterpret_bits<uint32_t>(fb);
23+
return std::abs(int64_t(a) - int64_t(b));
24+
}
25+
2026
int main(int argc, char **argv) {
2127
Target target = get_jit_target_from_environment();
2228

2329
struct Test {
2430
ApproximationPrecision precision;
2531
const char *objective;
32+
float expected_mae{0.0};
2633
} precisions_to_test[] = {
2734
// MAE
2835
{{ApproximationPrecision::MAE, 0, 1e-2}, "MAE"},
2936
{{ApproximationPrecision::MAE, 0, 1e-3}, "MAE"},
3037
{{ApproximationPrecision::MAE, 0, 1e-4}, "MAE"},
3138
{{ApproximationPrecision::MAE, 0, 1e-5}, "MAE"},
3239
{{ApproximationPrecision::MAE, 0, 1e-6}, "MAE"},
40+
{{ApproximationPrecision::MAE, 0, 1e-7}, "MAE", 5e-7f},
3341

3442
// MULPE
35-
{{ApproximationPrecision::MULPE, 0, 1e-2f}, "MULPE"},
36-
{{ApproximationPrecision::MULPE, 0, 1e-3f}, "MULPE"},
37-
{{ApproximationPrecision::MULPE, 0, 1e-4f}, "MULPE"},
38-
{{ApproximationPrecision::MULPE, 0, 1e-5f}, "MULPE"},
39-
{{ApproximationPrecision::MULPE, 0, 1e-6f}, "MULPE"},
43+
{{ApproximationPrecision::MULPE, 0, 1e-2}, "MULPE"},
44+
{{ApproximationPrecision::MULPE, 0, 1e-3}, "MULPE"},
45+
{{ApproximationPrecision::MULPE, 0, 1e-4}, "MULPE"},
46+
{{ApproximationPrecision::MULPE, 0, 1e-5}, "MULPE"},
47+
{{ApproximationPrecision::MULPE, 0, 1e-6}, "MULPE"},
48+
{{ApproximationPrecision::MULPE, 0, 1e-7}, "MULPE", 5e-7f},
49+
50+
// MULPE + MAE
51+
{{ApproximationPrecision::MULPE_MAE, 0, 1e-2}, "MULPE+MAE"},
52+
{{ApproximationPrecision::MULPE_MAE, 0, 1e-3}, "MULPE+MAE"},
53+
{{ApproximationPrecision::MULPE_MAE, 0, 1e-4}, "MULPE+MAE"},
54+
{{ApproximationPrecision::MULPE_MAE, 0, 1e-5}, "MULPE+MAE"},
55+
{{ApproximationPrecision::MULPE_MAE, 0, 1e-6}, "MULPE+MAE"},
56+
{{ApproximationPrecision::MULPE_MAE, 0, 1e-7}, "MULPE+MAE", 5e-7},
4057
};
4158

4259
for (Test test : precisions_to_test) {
@@ -57,24 +74,27 @@ int main(int argc, char **argv) {
5774
atan_f.vectorize(x, 8);
5875
}
5976

60-
printf(" Testing fast_atan() correctness... ");
77+
printf(" Testing fast_atan() correctness... ");
6178
Buffer<float> atan_result = atan_f.realize({steps});
6279
float max_error = 0.0f;
6380
int max_mantissa_error = 0;
81+
int max_ulp_error = 0;
6482
for (int i = 0; i < steps; ++i) {
6583
const float x = (i - steps / 2) / float(steps / 8);
6684
const float atan_x = atan_result(i);
6785
const float atan_x_ref = atan(x);
6886
float abs_error = std::abs(atan_x_ref - atan_x);
6987
int mantissa_error = bits_diff(atan_x, atan_x_ref);
88+
int ulp_error = ulp_diff(atan_x, atan_x_ref);
7089
max_error = std::max(max_error, abs_error);
7190
max_mantissa_error = std::max(max_mantissa_error, mantissa_error);
72-
if (abs_error > test.precision.constraint_max_absolute_error) {
91+
max_ulp_error = std::max(max_ulp_error, ulp_error);
92+
if (abs_error > std::max(test.precision.constraint_max_absolute_error, test.expected_mae)) {
7393
fprintf(stderr, "fast_atan(%.6f) = %.20f not equal to %.20f (error=%.5e)\n", x, atan_x, atan_x_ref, atan_x_ref - atan_x);
7494
exit(1);
7595
}
7696
}
77-
printf("Passed: max abs error: %.5e max mantissa bits wrong: %d\n", max_error, max_mantissa_error);
97+
printf("Passed: max abs error: %.5e max ULP error: %6d max mantissa bits wrong: %2d\n", max_error, max_ulp_error, max_mantissa_error);
7898

7999
atan2_f(x, y) = fast_atan2(vx, vy, test.precision);
80100
if (target.has_gpu_feature()) {
@@ -89,6 +109,7 @@ int main(int argc, char **argv) {
89109
Buffer<float> atan2_result = atan2_f.realize({steps, steps});
90110
max_error = 0.0f;
91111
max_mantissa_error = 0;
112+
max_ulp_error = 0;
92113
for (int i = 0; i < steps; ++i) {
93114
const float x = (i - steps / 2) / float(steps / 8);
94115
for (int j = 0; j < steps; ++j) {
@@ -97,15 +118,17 @@ int main(int argc, char **argv) {
97118
const float atan2_x_y_ref = atan2(x, y);
98119
float abs_error = std::abs(atan2_x_y_ref - atan2_x_y);
99120
int mantissa_error = bits_diff(atan2_x_y, atan2_x_y_ref);
121+
int ulp_error = ulp_diff(atan2_x_y, atan2_x_y_ref);
100122
max_error = std::max(max_error, abs_error);
101123
max_mantissa_error = std::max(max_mantissa_error, mantissa_error);
102-
if (abs_error > test.precision.constraint_max_absolute_error) {
124+
max_ulp_error = std::max(max_ulp_error, ulp_error);
125+
if (abs_error > std::max(test.precision.constraint_max_absolute_error, test.expected_mae)) {
103126
fprintf(stderr, "fast_atan2(%.6f, %.6f) = %.20f not equal to %.20f (error=%.5e)\n", x, y, atan2_x_y, atan2_x_y_ref, atan2_x_y_ref - atan2_x_y);
104127
exit(1);
105128
}
106129
}
107130
}
108-
printf("Passed: max abs error: %.5e max mantissa bits wrong: %d\n", max_error, max_mantissa_error);
131+
printf("Passed: max abs error: %.5e max ULP error: %6d max mantissa bits wrong: %2d\n", max_error, max_ulp_error, max_mantissa_error);
109132
}
110133

111134
printf("Success!\n");

test/performance/fast_arctan.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ int main(int argc, char **argv) {
7373
{{ApproximationPrecision::MULPE, 6}, "Poly6"},
7474
{{ApproximationPrecision::MULPE, 7}, "Poly7"},
7575
{{ApproximationPrecision::MULPE, 8}, "Poly8"},
76+
77+
{{ApproximationPrecision::MULPE, 0, 1e-2}, "MAE 1e-2"},
78+
{{ApproximationPrecision::MULPE, 0, 1e-3}, "MAE 1e-3"},
79+
{{ApproximationPrecision::MULPE, 0, 1e-4}, "MAE 1e-4"},
80+
{{ApproximationPrecision::MULPE, 0, 1e-5}, "MAE 1e-5"},
81+
{{ApproximationPrecision::MULPE, 0, 1e-6}, "MAE 1e-6"},
82+
{{ApproximationPrecision::MULPE, 0, 1e-7}, "MAE 1e-7"},
83+
{{ApproximationPrecision::MULPE, 0, 1e-8}, "MAE 1e-8"},
7684
};
7785

7886
for (Prec &precision : precisions_to_test) {

0 commit comments

Comments
 (0)