Skip to content

Commit f107f2d

Browse files
committed
Rewrite function calls to math functions to the native built-in API function for GPU backends.
1 parent 813920f commit f107f2d

7 files changed

+138
-187
lines changed

src/CodeGen_C.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -1902,7 +1902,14 @@ string CodeGen_C::print_extern_call(const Call *op) {
19021902
if (function_takes_user_context(op->name)) {
19031903
args.insert(args.begin(), "_ucon");
19041904
}
1905-
rhs << op->name << "(" << with_commas(args) << ")";
1905+
std::string name = op->name;
1906+
auto it = extern_function_name_map.find(name);
1907+
if (it != extern_function_name_map.end()) {
1908+
name = it->second;
1909+
debug(3) << "Rewriting " << op->name << " as " << name << "\n";
1910+
}
1911+
debug(3) << "Writing out call to " << name << "\n";
1912+
rhs << name << "(" << with_commas(args) << ")";
19061913
return rhs.str();
19071914
}
19081915

src/CodeGen_C.h

+2
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class CodeGen_C : public IRPrinter {
134134
* use different syntax for other C-like languages. */
135135
virtual void add_vector_typedefs(const std::set<Type> &vector_types);
136136

137+
std::unordered_map<std::string, std::string> extern_function_name_map;
138+
137139
/** Bottleneck to allow customization of calls to generic Extern/PureExtern calls. */
138140
virtual std::string print_extern_call(const Call *op);
139141

src/CodeGen_GPU_Dev.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,16 @@ void CodeGen_GPU_C::visit(const Shuffle *op) {
237237
}
238238

239239
void CodeGen_GPU_C::visit(const Call *op) {
240-
// In metal and opencl, "rint" is a polymorphic function that matches our
241-
// rounding semantics. GLSL handles it separately using "roundEven".
242-
if (op->is_intrinsic(Call::round)) {
243-
print_assignment(op->type, "rint(" + print_expr(op->args[0]) + ")");
240+
if (op->is_intrinsic(Call::abs)) {
241+
internal_assert(op->args.size() == 1);
242+
std::stringstream fn;
243+
if (op->type.is_float()) {
244+
fn << "abs_f" << op->type.bits();
245+
} else {
246+
fn << "abs";
247+
}
248+
Expr equiv = Call::make(op->type, fn.str(), op->args, Call::PureExtern);
249+
equiv.accept(this);
244250
} else {
245251
CodeGen_C::visit(op);
246252
}

src/CodeGen_Metal_Dev.cpp

+43-57
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,41 @@ class CodeGen_Metal_Dev : public CodeGen_GPU_Dev {
5858
public:
5959
CodeGen_Metal_C(std::ostream &s, const Target &t)
6060
: CodeGen_GPU_C(s, t) {
61+
62+
#define alias(x, y) \
63+
extern_function_name_map[x "_f16"] = y; \
64+
extern_function_name_map[x "_f32"] = y
65+
alias("sqrt", "sqrt");
66+
alias("sin", "sin");
67+
alias("cos", "cos");
68+
alias("exp", "exp");
69+
alias("log", "log");
70+
alias("abs", "fabs"); // f-prefix!
71+
alias("floor", "floor");
72+
alias("ceil", "ceil");
73+
alias("trunc", "trunc");
74+
alias("pow", "pow");
75+
alias("asin", "asin");
76+
alias("acos", "acos");
77+
alias("tan", "tan");
78+
alias("atan", "atan");
79+
alias("atan2", "atan2");
80+
alias("sinh", "sinh");
81+
alias("asinh", "asinh");
82+
alias("cosh", "cosh");
83+
alias("acosh", "acosh");
84+
alias("tanh", "tanh");
85+
alias("atanh", "atanh");
86+
87+
alias("is_nan", "isnan");
88+
alias("is_inf", "isinf");
89+
alias("is_finite", "isfinite");
90+
91+
alias("fast_inverse", "native_recip");
92+
alias("fast_inverse_sqrt", "native_rsqrt");
93+
#undef alias
94+
95+
6196
}
6297
void add_kernel(const Stmt &stmt,
6398
const std::string &name,
@@ -209,13 +244,7 @@ string simt_intrinsic(const string &name) {
209244

210245
string CodeGen_Metal_Dev::CodeGen_Metal_C::print_extern_call(const Call *op) {
211246
internal_assert(!function_takes_user_context(op->name)) << op->name;
212-
vector<string> args(op->args.size());
213-
for (size_t i = 0; i < op->args.size(); i++) {
214-
args[i] = print_expr(op->args[i]);
215-
}
216-
ostringstream rhs;
217-
rhs << op->name << "(" << with_commas(args) << ")";
218-
return rhs.str();
247+
return CodeGen_GPU_C::print_extern_call(op);
219248
}
220249

221250
void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Max *op) {
@@ -331,6 +360,13 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Call *op) {
331360
}
332361
stream << ");\n";
333362
print_assignment(op->type, "0");
363+
} else if (op->is_intrinsic(Call::absd)) {
364+
Expr equiv = Call::make(op->type, "absdiff", op->args, Call::PureExtern);
365+
equiv.accept(this);
366+
} else if (op->is_intrinsic(Call::round)) {
367+
// In Metal, rint matches our rounding semantics
368+
Expr equiv = Call::make(op->type, "rint", op->args, Call::PureExtern);
369+
equiv.accept(this);
334370
} else {
335371
CodeGen_GPU_C::visit(op);
336372
}
@@ -809,56 +845,6 @@ void CodeGen_Metal_Dev::init_module() {
809845
<< "constexpr float neg_inf_f32() { return float_from_bits(0xff800000); }\n"
810846
<< "constexpr float inf_f32() { return float_from_bits(0x7f800000); }\n"
811847
<< "float fast_inverse_f32(float x) { return 1.0f / x; }\n"
812-
<< "#define is_nan_f32 isnan\n"
813-
<< "#define is_inf_f32 isinf\n"
814-
<< "#define is_finite_f32 isfinite\n"
815-
<< "#define sqrt_f32 sqrt\n"
816-
<< "#define sin_f32 sin\n"
817-
<< "#define cos_f32 cos\n"
818-
<< "#define exp_f32 exp\n"
819-
<< "#define log_f32 log\n"
820-
<< "#define abs_f32 fabs\n"
821-
<< "#define floor_f32 floor\n"
822-
<< "#define ceil_f32 ceil\n"
823-
<< "#define trunc_f32 trunc\n"
824-
<< "#define pow_f32 pow\n"
825-
<< "#define asin_f32 asin\n"
826-
<< "#define acos_f32 acos\n"
827-
<< "#define tan_f32 tan\n"
828-
<< "#define atan_f32 atan\n"
829-
<< "#define atan2_f32 atan2\n"
830-
<< "#define sinh_f32 sinh\n"
831-
<< "#define asinh_f32 asinh\n"
832-
<< "#define cosh_f32 cosh\n"
833-
<< "#define acosh_f32 acosh\n"
834-
<< "#define tanh_f32 tanh\n"
835-
<< "#define atanh_f32 atanh\n"
836-
<< "#define fast_inverse_sqrt_f32 rsqrt\n"
837-
<< "#define is_nan_f16 isnan\n"
838-
<< "#define is_inf_f16 isinf\n"
839-
<< "#define is_finite_f16 isfinite\n"
840-
<< "#define sqrt_f16 sqrt\n"
841-
<< "#define sin_f16 sin\n"
842-
<< "#define cos_f16 cos\n"
843-
<< "#define exp_f16 exp\n"
844-
<< "#define log_f16 log\n"
845-
<< "#define abs_f16 fabs\n"
846-
<< "#define floor_f16 floor\n"
847-
<< "#define ceil_f16 ceil\n"
848-
<< "#define trunc_f16 trunc\n"
849-
<< "#define pow_f16 pow\n"
850-
<< "#define asin_f16 asin\n"
851-
<< "#define acos_f16 acos\n"
852-
<< "#define tan_f16 tan\n"
853-
<< "#define atan_f16 atan\n"
854-
<< "#define atan2_f16 atan2\n"
855-
<< "#define sinh_f16 sinh\n"
856-
<< "#define asinh_f16 asinh\n"
857-
<< "#define cosh_f16 cosh\n"
858-
<< "#define acosh_f16 acosh\n"
859-
<< "#define tanh_f16 tanh\n"
860-
<< "#define atanh_f16 atanh\n"
861-
<< "#define fast_inverse_sqrt_f16 rsqrt\n"
862848
<< "constexpr half half_from_bits(unsigned short x) {return as_type<half>(x);}\n"
863849
<< "constexpr half nan_f16() { return half_from_bits(32767); }\n"
864850
<< "constexpr half neg_inf_f16() { return half_from_bits(64512); }\n"

src/CodeGen_OpenCL_Dev.cpp

+38-94
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,40 @@ class CodeGen_OpenCL_Dev : public CodeGen_GPU_Dev {
6262
: CodeGen_GPU_C(s, t) {
6363
integer_suffix_style = IntegerSuffixStyle::OpenCL;
6464
vector_declaration_style = VectorDeclarationStyle::OpenCLSyntax;
65+
66+
#define alias(x, y) \
67+
extern_function_name_map[x "_f16"] = y; \
68+
extern_function_name_map[x "_f32"] = y; \
69+
extern_function_name_map[x "_f64"] = y
70+
alias("sqrt", "sqrt");
71+
alias("sin", "sin");
72+
alias("cos", "cos");
73+
alias("exp", "exp");
74+
alias("log", "log");
75+
alias("abs", "fabs"); // f-prefix! (although it's handled as an intrinsic).
76+
alias("floor", "floor");
77+
alias("ceil", "ceil");
78+
alias("trunc", "trunc");
79+
alias("pow", "pow");
80+
alias("asin", "asin");
81+
alias("acos", "acos");
82+
alias("tan", "tan");
83+
alias("atan", "atan");
84+
alias("atan2", "atan2");
85+
alias("sinh", "sinh");
86+
alias("asinh", "asinh");
87+
alias("cosh", "cosh");
88+
alias("acosh", "acosh");
89+
alias("tanh", "tanh");
90+
alias("atanh", "atanh");
91+
92+
alias("is_nan", "isnan");
93+
alias("is_inf", "isinf");
94+
alias("is_finite", "isfinite");
95+
96+
alias("fast_inverse", "native_recip");
97+
alias("fast_inverse_sqrt", "native_rsqrt");
98+
#undef alias
6599
}
66100
void add_kernel(Stmt stmt,
67101
const std::string &name,
@@ -300,16 +334,6 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Call *op) {
300334
ostringstream rhs;
301335
rhs << "select(" << false_val << ", " << true_val << ", " << cond << ")";
302336
print_assignment(op->type, rhs.str());
303-
} else if (op->is_intrinsic(Call::abs)) {
304-
if (op->type.is_float()) {
305-
ostringstream rhs;
306-
rhs << "abs_f" << op->type.bits() << "(" << print_expr(op->args[0]) << ")";
307-
print_assignment(op->type, rhs.str());
308-
} else {
309-
ostringstream rhs;
310-
rhs << "abs(" << print_expr(op->args[0]) << ")";
311-
print_assignment(op->type, rhs.str());
312-
}
313337
} else if (op->is_intrinsic(Call::absd)) {
314338
ostringstream rhs;
315339
rhs << "abs_diff(" << print_expr(op->args[0]) << ", " << print_expr(op->args[1]) << ")";
@@ -466,13 +490,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Call *op) {
466490

467491
string CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::print_extern_call(const Call *op) {
468492
internal_assert(!function_takes_user_context(op->name)) << op->name;
469-
vector<string> args(op->args.size());
470-
for (size_t i = 0; i < op->args.size(); i++) {
471-
args[i] = print_expr(op->args[i]);
472-
}
473-
ostringstream rhs;
474-
rhs << op->name << "(" << with_commas(args) << ")";
475-
return rhs.str();
493+
return CodeGen_GPU_C::print_extern_call(op);
476494
}
477495

478496
string CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::print_array_access(const string &name,
@@ -1123,64 +1141,14 @@ void CodeGen_OpenCL_Dev::init_module() {
11231141
src_stream << "inline float float_from_bits(unsigned int x) {return as_float(x);}\n"
11241142
<< "inline float nan_f32() { return NAN; }\n"
11251143
<< "inline float neg_inf_f32() { return -INFINITY; }\n"
1126-
<< "inline float inf_f32() { return INFINITY; }\n"
1127-
<< "inline bool is_nan_f32(float x) {return isnan(x); }\n"
1128-
<< "inline bool is_inf_f32(float x) {return isinf(x); }\n"
1129-
<< "inline bool is_finite_f32(float x) {return isfinite(x); }\n"
1130-
<< "#define sqrt_f32 sqrt \n"
1131-
<< "#define sin_f32 sin \n"
1132-
<< "#define cos_f32 cos \n"
1133-
<< "#define exp_f32 exp \n"
1134-
<< "#define log_f32 log \n"
1135-
<< "#define abs_f32 fabs \n"
1136-
<< "#define floor_f32 floor \n"
1137-
<< "#define ceil_f32 ceil \n"
1138-
<< "#define trunc_f32 trunc \n"
1139-
<< "#define pow_f32 pow\n"
1140-
<< "#define asin_f32 asin \n"
1141-
<< "#define acos_f32 acos \n"
1142-
<< "#define tan_f32 tan \n"
1143-
<< "#define atan_f32 atan \n"
1144-
<< "#define atan2_f32 atan2\n"
1145-
<< "#define sinh_f32 sinh \n"
1146-
<< "#define asinh_f32 asinh \n"
1147-
<< "#define cosh_f32 cosh \n"
1148-
<< "#define acosh_f32 acosh \n"
1149-
<< "#define tanh_f32 tanh \n"
1150-
<< "#define atanh_f32 atanh \n"
1151-
<< "#define fast_inverse_f32 native_recip \n"
1152-
<< "#define fast_inverse_sqrt_f32 native_rsqrt \n";
1144+
<< "inline float inf_f32() { return INFINITY; }\n";
11531145

11541146
// There does not appear to be a reliable way to safely ignore unused
11551147
// variables in OpenCL C. See https://github.com/halide/Halide/issues/4918.
11561148
src_stream << "#define halide_maybe_unused(x)\n";
11571149

11581150
if (target.has_feature(Target::CLDoubles)) {
1159-
src_stream << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
1160-
<< "inline bool is_nan_f64(double x) {return isnan(x); }\n"
1161-
<< "inline bool is_inf_f64(double x) {return isinf(x); }\n"
1162-
<< "inline bool is_finite_f64(double x) {return isfinite(x); }\n"
1163-
<< "#define sqrt_f64 sqrt\n"
1164-
<< "#define sin_f64 sin\n"
1165-
<< "#define cos_f64 cos\n"
1166-
<< "#define exp_f64 exp\n"
1167-
<< "#define log_f64 log\n"
1168-
<< "#define abs_f64 fabs\n"
1169-
<< "#define floor_f64 floor\n"
1170-
<< "#define ceil_f64 ceil\n"
1171-
<< "#define trunc_f64 trunc\n"
1172-
<< "#define pow_f64 pow\n"
1173-
<< "#define asin_f64 asin\n"
1174-
<< "#define acos_f64 acos\n"
1175-
<< "#define tan_f64 tan\n"
1176-
<< "#define atan_f64 atan\n"
1177-
<< "#define atan2_f64 atan2\n"
1178-
<< "#define sinh_f64 sinh\n"
1179-
<< "#define asinh_f64 asinh\n"
1180-
<< "#define cosh_f64 cosh\n"
1181-
<< "#define acosh_f64 acosh\n"
1182-
<< "#define tanh_f64 tanh\n"
1183-
<< "#define atanh_f64 atanh\n";
1151+
src_stream << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n";
11841152
}
11851153

11861154
if (target.has_feature(Target::CLHalf)) {
@@ -1192,31 +1160,7 @@ void CodeGen_OpenCL_Dev::init_module() {
11921160
<< "inline half half_from_bits(unsigned short x) {return __builtin_astype(x, half);}\n"
11931161
<< "inline half nan_f16() { return half_from_bits(" << nan_f16 << "); }\n"
11941162
<< "inline half neg_inf_f16() { return half_from_bits(" << neg_inf_f16 << "); }\n"
1195-
<< "inline half inf_f16() { return half_from_bits(" << inf_f16 << "); }\n"
1196-
<< "inline bool is_nan_f16(half x) {return isnan(x); }\n"
1197-
<< "inline bool is_inf_f16(half x) {return isinf(x); }\n"
1198-
<< "inline bool is_finite_f16(half x) {return isfinite(x); }\n"
1199-
<< "#define sqrt_f16 sqrt\n"
1200-
<< "#define sin_f16 sin\n"
1201-
<< "#define cos_f16 cos\n"
1202-
<< "#define exp_f16 exp\n"
1203-
<< "#define log_f16 log\n"
1204-
<< "#define abs_f16 fabs\n"
1205-
<< "#define floor_f16 floor\n"
1206-
<< "#define ceil_f16 ceil\n"
1207-
<< "#define trunc_f16 trunc\n"
1208-
<< "#define pow_f16 pow\n"
1209-
<< "#define asin_f16 asin\n"
1210-
<< "#define acos_f16 acos\n"
1211-
<< "#define tan_f16 tan\n"
1212-
<< "#define atan_f16 atan\n"
1213-
<< "#define atan2_f16 atan2\n"
1214-
<< "#define sinh_f16 sinh\n"
1215-
<< "#define asinh_f16 asinh\n"
1216-
<< "#define cosh_f16 cosh\n"
1217-
<< "#define acosh_f16 acosh\n"
1218-
<< "#define tanh_f16 tanh\n"
1219-
<< "#define atanh_f16 atanh\n";
1163+
<< "inline half inf_f16() { return half_from_bits(" << inf_f16 << "); }\n";
12201164
}
12211165

12221166
if (target.has_feature(Target::CLAtomics64)) {

0 commit comments

Comments
 (0)