Skip to content

Commit 58e4a7f

Browse files
committed
Test vectorized support for math functions in correctness/math.cpp
1 parent f107f2d commit 58e4a7f

7 files changed

+68
-58
lines changed

src/CodeGen_D3D12Compute_Dev.cpp

+37-38
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,40 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev {
6767
CodeGen_D3D12Compute_C(std::ostream &s, const Target &t)
6868
: CodeGen_GPU_C(s, t) {
6969
integer_suffix_style = IntegerSuffixStyle::HLSL;
70+
71+
#define alias(x, y) \
72+
extern_function_name_map[x "_f16"] = y; \
73+
extern_function_name_map[x "_f32"] = y; \
74+
extern_function_name_map[x "_f64"] = y
75+
alias("sqrt", "sqrt");
76+
alias("sin", "sin");
77+
alias("cos", "cos");
78+
alias("exp", "exp");
79+
alias("log", "log");
80+
alias("abs", "abs");
81+
alias("floor", "floor");
82+
alias("ceil", "ceil");
83+
alias("trunc", "trunc");
84+
alias("pow", "pow");
85+
alias("asin", "asin");
86+
alias("acos", "acos");
87+
alias("tan", "tan");
88+
alias("atan", "atan");
89+
alias("atan2", "atan2");
90+
alias("sinh", "sinh");
91+
alias("asinh", "asinh");
92+
alias("cosh", "cosh");
93+
alias("acosh", "acosh");
94+
alias("tanh", "tanh");
95+
alias("atanh", "atanh");
96+
97+
alias("is_nan", "isnan");
98+
alias("is_inf", "isinf");
99+
alias("is_finite", "isfinite");
100+
101+
alias("fast_inverse", "rcp");
102+
alias("fast_inverse_sqrt", "rsqrt");
103+
#undef alias
70104
}
71105
void add_kernel(Stmt stmt,
72106
const std::string &name,
@@ -79,7 +113,6 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev {
79113
std::string print_storage_type(Type type);
80114
std::string print_type_maybe_storage(Type type, bool storage, AppendSpaceIfNeeded space);
81115
std::string print_reinterpret(Type type, const Expr &e) override;
82-
std::string print_extern_call(const Call *op) override;
83116

84117
std::string print_vanilla_cast(Type type, const std::string &value_expr);
85118
std::string print_reinforced_cast(Type type, const std::string &value_expr);
@@ -247,18 +280,6 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Evaluate *op)
247280
print_expr(op->value);
248281
}
249282

250-
string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_extern_call(const Call *op) {
251-
internal_assert(!function_takes_user_context(op->name)) << op->name;
252-
253-
vector<string> args(op->args.size());
254-
for (size_t i = 0; i < op->args.size(); i++) {
255-
args[i] = print_expr(op->args[i]);
256-
}
257-
ostringstream rhs;
258-
rhs << op->name << "(" << with_commas(args) << ")";
259-
return rhs.str();
260-
}
261-
262283
void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Max *op) {
263284
print_expr(Call::make(op->type, "max", {op->a, op->b}, Call::Extern));
264285
}
@@ -1290,19 +1311,7 @@ void CodeGen_D3D12Compute_Dev::init_module() {
12901311
<< "float nan_f32() { return 1.#IND; } \n" // Quiet NaN with minimum fractional value.
12911312
<< "float neg_inf_f32() { return -1.#INF; } \n"
12921313
<< "float inf_f32() { return +1.#INF; } \n"
1293-
<< "#define is_inf_f32 isinf \n"
1294-
<< "#define is_finite_f32 isfinite \n"
1295-
<< "#define is_nan_f32 isnan \n"
12961314
<< "#define float_from_bits asfloat \n"
1297-
<< "#define sqrt_f32 sqrt \n"
1298-
<< "#define sin_f32 sin \n"
1299-
<< "#define cos_f32 cos \n"
1300-
<< "#define exp_f32 exp \n"
1301-
<< "#define log_f32 log \n"
1302-
<< "#define abs_f32 abs \n"
1303-
<< "#define floor_f32 floor \n"
1304-
<< "#define ceil_f32 ceil \n"
1305-
<< "#define trunc_f32 trunc \n"
13061315
// pow() in HLSL has the same semantics as C if
13071316
// x > 0. Otherwise, we need to emulate C
13081317
// behavior.
@@ -1322,19 +1331,9 @@ void CodeGen_D3D12Compute_Dev::init_module() {
13221331
<< " return nan_f32(); \n"
13231332
<< " } \n"
13241333
<< "} \n"
1325-
<< "#define asin_f32 asin \n"
1326-
<< "#define acos_f32 acos \n"
1327-
<< "#define tan_f32 tan \n"
1328-
<< "#define atan_f32 atan \n"
1329-
<< "#define atan2_f32 atan2 \n"
1330-
<< "#define sinh_f32 sinh \n"
1331-
<< "#define cosh_f32 cosh \n"
1332-
<< "#define tanh_f32 tanh \n"
1333-
<< "#define asinh_f32(x) (log_f32(x + sqrt_f32(x*x + 1))) \n"
1334-
<< "#define acosh_f32(x) (log_f32(x + sqrt_f32(x*x - 1))) \n"
1335-
<< "#define atanh_f32(x) (log_f32((1+x)/(1-x))/2) \n"
1336-
<< "#define fast_inverse_f32 rcp \n"
1337-
<< "#define fast_inverse_sqrt_f32 rsqrt \n"
1334+
<< "#define asinh(x) (log(x + sqrt(x*x + 1))) \n"
1335+
<< "#define acosh(x) (log(x + sqrt(x*x - 1))) \n"
1336+
<< "#define atanh(x) (log((1+x)/(1-x))/2) \n"
13381337
<< "\n";
13391338
//<< "}\n"; // close namespace
13401339

src/CodeGen_GPU_Dev.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "CodeGen_GPU_Dev.h"
2+
#include "CodeGen_Internal.h"
23
#include "CanonicalizeGPUVars.h"
34
#include "Deinterleave.h"
45
#include "ExprUsesVar.h"
@@ -252,5 +253,31 @@ void CodeGen_GPU_C::visit(const Call *op) {
252253
}
253254
}
254255

256+
257+
std::string CodeGen_GPU_C::print_extern_call(const Call *op) {
258+
internal_assert(!function_takes_user_context(op->name)) << op->name;
259+
260+
// Here we do not scalarize function calls with vector arguments.
261+
// Backends should provide those functions, and if not available,
262+
// we could compose them by writing out a call element by element,
263+
// but that's never happened until 2025, so I guess we can leave
264+
// this to be an error for now, just like it was.
265+
266+
std::ostringstream rhs;
267+
std::vector<std::string> args(op->args.size());
268+
for (size_t i = 0; i < op->args.size(); i++) {
269+
args[i] = print_expr(op->args[i]);
270+
}
271+
std::string name = op->name;
272+
auto it = extern_function_name_map.find(name);
273+
if (it != extern_function_name_map.end()) {
274+
name = it->second;
275+
debug(3) << "Rewriting " << op->name << " as " << name << "\n";
276+
}
277+
debug(3) << "Writing out call to " << name << "\n";
278+
rhs << name << "(" << with_commas(args) << ")";
279+
return rhs.str();
280+
}
281+
255282
} // namespace Internal
256283
} // namespace Halide

