Skip to content

Commit 55138e8

Browse files
committed
Use grid-stride loop for fill!
1 parent 26237ff commit 55138e8

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/host/construction.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,22 @@ function Base.fill!(A::AnyGPUArray{T}, x) where T
1414

1515
@kernel function fill_kernel!(a, val)
1616
idx = @index(Global, Linear)
17-
@inbounds a[idx] = val
17+
stride = prod(@ndrange())
18+
while idx <= length(a)
19+
@inbounds a[idx] = val
20+
idx += stride
21+
end
1822
end
1923

2024
# ndims check for 0D support
2125
kernel = fill_kernel!(get_backend(A))
22-
kernel(A, x; ndrange = length(A))
26+
27+
# Calculate ndrange to ensure that a total grid size >typemax(UInt32) is never
28+
# chosen. Grid stride to accomodate grid size limitations on AMD and Metal backends
29+
len = length(A)
30+
ndrange = cld(len, cld(len, typemax(UInt32) - 1024))
31+
32+
kernel(A, x; ndrange)
2333
A
2434
end
2535

0 commit comments

Comments
 (0)