Skip to content

Commit 5342857

Browse files
N5N3LilithHafner
authored andcommitted
Fix stride(A, i) for 0-dim inputs (JuliaLang#44090)
Fixes JuliaLang#44087
1 parent 560eccf commit 5342857

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

base/abstractarray.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,13 @@ julia> stride(A,3)
546546
function stride(A::AbstractArray, k::Integer)
547547
st = strides(A)
548548
k ndims(A) && return st[k]
549-
return sum(st .* size(A))
549+
ndims(A) == 0 && return 1
550+
sz = size(A)
551+
s = st[1] * sz[1]
552+
for i in 2:ndims(A)
553+
s += st[i] * sz[i]
554+
end
555+
return s
550556
end
551557

552558
@inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...)

base/reinterpretarray.jl

+2
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ StridedMatrix{T} = StridedArray{T,2}
149149
StridedVecOrMat{T} = Union{StridedVector{T}, StridedMatrix{T}}
150150

151151
strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)
152+
stride(A::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, k::Integer) =
153+
k ndims(A) ? strides(A)[k] : length(A)
152154

153155
function strides(a::ReshapedReinterpretArray)
154156
ap = parent(a)

test/abstractarray.jl

+13
Original file line numberDiff line numberDiff line change
@@ -1584,6 +1584,19 @@ end
15841584
end
15851585
end
15861586

1587+
@testset "stride for 0 dims array #44087" begin
1588+
struct Fill44087 <: AbstractArray{Int,0}
1589+
a::Int
1590+
end
1591+
# `stride` shouldn't work if `strides` is not defined.
1592+
@test_throws MethodError stride(Fill44087(1), 1)
1593+
# It is intentionally to only check the return type. (The value is somehow arbitrary)
1594+
@test stride(fill(1), 1) isa Int
1595+
@test stride(reinterpret(Float64, fill(Int64(1))), 1) isa Int
1596+
@test stride(reinterpret(reshape, Float64, fill(Int64(1))), 1) isa Int
1597+
@test stride(Base.ReshapedArray(fill(1), (), ()), 1) isa Int
1598+
end
1599+
15871600
@testset "to_indices inference (issue #42001 #44059)" begin
15881601
@test (@inferred to_indices([], ntuple(Returns(CartesianIndex(1)), 32))) == ntuple(Returns(1), 32)
15891602
@test (@inferred to_indices([], ntuple(Returns(CartesianIndices(1:1)), 32))) == ntuple(Returns(Base.OneTo(1)), 32)

0 commit comments

Comments
 (0)