Skip to content

Commit 8a95e72

Browse files
authored
Merge pull request #67 from oschub/sincos
Add CUDA rewrites for sincos(x) and exp(y) for complex y
2 parents d8787e0 + 3e5a9f1 commit 8a95e72

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/backends/cuda.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ function (obj::Kernel{CUDA})(args...; ndrange=nothing, dependencies=nothing, wor
150150
workgroupsize = (256,)
151151
end
152152
# If the kernel is statically sized we can tell the compiler about that
153-
if KernelAbstractions.workgroupsize(obj) <: StaticSize
153+
if KernelAbstractions.workgroupsize(obj) <: StaticSize
154154
maxthreads = prod(get(KernelAbstractions.workgroupsize(obj)))
155155
else
156156
maxthreads = nothing
@@ -244,6 +244,9 @@ for f in cudafuns
244244
end
245245
end
246246

247+
@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (CUDAnative.sin(x), CUDAnative.cos(x))
248+
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = CUDAnative.exp(x)
249+
247250

248251
###
249252
# GPU implementation of shared memory

0 commit comments

Comments
 (0)