Skip to content

Commit 4a3acad

Browse files
Update src/KernelAbstractions.jl
Co-authored-by: Simon Byrne <[email protected]>
1 parent ca79220 commit 4a3acad

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

src/KernelAbstractions.jl

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -238,30 +238,25 @@ macro print(items...)
238238
end
239239
end
240240

241-
@generated function promote_c_argument(arg)
242-
# > When a function with a variable-length argument list is called, the variable
243-
# > arguments are passed using C's old ``default argument promotions.'' These say that
244-
# > types char and short int are automatically promoted to int, and type float is
245-
# > automatically promoted to double. Therefore, varargs functions will never receive
246-
# > arguments of type char, short int, or float.
247-
248-
if arg == Cchar || arg == Cshort
249-
return :(Cint(arg))
250-
elseif arg == Cfloat
251-
return :(Cdouble(arg))
252-
else
253-
return :(arg)
254-
end
255-
end
241+
# When a function with a variable-length argument list is called, the variable
242+
# arguments are passed using C's old ``default argument promotions.'' These say that
243+
# types char and short int are automatically promoted to int, and type float is
244+
# automatically promoted to double. Therefore, varargs functions will never receive
245+
# arguments of type char, short int, or float.
246+
247+
promote_c_argument(arg) = arg
248+
promote_c_argument(arg::Cfloat) = Cdouble(arg)
249+
promote_c_argument(arg::Cchar) = Cint(arg)
250+
promote_c_argument(arg::Cshort) = Cint(arg)
256251

257252
"""
258253
@printf(fmt::String, args...)
259254
260-
This is a unified formatted print statement.
255+
This is a unified formatted printf statement.
261256
262257
# Platform differences
263258
- `GPU`: This will reorganize the items to print via @cuprintf
264-
- `CPU`: This will call `print(items...)`
259+
- `CPU`: This will call `sprintf(fmt, items...)`
265260
"""
266261
macro printf(fmt::String, args...)
267262
fmt_val = Val(Symbol(fmt))
@@ -551,9 +546,7 @@ end
551546
end
552547
sfmt = String(fmt)
553548
quote
554-
# @sprintf($sfmt, $(args...))
555-
@print(@sprintf($sfmt, $(args...)))
556-
# @print("test")
549+
Printf.@printf($sfmt, $(args...))
557550
end
558551
end
559552

src/backends/cuda.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,11 @@ end
320320
end
321321

322322
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__printf), fmt, args...)
323-
CUDA._cuprintf(fmt, args...)
323+
CUDA._cuprintf(Val(fmt), args...)
324+
end
325+
326+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__printf), ::Val{fmt}, args...) where fmt
327+
CUDA._cuprintf(Val(fmt), args...)
324328
end
325329

326330
###

test/print_test.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ if has_cuda_gpu()
66
end
77

88
struct Foo{A,B} end
9+
get_name(::Type{T}) where T<:Foo = "Foo"
910

1011
@kernel function kernel_print()
1112
I = @index(Global)
@@ -14,7 +15,11 @@ end
1415

1516
@kernel function kernel_printf()
1617
I = @index(Global)
17-
@printf("Hello printf %s thread %d! type = %s.\n", "from", I, nameof(Foo))
18+
# @printf("Hello printf %s thread %d! type = %s.\n", "from", I, nameof(Foo))
19+
# @print("Hello printf from thread ", I, "!\n")
20+
# @printf("Hello printf %s thread %d! type = %s.\n", "from", I, string(nameof(Foo)))
21+
@printf("Hello printf %s thread %d! type = %s.\n", "from", I, "Foo")
22+
@printf("Hello printf %s thread %d! type = %s.\n", "from", I, get_name(Foo))
1823
end
1924

2025
function test_print(backend)

0 commit comments

Comments
 (0)