|
| 1 | +# Contiguous on-device arrays |
| 2 | + |
| 3 | +export CLDeviceArray, CLDeviceVector, CLDeviceMatrix, CLLocalArray |
| 4 | + |
| 5 | + |
| 6 | +## construction |
| 7 | + |
| 8 | +# NOTE: we can't support the typical `tuple or series of integer` style construction, |
| 9 | +# because we're currently requiring a trailing pointer argument. |
| 10 | + |
| 11 | +struct CLDeviceArray{T,N,A} <: DenseArray{T,N} |
| 12 | + ptr::LLVMPtr{T,A} |
| 13 | + maxsize::Int |
| 14 | + |
| 15 | + dims::Dims{N} |
| 16 | + len::Int |
| 17 | + |
| 18 | + # inner constructors, fully parameterized, exact types (ie. Int not <:Integer) |
| 19 | + # TODO: deprecate; put `ptr` first like oneArray |
| 20 | + CLDeviceArray{T,N,A}(dims::Dims{N}, ptr::LLVMPtr{T,A}, |
| 21 | + maxsize::Int=prod(dims)*sizeof(T)) where {T,A,N} = |
| 22 | + new(ptr, maxsize, dims, prod(dims)) |
| 23 | +end |
| 24 | + |
| 25 | +const CLDeviceVector = CLDeviceArray{T,1,A} where {T,A} |
| 26 | +const CLDeviceMatrix = CLDeviceArray{T,2,A} where {T,A} |
| 27 | + |
| 28 | +# outer constructors, non-parameterized |
| 29 | +CLDeviceArray(dims::NTuple{N,<:Integer}, p::LLVMPtr{T,A}) where {T,A,N} = CLDeviceArray{T,N,A}(dims, p) |
| 30 | +CLDeviceArray(len::Integer, p::LLVMPtr{T,A}) where {T,A} = CLDeviceVector{T,A}((len,), p) |
| 31 | + |
| 32 | +# outer constructors, partially parameterized |
| 33 | +CLDeviceArray{T}(dims::NTuple{N,<:Integer}, p::LLVMPtr{T,A}) where {T,A,N} = CLDeviceArray{T,N,A}(dims, p) |
| 34 | +CLDeviceArray{T}(len::Integer, p::LLVMPtr{T,A}) where {T,A} = CLDeviceVector{T,A}((len,), p) |
| 35 | +CLDeviceArray{T,N}(dims::NTuple{N,<:Integer}, p::LLVMPtr{T,A}) where {T,A,N} = CLDeviceArray{T,N,A}(dims, p) |
| 36 | +CLDeviceVector{T}(len::Integer, p::LLVMPtr{T,A}) where {T,A} = CLDeviceVector{T,A}((len,), p) |
| 37 | + |
| 38 | +# outer constructors, fully parameterized |
| 39 | +CLDeviceArray{T,N,A}(dims::NTuple{N,<:Integer}, p::LLVMPtr{T,A}) where {T,A,N} = CLDeviceArray{T,N,A}(Int.(dims), p) |
| 40 | +CLDeviceVector{T,A}(len::Integer, p::LLVMPtr{T,A}) where {T,A} = CLDeviceVector{T,A}((Int(len),), p) |
| 41 | + |
| 42 | + |
| 43 | +## array interface |
| 44 | + |
| 45 | +Base.elsize(::Type{<:CLDeviceArray{T}}) where {T} = sizeof(T) |
| 46 | + |
| 47 | +Base.size(g::CLDeviceArray) = g.dims |
| 48 | +Base.sizeof(x::CLDeviceArray) = Base.elsize(x) * length(x) |
| 49 | + |
| 50 | +# we store the array length too; computing prod(size) is expensive |
| 51 | +Base.length(g::CLDeviceArray) = g.len |
| 52 | + |
| 53 | +Base.pointer(x::CLDeviceArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(LLVMPtr{T,A}, x) |
| 54 | +@inline function Base.pointer(x::CLDeviceArray{T,<:Any,A}, i::Integer) where {T,A} |
| 55 | + Base.unsafe_convert(LLVMPtr{T,A}, x) + Base._memory_offset(x, i) |
| 56 | +end |
| 57 | + |
| 58 | +typetagdata(a::CLDeviceArray{<:Any,<:Any,A}, i=1) where {A} = |
| 59 | + reinterpret(LLVMPtr{UInt8,A}, a.ptr + a.maxsize) + i - one(i) |
| 60 | + |
| 61 | + |
| 62 | +## conversions |
| 63 | + |
| 64 | +Base.unsafe_convert(::Type{LLVMPtr{T,A}}, x::CLDeviceArray{T,<:Any,A}) where {T,A} = |
| 65 | + x.ptr |
| 66 | + |
| 67 | + |
| 68 | +## indexing intrinsics |
| 69 | + |
| 70 | +# TODO: how are allocations aligned by the level zero API? keep track of this |
| 71 | +# because it enables optimizations like Load Store Vectorization |
| 72 | +# (cfr. shared memory and its wider-than-datatype alignment) |
| 73 | + |
| 74 | +@generated function alignment(::CLDeviceArray{T}) where {T} |
| 75 | + if Base.isbitsunion(T) |
| 76 | + _, sz, al = Base.uniontype_layout(T) |
| 77 | + al |
| 78 | + else |
| 79 | + Base.datatype_alignment(T) |
| 80 | + end |
| 81 | +end |
| 82 | + |
| 83 | +@device_function @inline function arrayref(A::CLDeviceArray{T}, index::Integer) where {T} |
| 84 | + @boundscheck checkbounds(A, index) |
| 85 | + if isbitstype(T) |
| 86 | + arrayref_bits(A, index) |
| 87 | + else #if isbitsunion(T) |
| 88 | + arrayref_union(A, index) |
| 89 | + end |
| 90 | +end |
| 91 | + |
| 92 | +@inline function arrayref_bits(A::CLDeviceArray{T}, index::Integer) where {T} |
| 93 | + align = alignment(A) |
| 94 | + unsafe_load(pointer(A), index, Val(align)) |
| 95 | +end |
| 96 | + |
| 97 | +@inline @generated function arrayref_union(A::CLDeviceArray{T,<:Any,AS}, index::Integer) where {T,AS} |
| 98 | + typs = Base.uniontypes(T) |
| 99 | + |
| 100 | + # generate code that conditionally loads a value based on the selector value. |
| 101 | + # lacking noreturn, we return T to avoid inference thinking this can return Nothing. |
| 102 | + ex = :(Base.llvmcall("unreachable", $T, Tuple{})) |
| 103 | + for (sel, typ) in Iterators.reverse(enumerate(typs)) |
| 104 | + ex = quote |
| 105 | + if selector == $(sel-1) |
| 106 | + ptr = reinterpret(LLVMPtr{$typ,AS}, data_ptr) |
| 107 | + unsafe_load(ptr, 1, Val(align)) |
| 108 | + else |
| 109 | + $ex |
| 110 | + end |
| 111 | + end |
| 112 | + end |
| 113 | + |
| 114 | + quote |
| 115 | + selector_ptr = typetagdata(A, index) |
| 116 | + selector = unsafe_load(selector_ptr) |
| 117 | + |
| 118 | + align = alignment(A) |
| 119 | + data_ptr = pointer(A, index) |
| 120 | + |
| 121 | + return $ex |
| 122 | + end |
| 123 | +end |
| 124 | + |
| 125 | +@device_function @inline function arrayset(A::CLDeviceArray{T}, x::T, index::Integer) where {T} |
| 126 | + @boundscheck checkbounds(A, index) |
| 127 | + if isbitstype(T) |
| 128 | + arrayset_bits(A, x, index) |
| 129 | + else #if isbitsunion(T) |
| 130 | + arrayset_union(A, x, index) |
| 131 | + end |
| 132 | + return A |
| 133 | +end |
| 134 | + |
| 135 | +@inline function arrayset_bits(A::CLDeviceArray{T}, x::T, index::Integer) where {T} |
| 136 | + align = alignment(A) |
| 137 | + unsafe_store!(pointer(A), x, index, Val(align)) |
| 138 | +end |
| 139 | + |
| 140 | +@inline @generated function arrayset_union(A::CLDeviceArray{T,<:Any,AS}, x::T, index::Integer) where {T,AS} |
| 141 | + typs = Base.uniontypes(T) |
| 142 | + sel = findfirst(isequal(x), typs) |
| 143 | + |
| 144 | + quote |
| 145 | + selector_ptr = typetagdata(A, index) |
| 146 | + unsafe_store!(selector_ptr, $(UInt8(sel-1))) |
| 147 | + |
| 148 | + align = alignment(A) |
| 149 | + data_ptr = pointer(A, index) |
| 150 | + |
| 151 | + unsafe_store!(reinterpret(LLVMPtr{$x,AS}, data_ptr), x, 1, Val(align)) |
| 152 | + return |
| 153 | + end |
| 154 | +end |
| 155 | + |
| 156 | +@device_function @inline function const_arrayref(A::CLDeviceArray{T}, index::Integer) where {T} |
| 157 | + @boundscheck checkbounds(A, index) |
| 158 | + align = alignment(A) |
| 159 | + unsafe_cached_load(pointer(A), index, Val(align)) |
| 160 | +end |
| 161 | + |
| 162 | + |
| 163 | +## indexing |
| 164 | + |
| 165 | +Base.IndexStyle(::Type{<:CLDeviceArray}) = Base.IndexLinear() |
| 166 | + |
| 167 | +Base.@propagate_inbounds Base.getindex(A::CLDeviceArray{T}, i1::Integer) where {T} = |
| 168 | + arrayref(A, i1) |
| 169 | +Base.@propagate_inbounds Base.setindex!(A::CLDeviceArray{T}, x, i1::Integer) where {T} = |
| 170 | + arrayset(A, convert(T,x)::T, i1) |
| 171 | + |
| 172 | +# preserve the specific integer type when indexing device arrays, |
| 173 | +# to avoid extending 32-bit hardware indices to 64-bit. |
| 174 | +Base.to_index(::CLDeviceArray, i::Integer) = i |
| 175 | + |
| 176 | +# Base doesn't like Integer indices, so we need our own ND get and setindex! routines. |
| 177 | +# See also: https://github.com/JuliaLang/julia/pull/42289 |
| 178 | +Base.@propagate_inbounds Base.getindex(A::CLDeviceArray, |
| 179 | + I::Union{Integer, CartesianIndex}...) = |
| 180 | + A[Base._to_linear_index(A, to_indices(A, I)...)] |
| 181 | +Base.@propagate_inbounds Base.setindex!(A::CLDeviceArray, x, |
| 182 | + I::Union{Integer, CartesianIndex}...) = |
| 183 | + A[Base._to_linear_index(A, to_indices(A, I)...)] = x |
| 184 | + |
| 185 | + |
| 186 | +## const indexing |
| 187 | + |
| 188 | +""" |
| 189 | + Const(A::CLDeviceArray) |
| 190 | +
|
| 191 | +Mark a CLDeviceArray as constant/read-only. The invariant guaranteed is that you will not |
| 192 | +modify an CLDeviceArray for the duration of the current kernel. |
| 193 | +
|
| 194 | +This API can only be used on devices with compute capability 3.5 or higher. |
| 195 | +
|
| 196 | +!!! warning |
| 197 | + Experimental API. Subject to change without deprecation. |
| 198 | +""" |
| 199 | +struct Const{T,N,AS} <: DenseArray{T,N} |
| 200 | + a::CLDeviceArray{T,N,AS} |
| 201 | +end |
| 202 | +Base.Experimental.Const(A::CLDeviceArray) = Const(A) |
| 203 | + |
| 204 | +Base.IndexStyle(::Type{<:Const}) = IndexLinear() |
| 205 | +Base.size(C::Const) = size(C.a) |
| 206 | +Base.axes(C::Const) = axes(C.a) |
| 207 | +Base.@propagate_inbounds Base.getindex(A::Const, i1::Integer) = const_arrayref(A.a, i1) |
| 208 | + |
| 209 | + |
| 210 | +## other |
| 211 | + |
| 212 | +Base.show(io::IO, a::CLDeviceVector) = |
| 213 | + print(io, "$(length(a))-element device array at $(pointer(a))") |
| 214 | +Base.show(io::IO, a::CLDeviceArray) = |
| 215 | + print(io, "$(join(a.shape, '×')) device array at $(pointer(a))") |
| 216 | + |
| 217 | +Base.show(io::IO, mime::MIME"text/plain", a::CLDeviceArray) = show(io, a) |
| 218 | + |
| 219 | +@inline function Base.iterate(A::CLDeviceArray, i=1) |
| 220 | + if (i % UInt) - 1 < length(A) |
| 221 | + (@inbounds A[i], i + 1) |
| 222 | + else |
| 223 | + nothing |
| 224 | + end |
| 225 | +end |
| 226 | + |
| 227 | +function Base.reinterpret(::Type{T}, a::CLDeviceArray{S,N,A}) where {T,S,N,A} |
| 228 | + err = _reinterpret_exception(T, a) |
| 229 | + err === nothing || throw(err) |
| 230 | + |
| 231 | + if sizeof(T) == sizeof(S) # fast case |
| 232 | + return CLDeviceArray{T,N,A}(size(a), reinterpret(LLVMPtr{T,A}, a.ptr), a.maxsize) |
| 233 | + end |
| 234 | + |
| 235 | + isize = size(a) |
| 236 | + size1 = div(isize[1]*sizeof(S), sizeof(T)) |
| 237 | + osize = tuple(size1, Base.tail(isize)...) |
| 238 | + return CLDeviceArray{T,N,A}(osize, reinterpret(LLVMPtr{T,A}, a.ptr), a.maxsize) |
| 239 | +end |
| 240 | + |
| 241 | + |
| 242 | +## local memory |
| 243 | + |
| 244 | +# XXX: use OpenCL-style local memory arguments instead? |
| 245 | + |
| 246 | +@inline function CLLocalArray(::Type{T}, dims) where {T} |
| 247 | + len = prod(dims) |
| 248 | + # NOTE: this relies on const-prop to forward the literal length to the generator. |
| 249 | + # maybe we should include the size in the type, like StaticArrays does? |
| 250 | + ptr = emit_localmemory(T, Val(len)) |
| 251 | + CLDeviceArray(dims, ptr) |
| 252 | +end |
0 commit comments