Skip to content

Commit 6cebc56

Browse files
committed
Move fast function calls to extern table for Metal.
1 parent 32f98e2 commit 6cebc56

File tree

4 files changed

+20
-12
lines changed

4 files changed

+20
-12
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ xcuserdata
240240
# NeoVim + clangd
241241
.cache
242242

243+
# CCLS
244+
.ccls-cache
245+
243246
# Emacs
244247
tags
245248
TAGS

src/CodeGen_Metal_Dev.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ class CodeGen_Metal_Dev : public CodeGen_GPU_Dev {
8989
alias("is_inf", "isinf");
9090
alias("is_finite", "isfinite");
9191

92+
alias("fast_sin", "fast::sin");
93+
alias("fast_cos", "fast::cos");
94+
alias("fast_tan", "fast::tan");
95+
alias("fast_exp", "fast::exp");
96+
alias("fast_log", "fast::log");
97+
alias("fast_pow", "fast::pow");
98+
alias("fast_tanh", "fast::tanh");
9299
alias("fast_inverse_sqrt", "fast::rsqrt");
93100
#undef alias
94101
}
@@ -837,14 +844,6 @@ void CodeGen_Metal_Dev::init_module() {
837844
<< "constexpr float neg_inf_f32() { return float_from_bits(0xff800000); }\n"
838845
<< "constexpr float inf_f32() { return float_from_bits(0x7f800000); }\n"
839846
<< "float fast_inverse_f32(float x) { return 1.0f / x; }\n"
840-
<< "#define fast_sin_f32 fast::sin \n"
841-
<< "#define fast_cos_f32 fast::cos \n"
842-
<< "#define fast_tan_f32 fast::tan \n"
843-
<< "#define fast_exp_f32 fast::exp \n"
844-
<< "#define fast_log_f32 fast::log \n"
845-
<< "#define fast_pow_f32 fast::pow \n"
846-
<< "#define fast_tanh_f32 fast::tanh \n"
847-
<< "#define fast_inverse_sqrt_f16 rsqrt\n"
848847
<< "constexpr half half_from_bits(unsigned short x) {return as_type<half>(x);}\n"
849848
<< "constexpr half nan_f16() { return half_from_bits(32767); }\n"
850849
<< "constexpr half neg_inf_f16() { return half_from_bits(64512); }\n"

test/correctness/fast_function_approximations.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,10 @@ int main(int argc, char **argv) {
479479
ref_func_gpu(i) = ftt.make_reference(arg_x, arg_y);
480480
ref_func_gpu.never_partition_all();
481481
// also vectorize to make sure that works on GPU as well...
482-
ref_func_gpu.gpu_tile(i, io, ii, 256, TailStrategy::ShiftInwards).vectorize(ii, 2);
482+
ref_func_gpu
483+
.gpu_tile(i, io, ii, 512, TailStrategy::ShiftInwards)
484+
.vectorize(ii, 4);
485+
// TODO(mcourteaux): When vector legalization lowering pass is in, increase vectorize for testing purposes!
483486
ref_func_gpu.realize(out_approx);
484487
out_approx.copy_to_host();
485488

@@ -519,8 +522,11 @@ int main(int argc, char **argv) {
519522
approx_func.align_bounds(i, 8);
520523
if (target.has_gpu_feature()) {
521524
Var io, ii;
522-
approx_func.never_partition_all();
523-
approx_func.gpu_tile(i, io, ii, 256, TailStrategy::ShiftInwards);
525+
approx_func
526+
.never_partition_all()
527+
.gpu_tile(i, io, ii, 256, TailStrategy::ShiftInwards)
528+
.vectorize(ii, 4);
529+
// TODO(mcourteaux): When vector legalization lowering pass is in, increase vectorize for testing.
524530
} else {
525531
approx_func.vectorize(i, 8);
526532
}

test/performance/fast_function_approximations.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ int main(int argc, char **argv) {
179179
std::function<void(Func &)> schedule = [&](Func &f) {
180180
if (target.has_gpu_feature()) {
181181
f.never_partition_all();
182-
f.gpu_tile(x, y, xo, yo, xi, yi, 16, 16, TailStrategy::ShiftInwards);
182+
f.gpu_tile(x, y, xo, yo, xi, yi, 64, 16, TailStrategy::ShiftInwards).vectorize(xi, 4);
183183
} else {
184184
f.vectorize(x, 8);
185185
}

0 commit comments

Comments
 (0)