diff --git a/src/triangular.jl b/src/triangular.jl index 26ad4204..64bf85e4 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) = Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) = _shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false -@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} = - _shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T)) -@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) = - _shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j) +@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} + if _shouldforwardindex(A, i, j) + A.data[i,j] + else + @boundscheck checkbounds(A, i, j) + ifelse(i == j, oneunit(T), zero(T)) + end +end +@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) + if _shouldforwardindex(A, i, j) + A.data[i,j] + else + @boundscheck checkbounds(A, i, j) + @inbounds diagzero(A,i,j) + end +end _shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0 _shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0 @@ -250,63 +262,97 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0 # these specialized getindex methods enable constant-propagation of the band Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T} - _shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T)) + if _shouldforwardindex(A, b) + A.data[b] + else + @boundscheck checkbounds(A, b) + ifelse(b.band == 0, oneunit(T), zero(T)) + end end Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex) - _shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b) + if _shouldforwardindex(A, b) + A.data[b] + else + @boundscheck checkbounds(A, b) + @inbounds diagzero(A, b) + end end -_zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower" -_zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper" - -@noinline function throw_nonzeroerror(T, @nospecialize(x), i, j) - Ts = _zero_triangular_half_str(T) - Tn = nameof(T) +@noinline function throw_nonzeroerror(Tn::Symbol, @nospecialize(x), i, j) + zero_half = Tn in (:UpperTriangular, :UnitUpperTriangular) ? "lower" : "upper" + nstr = Tn === :UpperTriangular ? "n" : "" throw(ArgumentError( - lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)")) + LazyString( + lazy"cannot set index ($i, $j) in the $zero_half triangular part ", + lazy"of a$nstr $Tn matrix to a nonzero value ($x)") + ) + ) end -@noinline function throw_nononeerror(T, @nospecialize(x), i, j) - Tn = nameof(T) +@noinline function throw_nonuniterror(Tn::Symbol, @nospecialize(x), i, j) throw(ArgumentError( - lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)")) + lazy"cannot set index ($i, $j) on the diagonal of a $Tn matrix to a non-unit value ($x)")) end @propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer) - if i > j - iszero(x) || throw_nonzeroerror(typeof(A), x, i, j) - else + if _shouldforwardindex(A, i, j) A.data[i,j] = x + else + @boundscheck checkbounds(A, i, j) + # the value must be convertible to the eltype for setindex! to be meaningful + # however, the converted value is unused, and the compiler is free to remove + # the conversion if the call is guaranteed to succeed + convert(eltype(A), x) + iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j) end return A end @propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer) - if i > j - iszero(x) || throw_nonzeroerror(typeof(A), x, i, j) - elseif i == j - x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j) - else + if _shouldforwardindex(A, i, j) A.data[i,j] = x + else + @boundscheck checkbounds(A, i, j) + # the value must be convertible to the eltype for setindex! to be meaningful + # however, the converted value is unused, and the compiler is free to remove + # the conversion if the call is guaranteed to succeed + convert(eltype(A), x) + if i == j # diagonal + x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j) + else + iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j) + end end return A end @propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer) - if i < j - iszero(x) || throw_nonzeroerror(typeof(A), x, i, j) - else + if _shouldforwardindex(A, i, j) A.data[i,j] = x + else + @boundscheck checkbounds(A, i, j) + # the value must be convertible to the eltype for setindex! to be meaningful + # however, the converted value is unused, and the compiler is free to remove + # the conversion if the call is guaranteed to succeed + convert(eltype(A), x) + iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j) end return A end @propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer) - if i < j - iszero(x) || throw_nonzeroerror(typeof(A), x, i, j) - elseif i == j - x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j) - else + if _shouldforwardindex(A, i, j) A.data[i,j] = x + else + @boundscheck checkbounds(A, i, j) + # the value must be convertible to the eltype for setindex! to be meaningful + # however, the converted value is unused, and the compiler is free to remove + # the conversion if the call is guaranteed to succeed + convert(eltype(A), x) + if i == j # diagonal + x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j) + else + iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j) + end end return A end @@ -560,7 +606,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un @eval @inline function _copy!(A::$UT, B::$T) for dind in diagind(A, IndexStyle(A)) if A[dind] != B[dind] - throw_nononeerror(typeof(A), B[dind], Tuple(dind)...) + throw_nonuniterror(nameof(typeof(A)), B[dind], Tuple(dind)...) end end _copy!($T(parent(A)), B) @@ -741,7 +787,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu checksize1(A, B) _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) - @inbounds _modify!(_add, c, A, (j,j)) + @inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j)) for i in firstindex(B.data,1):(j - 1) @inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j)) end @@ -752,7 +798,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang checksize1(A, B) _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) - @inbounds _modify!(_add, c, A, (j,j)) + @inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j)) for i in firstindex(B.data,1):(j - 1) @inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j)) end @@ -783,7 +829,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu checksize1(A, B) _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) - @inbounds _modify!(_add, c, A, (j,j)) + @inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j)) for i in (j + 1):lastindex(B.data,1) @inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j)) end @@ -794,7 +840,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang checksize1(A, B) _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) - @inbounds _modify!(_add, c, A, (j,j)) + @inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j)) for i in (j + 1):lastindex(B.data,1) @inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j)) end diff --git a/test/triangular.jl b/test/triangular.jl index 21fc2d69..0ad4f521 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -641,11 +641,11 @@ end @testset "error message" begin A = UpperTriangular(Ap) B = UpperTriangular(Bp) - @test_throws "cannot set index in the lower triangular part" copyto!(A, B) + @test_throws "cannot set index (3, 1) in the lower triangular part" copyto!(A, B) A = LowerTriangular(Ap) B = LowerTriangular(Bp) - @test_throws "cannot set index in the upper triangular part" copyto!(A, B) + @test_throws "cannot set index (1, 2) in the upper triangular part" copyto!(A, B) end end @@ -950,6 +950,10 @@ end @test 2\U == 2\M @test U*2 == M*2 @test 2*U == 2*M + + U2 = copy(U) + @test rmul!(U, 1) == U2 + @test lmul!(1, U) == U2 end @testset "scaling partly initialized unit triangular" begin @@ -966,4 +970,75 @@ end end end +@testset "indexing checks" begin + P = [1 2; 3 4] + @testset "getindex" begin + U = UnitUpperTriangular(P) + @test_throws BoundsError U[0,0] + @test_throws BoundsError U[1,0] + @test_throws BoundsError U[BandIndex(0,0)] + @test_throws BoundsError U[BandIndex(-1,0)] + + U = UpperTriangular(P) + @test_throws BoundsError U[1,0] + @test_throws BoundsError U[BandIndex(-1,0)] + + L = UnitLowerTriangular(P) + @test_throws BoundsError L[0,0] + @test_throws BoundsError L[0,1] + @test_throws BoundsError U[BandIndex(0,0)] + @test_throws BoundsError U[BandIndex(1,0)] + + L = LowerTriangular(P) + @test_throws BoundsError L[0,1] + @test_throws BoundsError L[BandIndex(1,0)] + end + @testset "setindex!" begin + A = SizedArrays.SizedArray{(2,2)}(P) + M = fill(A, 2, 2) + U = UnitUpperTriangular(M) + @test_throws "Cannot `convert` an object of type $Int" U[1,1] = 1 + non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitUpperTriangular matrix to a non-unit value" + @test_throws non_unit_msg U[1,1] = A + L = UnitLowerTriangular(M) + @test_throws "Cannot `convert` an object of type $Int" L[1,1] = 1 + non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitLowerTriangular matrix to a non-unit value" + @test_throws non_unit_msg L[1,1] = A + + for UT in (UnitUpperTriangular, UpperTriangular) + U = UT(M) + @test_throws "Cannot `convert` an object of type $Int" U[2,1] = 0 + end + for LT in (UnitLowerTriangular, LowerTriangular) + L = LT(M) + @test_throws "Cannot `convert` an object of type $Int" L[1,2] = 0 + end + + U = UnitUpperTriangular(P) + @test_throws BoundsError U[0,0] = 1 + @test_throws BoundsError U[1,0] = 0 + + U = UpperTriangular(P) + @test_throws BoundsError U[1,0] = 0 + + L = UnitLowerTriangular(P) + @test_throws BoundsError L[0,0] = 1 + @test_throws BoundsError L[0,1] = 0 + + L = LowerTriangular(P) + @test_throws BoundsError L[0,1] = 0 + end +end + +@testset "unit triangular l/rdiv!" begin + A = rand(3,3) + @testset for (UT,T) in ((UnitUpperTriangular, UpperTriangular), + (UnitLowerTriangular, LowerTriangular)) + UnitTri = UT(A) + Tri = T(LinearAlgebra.full(UnitTri)) + @test 2 \ UnitTri ≈ 2 \ Tri + @test UnitTri / 2 ≈ Tri / 2 + end +end + end # module TestTriangular