src/CodeGen_GPU_Dev.h

+2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ class CodeGen_GPU_C : public CodeGen_C {
100100
void visit(const Shuffle *op) override;
101101
void visit(const Call *op) override;
102102

103+
std::string print_extern_call(const Call *op) override;
104+
103105
VectorDeclarationStyle vector_declaration_style = VectorDeclarationStyle::CLikeSyntax;
104106
};
105107

src/CodeGen_Metal_Dev.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ class CodeGen_Metal_Dev : public CodeGen_GPU_Dev {
111111
std::string print_storage_type(Type type);
112112
std::string print_type_maybe_storage(Type type, bool storage, AppendSpaceIfNeeded space);
113113
std::string print_reinterpret(Type type, const Expr &e) override;
114-
std::string print_extern_call(const Call *op) override;
115114

116115
std::string get_memory_space(const std::string &);
117116

@@ -242,11 +241,6 @@ string simt_intrinsic(const string &name) {
242241
}
243242
} // namespace
244243

245-
string CodeGen_Metal_Dev::CodeGen_Metal_C::print_extern_call(const Call *op) {
246-
internal_assert(!function_takes_user_context(op->name)) << op->name;
247-
return CodeGen_GPU_C::print_extern_call(op);
248-
}
249-
250244
void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Max *op) {
251245
print_expr(Call::make(op->type, "max", {op->a, op->b}, Call::Extern));
252246
}

