Skip to content

Commit 8880ad4

Browse files
committed
tests pass up to random now.
1 parent 23eaa50 commit 8880ad4

File tree

6 files changed

+22
-13
lines changed

6 files changed

+22
-13
lines changed

lib/JLArrays/src/JLArrays.jl

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, Dyn
1313

1414
export JLArray, JLVector, JLMatrix, jl, JLBackend
1515

16+
#
17+
# Device functionality
18+
#
19+
20+
const MAXTHREADS = 256
21+
1622
struct JLBackend <: KernelAbstractions.GPU
1723
static::Bool
1824
JLBackend(;static::Bool=false) = new(static)

src/host/broadcast.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ end
4747
@inbounds dest[I] = bc′[I]
4848
end
4949

50-
broadcast_kernel(get_backend(dest))(dest, bc′, ndrange = size(dest))
50+
# ndrange set for a possible 0D evaluation
51+
broadcast_kernel(get_backend(dest))(dest, bc′, ndrange = length(size(dest)) > 0 ? size(dest) : (1,))
5152

5253
return dest
5354
end

src/host/construction.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ function Base.fill!(A::AnyGPUArray{T}, x) where T
1515
idx = @index(Global, Linear)
1616
@inbounds a[idx] = val
1717
end
18-
fill_kernel!(get_backend(A))(A, x, ndrange = size(A))
18+
19+
# ndrange set for a possible 0D evaluation
20+
fill_kernel!(get_backend(A))(A, x,
21+
ndrange = length(size(A)) > 0 ? size(A) : (1,))
1922
A
2023
end
2124

src/host/linalg.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
9292
@inbounds _A[j,i] = conj(_A[i,j])
9393
end
9494
end
95-
U_conj!(get_backend(_A), ndrange = size(_A))
95+
U_conj!(get_backend(A))(A, ndrange = size(A))
9696
elseif uplo == 'U' && !conjugate
9797
@kernel function U_noconj!(_A)
9898
I = @index(Global, Cartesian)
@@ -101,7 +101,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
101101
@inbounds _A[j,i] = _A[i,j]
102102
end
103103
end
104-
U_noconj!(get_backend(_A))(_A, ndrange=size(_A))
104+
U_noconj!(get_backend(A))(A, ndrange = size(A))
105105
elseif uplo == 'L' && conjugate
106106
@kernel function L_conj!(_A)
107107
I = @index(Global, Cartesian)
@@ -110,7 +110,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
110110
@inbounds _A[i,j] = conj(_A[j,i])
111111
end
112112
end
113-
L_conj!(get_backend(_A))(_A, ndrange = size(_A))
113+
L_conj!(get_backend(A))(A, ndrange = size(A))
114114
elseif uplo == 'L' && !conjugate
115115
@kernel function L_noconj!(_A)
116116
I = @index(Global, Cartesian)
@@ -119,7 +119,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
119119
@inbounds _A[i,j] = _A[j,i]
120120
end
121121
end
122-
L_noconj!(get_backend(_A))(_A, ndrange = size(_A))
122+
L_noconj!(get_backend(A))(A, ndrange = size(A))
123123
else
124124
throw(ArgumentError("uplo argument must be 'U' (upper) or 'L' (lower), got $uplo"))
125125
end
@@ -178,7 +178,7 @@ function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
178178
@inbounds _A[i, j] = zero(T)
179179
end
180180
end
181-
tril_kernel!(get_backend(_A))(_A, _d, ndrange = size(_A))
181+
tril_kernel!(get_backend(A))(A, d, ndrange = size(A))
182182
return A
183183
end
184184

@@ -190,7 +190,7 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
190190
@inbounds _A[i, j] = zero(T)
191191
end
192192
end
193-
triu_kernel!(get_backend(_A))(_A, _d, ndrange = length(_A))
193+
triu_kernel!(get_backend(A))(A, d, ndrange = size(A))
194194
return A
195195
end
196196

@@ -423,7 +423,7 @@ LinearAlgebra.rmul!(A::AbstractGPUArray, b::Number) = generic_rmul!(A, b)
423423

424424
function generic_lmul!(s::Number, X::AbstractArray)
425425
@kernel function lmul_kernel!(X, s)
426-
i = @index(Global, linear)
426+
i = @index(Global, Linear)
427427
@inbounds X[i] = s*X[i]
428428
end
429429
lmul_kernel!(get_backend(X))(X, s, ndrange = size(X))

src/host/math.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Base mathematical operations
22

33
function Base.clamp!(A::AnyGPUArray, low, high)
4-
@kernel function clamp_kernel!(A::AnyGPUArray, low, high)
4+
@kernel function clamp_kernel!(A, low, high)
55
I = @index(Global, Cartesian)
66
A[I] = clamp(A[I], low, high)
77
end

src/host/random.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,9 @@ end
8686
function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
8787
@kernel function rand!(a, randstate)
8888
idx = @index(Global, Linear)
89-
@inbounds a[idx] = gpu_rand(T, idx, randstates)
89+
@inbounds a[idx] = gpu_rand(T, idx, randstate)
9090
end
91-
kernel = rand!(get_backend(A))
92-
kernel(A, rng.state)
91+
rand!(get_backend(A))(A, rng.state, ndrange = size(A))
9392
A
9493
end
9594

0 commit comments

Comments
 (0)