1- import Base . Broadcast : Broadcasted, Extruded, BroadcastStyle, ArrayStyle
1+ # broadcasting
22
3- BroadcastStyle ( :: Type{<:CuArray} ) = ArrayStyle {CuArray} ()
3+ using Base . Broadcast : BroadcastStyle, Broadcasted
44
5- function Base. similar (bc:: Broadcasted{ArrayStyle{CuArray}} , :: Type{T} ) where T
5+ struct CuArrayStyle{N} <: AbstractGPUArrayStyle{N} end
6+ CuArrayStyle (:: Val{N} ) where N = CuArrayStyle {N} ()
7+ CuArrayStyle {M} (:: Val{N} ) where {N,M} = CuArrayStyle {N} ()
8+
9+ BroadcastStyle (:: Type{<:CuArray{T,N}} ) where {T,N} = CuArrayStyle {N} ()
10+
11+ Base. similar (bc:: Broadcasted{CuArrayStyle{N}} , :: Type{T} ) where {N,T} =
612 similar (CuArray{T}, axes (bc))
7- end
813
9- function Base. similar (bc:: Broadcasted{ArrayStyle{CuArray }} , :: Type{T} , dims... ) where {T}
10- similar ( CuArray{T}, dims... )
11- end
14+ Base. similar (bc:: Broadcasted{CuArrayStyle{N }} , :: Type{T} , dims... ) where {N,T} =
15+ CuArray {T} (undef , dims... )
16+
1217
13- # replace base functions with libdevice alternatives
14- # TODO : do this with Cassette.jl
18+ # # replace base functions with libdevice alternatives
1519
1620cufunc (f) = f
1721cufunc (:: Type{T} ) where T = (x... ) -> T (x... ) # broadcasting type ctors isn't GPU compatible
1822
19- Broadcast. broadcasted (:: ArrayStyle{CuArray } , f, args... ) =
20- Broadcasted {ArrayStyle{CuArray }} (cufunc (f), args, nothing )
23+ Broadcast. broadcasted (:: CuArrayStyle{N } , f, args... ) where {N} =
24+ Broadcasted {CuArrayStyle{N }} (cufunc (f), args, nothing )
2125
22- libdevice = :[
26+ const libdevice = :[
2327 cos, cospi, sin, sinpi, tan, acos, asin, atan,
2428 cosh, sinh, tanh, acosh, asinh, atanh,
2529 log, log10, log1p, log2, logb, ilogb,
@@ -40,7 +44,8 @@ for f in libdevice
4044 @eval cufunc (:: typeof (Base.$ f)) = CUDAnative.$ f
4145end
4246
43- # broadcast ^
47+ # broadcast ^
48+
4449culiteral_pow (:: typeof (^ ), x:: T , :: Val{0} ) where {T<: Real } = one (x)
4550culiteral_pow (:: typeof (^ ), x:: T , :: Val{1} ) where {T<: Real } = x
4651culiteral_pow (:: typeof (^ ), x:: T , :: Val{2} ) where {T<: Real } = x * x
0 commit comments