@@ -67,6 +67,40 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev {
67
67
CodeGen_D3D12Compute_C (std::ostream &s, const Target &t)
68
68
: CodeGen_GPU_C(s, t) {
69
69
integer_suffix_style = IntegerSuffixStyle::HLSL;
70
+
71
+ #define alias (x, y ) \
72
+ extern_function_name_map[x " _f16" ] = y; \
73
+ extern_function_name_map[x " _f32" ] = y; \
74
+ extern_function_name_map[x " _f64" ] = y
75
+ alias (" sqrt" , " sqrt" );
76
+ alias (" sin" , " sin" );
77
+ alias (" cos" , " cos" );
78
+ alias (" exp" , " exp" );
79
+ alias (" log" , " log" );
80
+ alias (" abs" , " abs" );
81
+ alias (" floor" , " floor" );
82
+ alias (" ceil" , " ceil" );
83
+ alias (" trunc" , " trunc" );
84
+ alias (" pow" , " pow" );
85
+ alias (" asin" , " asin" );
86
+ alias (" acos" , " acos" );
87
+ alias (" tan" , " tan" );
88
+ alias (" atan" , " atan" );
89
+ alias (" atan2" , " atan2" );
90
+ alias (" sinh" , " sinh" );
91
+ alias (" asinh" , " asinh" );
92
+ alias (" cosh" , " cosh" );
93
+ alias (" acosh" , " acosh" );
94
+ alias (" tanh" , " tanh" );
95
+ alias (" atanh" , " atanh" );
96
+
97
+ alias (" is_nan" , " isnan" );
98
+ alias (" is_inf" , " isinf" );
99
+ alias (" is_finite" , " isfinite" );
100
+
101
+ alias (" fast_inverse" , " rcp" );
102
+ alias (" fast_inverse_sqrt" , " rsqrt" );
103
+ #undef alias
70
104
}
71
105
void add_kernel (Stmt stmt,
72
106
const std::string &name,
@@ -79,7 +113,6 @@ class CodeGen_D3D12Compute_Dev : public CodeGen_GPU_Dev {
79
113
std::string print_storage_type (Type type);
80
114
std::string print_type_maybe_storage (Type type, bool storage, AppendSpaceIfNeeded space);
81
115
std::string print_reinterpret (Type type, const Expr &e) override ;
82
- std::string print_extern_call (const Call *op) override ;
83
116
84
117
std::string print_vanilla_cast (Type type, const std::string &value_expr);
85
118
std::string print_reinforced_cast (Type type, const std::string &value_expr);
@@ -247,18 +280,6 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Evaluate *op)
247
280
print_expr (op->value );
248
281
}
249
282
250
- string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_extern_call (const Call *op) {
251
- internal_assert (!function_takes_user_context (op->name )) << op->name ;
252
-
253
- vector<string> args (op->args .size ());
254
- for (size_t i = 0 ; i < op->args .size (); i++) {
255
- args[i] = print_expr (op->args [i]);
256
- }
257
- ostringstream rhs;
258
- rhs << op->name << " (" << with_commas (args) << " )" ;
259
- return rhs.str ();
260
- }
261
-
262
283
void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit (const Max *op) {
263
284
print_expr (Call::make (op->type , " max" , {op->a , op->b }, Call::Extern));
264
285
}
@@ -1290,19 +1311,7 @@ void CodeGen_D3D12Compute_Dev::init_module() {
1290
1311
<< " float nan_f32() { return 1.#IND; } \n " // Quiet NaN with minimum fractional value.
1291
1312
<< " float neg_inf_f32() { return -1.#INF; } \n "
1292
1313
<< " float inf_f32() { return +1.#INF; } \n "
1293
- << " #define is_inf_f32 isinf \n "
1294
- << " #define is_finite_f32 isfinite \n "
1295
- << " #define is_nan_f32 isnan \n "
1296
1314
<< " #define float_from_bits asfloat \n "
1297
- << " #define sqrt_f32 sqrt \n "
1298
- << " #define sin_f32 sin \n "
1299
- << " #define cos_f32 cos \n "
1300
- << " #define exp_f32 exp \n "
1301
- << " #define log_f32 log \n "
1302
- << " #define abs_f32 abs \n "
1303
- << " #define floor_f32 floor \n "
1304
- << " #define ceil_f32 ceil \n "
1305
- << " #define trunc_f32 trunc \n "
1306
1315
// pow() in HLSL has the same semantics as C if
1307
1316
// x > 0. Otherwise, we need to emulate C
1308
1317
// behavior.
@@ -1322,19 +1331,9 @@ void CodeGen_D3D12Compute_Dev::init_module() {
1322
1331
<< " return nan_f32(); \n "
1323
1332
<< " } \n "
1324
1333
<< " } \n "
1325
- << " #define asin_f32 asin \n "
1326
- << " #define acos_f32 acos \n "
1327
- << " #define tan_f32 tan \n "
1328
- << " #define atan_f32 atan \n "
1329
- << " #define atan2_f32 atan2 \n "
1330
- << " #define sinh_f32 sinh \n "
1331
- << " #define cosh_f32 cosh \n "
1332
- << " #define tanh_f32 tanh \n "
1333
- << " #define asinh_f32(x) (log_f32(x + sqrt_f32(x*x + 1))) \n "
1334
- << " #define acosh_f32(x) (log_f32(x + sqrt_f32(x*x - 1))) \n "
1335
- << " #define atanh_f32(x) (log_f32((1+x)/(1-x))/2) \n "
1336
- << " #define fast_inverse_f32 rcp \n "
1337
- << " #define fast_inverse_sqrt_f32 rsqrt \n "
1334
+ << " #define asinh(x) (log(x + sqrt(x*x + 1))) \n "
1335
+ << " #define acosh(x) (log(x + sqrt(x*x - 1))) \n "
1336
+ << " #define atanh(x) (log((1+x)/(1-x))/2) \n "
1338
1337
<< " \n " ;
1339
1338
// << "}\n"; // close namespace
1340
1339
0 commit comments