Skip to content

Bounds-checking in triangular indexing branches #1305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
120 changes: 83 additions & 37 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) =
_shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false

@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} =
_shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) =
_shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j)
@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T}
if _shouldforwardindex(A, i, j)
A.data[i,j]
else
@boundscheck checkbounds(A, i, j)
ifelse(i == j, oneunit(T), zero(T))
end
end
@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int)
if _shouldforwardindex(A, i, j)
A.data[i,j]
else
@boundscheck checkbounds(A, i, j)
@inbounds diagzero(A,i,j)
end
end

_shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0
_shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0
Expand All @@ -250,63 +262,97 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0

# these specialized getindex methods enable constant-propagation of the band
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T}
_shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
if _shouldforwardindex(A, b)
A.data[b]
else
@boundscheck checkbounds(A, b)
ifelse(b.band == 0, oneunit(T), zero(T))
end
end
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex)
_shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b)
if _shouldforwardindex(A, b)
A.data[b]
else
@boundscheck checkbounds(A, b)
@inbounds diagzero(A, b)
end
end

_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower"
_zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper"

@noinline function throw_nonzeroerror(T, @nospecialize(x), i, j)
Ts = _zero_triangular_half_str(T)
Tn = nameof(T)
@noinline function throw_nonzeroerror(Tn::Symbol, @nospecialize(x), i, j)
zero_half = Tn in (:UpperTriangular, :UnitUpperTriangular) ? "lower" : "upper"
nstr = Tn === :UpperTriangular ? "n" : ""
throw(ArgumentError(
lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)"))
LazyString(
lazy"cannot set index ($i, $j) in the $zero_half triangular part ",
lazy"of a$nstr $Tn matrix to a nonzero value ($x)")
)
)
end
@noinline function throw_nononeerror(T, @nospecialize(x), i, j)
Tn = nameof(T)
@noinline function throw_nonuniterror(Tn::Symbol, @nospecialize(x), i, j)
throw(ArgumentError(
lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)"))
lazy"cannot set index ($i, $j) on the diagonal of a $Tn matrix to a non-unit value ($x)"))
end

@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
if i > j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
return A
end

@propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer)
if i > j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
elseif i == j
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
if i == j # diagonal
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
else
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
end
return A
end

@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
if i < j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
return A
end

@propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer)
if i < j
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
elseif i == j
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
else
if _shouldforwardindex(A, i, j)
A.data[i,j] = x
else
@boundscheck checkbounds(A, i, j)
# the value must be convertible to the eltype for setindex! to be meaningful
# however, the converted value is unused, and the compiler is free to remove
# the conversion if the call is guaranteed to succeed
convert(eltype(A), x)
if i == j # diagonal
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
else
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
end
end
return A
end
Expand Down Expand Up @@ -560,7 +606,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
@eval @inline function _copy!(A::$UT, B::$T)
for dind in diagind(A, IndexStyle(A))
if A[dind] != B[dind]
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
throw_nonuniterror(nameof(typeof(A)), B[dind], Tuple(dind)...)
end
end
_copy!($T(parent(A)), B)
Expand Down Expand Up @@ -741,7 +787,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
for i in firstindex(B.data,1):(j - 1)
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
end
Expand All @@ -752,7 +798,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
for i in firstindex(B.data,1):(j - 1)
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
end
Expand Down Expand Up @@ -783,7 +829,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
for i in (j + 1):lastindex(B.data,1)
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
end
Expand All @@ -794,7 +840,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
checksize1(A, B)
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
for j in axes(B.data,2)
@inbounds _modify!(_add, c, A, (j,j))
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
for i in (j + 1):lastindex(B.data,1)
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
end
Expand Down
79 changes: 77 additions & 2 deletions test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,11 +641,11 @@ end
@testset "error message" begin
A = UpperTriangular(Ap)
B = UpperTriangular(Bp)
@test_throws "cannot set index in the lower triangular part" copyto!(A, B)
@test_throws "cannot set index (3, 1) in the lower triangular part" copyto!(A, B)

A = LowerTriangular(Ap)
B = LowerTriangular(Bp)
@test_throws "cannot set index in the upper triangular part" copyto!(A, B)
@test_throws "cannot set index (1, 2) in the upper triangular part" copyto!(A, B)
end
end

Expand Down Expand Up @@ -950,6 +950,10 @@ end
@test 2\U == 2\M
@test U*2 == M*2
@test 2*U == 2*M

U2 = copy(U)
@test rmul!(U, 1) == U2
@test lmul!(1, U) == U2
end

@testset "scaling partly initialized unit triangular" begin
Expand All @@ -966,4 +970,75 @@ end
end
end

@testset "indexing checks" begin
P = [1 2; 3 4]
@testset "getindex" begin
U = UnitUpperTriangular(P)
@test_throws BoundsError U[0,0]
@test_throws BoundsError U[1,0]
@test_throws BoundsError U[BandIndex(0,0)]
@test_throws BoundsError U[BandIndex(-1,0)]

U = UpperTriangular(P)
@test_throws BoundsError U[1,0]
@test_throws BoundsError U[BandIndex(-1,0)]

L = UnitLowerTriangular(P)
@test_throws BoundsError L[0,0]
@test_throws BoundsError L[0,1]
@test_throws BoundsError U[BandIndex(0,0)]
@test_throws BoundsError U[BandIndex(1,0)]

L = LowerTriangular(P)
@test_throws BoundsError L[0,1]
@test_throws BoundsError L[BandIndex(1,0)]
end
@testset "setindex!" begin
A = SizedArrays.SizedArray{(2,2)}(P)
M = fill(A, 2, 2)
U = UnitUpperTriangular(M)
@test_throws "Cannot `convert` an object of type $Int" U[1,1] = 1
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitUpperTriangular matrix to a non-unit value"
@test_throws non_unit_msg U[1,1] = A
L = UnitLowerTriangular(M)
@test_throws "Cannot `convert` an object of type $Int" L[1,1] = 1
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitLowerTriangular matrix to a non-unit value"
@test_throws non_unit_msg L[1,1] = A

for UT in (UnitUpperTriangular, UpperTriangular)
U = UT(M)
@test_throws "Cannot `convert` an object of type $Int" U[2,1] = 0
end
for LT in (UnitLowerTriangular, LowerTriangular)
L = LT(M)
@test_throws "Cannot `convert` an object of type $Int" L[1,2] = 0
end

U = UnitUpperTriangular(P)
@test_throws BoundsError U[0,0] = 1
@test_throws BoundsError U[1,0] = 0

U = UpperTriangular(P)
@test_throws BoundsError U[1,0] = 0

L = UnitLowerTriangular(P)
@test_throws BoundsError L[0,0] = 1
@test_throws BoundsError L[0,1] = 0

L = LowerTriangular(P)
@test_throws BoundsError L[0,1] = 0
end
end

@testset "unit triangular l/rdiv!" begin
A = rand(3,3)
@testset for (UT,T) in ((UnitUpperTriangular, UpperTriangular),
(UnitLowerTriangular, LowerTriangular))
UnitTri = UT(A)
Tri = T(LinearAlgebra.full(UnitTri))
@test 2 \ UnitTri ≈ 2 \ Tri
@test UnitTri / 2 ≈ Tri / 2
end
end

end # module TestTriangular