src/CodeGen_OpenCL_Dev.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ class CodeGen_OpenCL_Dev : public CodeGen_GPU_Dev {
105105
using CodeGen_GPU_C::visit;
106106
std::string print_type(Type type, AppendSpaceIfNeeded append_space = DoNotAppendSpace) override;
107107
std::string print_reinterpret(Type type, const Expr &e) override;
108-
std::string print_extern_call(const Call *op) override;
109108
std::string print_array_access(const std::string &name,
110109
const Type &type,
111110
const std::string &id_index);
@@ -488,11 +487,6 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Call *op) {
488487
}
489488
}
490489

491-
string CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::print_extern_call(const Call *op) {
492-
internal_assert(!function_takes_user_context(op->name)) << op->name;
493-
return CodeGen_GPU_C::print_extern_call(op);
494-
}
495-
496490
string CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::print_array_access(const string &name,
497491
const Type &type,
498492
const string &id_index) {

src/CodeGen_WebGPU_Dev.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ class CodeGen_WebGPU_Dev : public CodeGen_GPU_Dev {
102102
AppendSpaceIfNeeded append_space =
103103
DoNotAppendSpace) override;
104104
std::string print_reinterpret(Type type, const Expr &e) override;
105-
std::string print_extern_call(const Call *op) override;
106105
std::string print_assignment(Type t, const std::string &rhs) override;
107106
std::string print_const(Type t, const std::string &rhs);
108107
std::string print_assignment_or_const(Type t, const std::string &rhs,
@@ -299,11 +298,6 @@ string CodeGen_WebGPU_Dev::CodeGen_WGSL::print_reinterpret(Type type,
299298
return oss.str();
300299
}
301300

302-
string CodeGen_WebGPU_Dev::CodeGen_WGSL::print_extern_call(const Call *op) {
303-
internal_assert(!function_takes_user_context(op->name)) << op->name;
304-
return CodeGen_GPU_C::print_extern_call(op);
305-
}
306-
307301
void CodeGen_WebGPU_Dev::CodeGen_WGSL::add_kernel(
308302
const Stmt &s, const string &name, const vector<DeviceArgument> &args) {
309303
debug(2) << "Adding WGSL shader " << name << "\n";

test/correctness/math.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ struct TestArgs {
137137
Var x("x"), xi("xi"); \
138138
test_##name(x) = name(in(x)); \
139139
if (target.has_gpu_feature()) { \
140-
test_##name.gpu_tile(x, xi, 8); \
140+
test_##name.gpu_tile(x, xi, 16).vectorize(xi, 2); \
141141
} else if (target.has_feature(Target::HVX)) { \
142142
test_##name.hexagon(); \
143143
} \
@@ -168,7 +168,7 @@ struct TestArgs {
168168
Var x("x"), xi("xi"); \
169169
test_##name(x) = name(in(0, x), in(1, x)); \
170170
if (target.has_gpu_feature()) { \
171-
test_##name.gpu_tile(x, xi, 8); \
171+
test_##name.gpu_tile(x, xi, 16).vectorize(xi, 2); \
172172
} else if (target.has_feature(Target::HVX)) { \
173173
test_##name.hexagon(); \
174174
} \

0 commit comments

Comments
 (0)