Skip to content

Commit addf9ab

Browse files
committed
Out-of-place triu/tril for Symmetric in each branch
1 parent 6e8f9a1 commit addf9ab

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

src/symmetric.jl

+13-13
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ Base.dataids(A::HermOrSym) = Base.dataids(parent(A))
289289
Base.unaliascopy(A::Hermitian) = Hermitian(Base.unaliascopy(parent(A)), sym_uplo(A.uplo))
290290
Base.unaliascopy(A::Symmetric) = Symmetric(Base.unaliascopy(parent(A)), sym_uplo(A.uplo))
291291

292-
_conjugation(::Symmetric) = transpose
292+
_conjugation(::Union{Symmetric, Hermitian{<:Real}}) = transpose
293293
_conjugation(::Hermitian) = adjoint
294294

295295
diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo))
@@ -472,49 +472,49 @@ Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo)
472472
# tril/triu
473473
function tril(A::Hermitian, k::Integer=0)
474474
if A.uplo == 'U' && k <= 0
475-
return tril!(copy(A.data'),k)
475+
return tril_maybe_inplace(copy(A.data'),k)
476476
elseif A.uplo == 'U' && k > 0
477-
return tril!(copy(A.data'),-1) + tril!(triu(A.data),k)
477+
return tril_maybe_inplace(copy(A.data'),-1) + tril_maybe_inplace(triu(A.data),k)
478478
elseif A.uplo == 'L' && k <= 0
479479
return tril(A.data,k)
480480
else
481-
return tril(A.data,-1) + tril!(triu!(copy(A.data')),k)
481+
return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(A.data')),k)
482482
end
483483
end
484484

485485
function tril(A::Symmetric, k::Integer=0)
486486
if A.uplo == 'U' && k <= 0
487-
return tril!(copy(transpose(A.data)),k)
487+
return tril_maybe_inplace(copy(transpose(A.data)),k)
488488
elseif A.uplo == 'U' && k > 0
489-
return tril!(copy(transpose(A.data)),-1) + tril!(triu(A.data),k)
489+
return tril_maybe_inplace(copy(transpose(A.data)),-1) + tril_maybe_inplace(triu(A.data),k)
490490
elseif A.uplo == 'L' && k <= 0
491491
return tril(A.data,k)
492492
else
493-
return tril(A.data,-1) + tril!(triu!(copy(transpose(A.data))),k)
493+
return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(transpose(A.data))),k)
494494
end
495495
end
496496

497497
function triu(A::Hermitian, k::Integer=0)
498498
if A.uplo == 'U' && k >= 0
499499
return triu(A.data,k)
500500
elseif A.uplo == 'U' && k < 0
501-
return triu(A.data,1) + triu!(tril!(copy(A.data')),k)
501+
return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(A.data')),k)
502502
elseif A.uplo == 'L' && k >= 0
503-
return triu!(copy(A.data'),k)
503+
return triu_maybe_inplace(copy(A.data'),k)
504504
else
505-
return triu!(copy(A.data'),1) + triu!(tril(A.data),k)
505+
return triu_maybe_inplace(copy(A.data'),1) + triu_maybe_inplace(tril(A.data),k)
506506
end
507507
end
508508

509509
function triu(A::Symmetric, k::Integer=0)
510510
if A.uplo == 'U' && k >= 0
511511
return triu(A.data,k)
512512
elseif A.uplo == 'U' && k < 0
513-
return triu(A.data,1) + triu!(tril!(copy(transpose(A.data))),k)
513+
return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(transpose(A.data))),k)
514514
elseif A.uplo == 'L' && k >= 0
515-
return triu!(copy(transpose(A.data)),k)
515+
return triu_maybe_inplace(copy(transpose(A.data)),k)
516516
else
517-
return triu!(copy(transpose(A.data)),1) + triu!(tril(A.data),k)
517+
return triu_maybe_inplace(copy(transpose(A.data)),1) + triu_maybe_inplace(tril(A.data),k)
518518
end
519519
end
520520

src/triangular.jl

+5
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,11 @@ function tril!(A::UnitLowerTriangular, k::Integer=0)
484484
return tril!(LowerTriangular(A.data), k)
485485
end
486486

487+
tril_maybe_inplace(A, k::Integer=0) = tril(A, k)
488+
triu_maybe_inplace(A, k::Integer=0) = triu(A, k)
489+
tril_maybe_inplace(A::StridedMatrix, k::Integer=0) = tril!(A, k)
490+
triu_maybe_inplace(A::StridedMatrix, k::Integer=0) = triu!(A, k)
491+
487492
adjoint(A::LowerTriangular) = UpperTriangular(adjoint(A.data))
488493
adjoint(A::UpperTriangular) = LowerTriangular(adjoint(A.data))
489494
adjoint(A::UnitLowerTriangular) = UnitUpperTriangular(adjoint(A.data))

test/symmetric.jl

+23
Original file line numberDiff line numberDiff line change
@@ -1191,4 +1191,27 @@ end
11911191
@test_throws s_msg S[1,1] = v
11921192
end
11931193

1194+
@testset "triu/tril with immutable arrays" begin
1195+
struct ImmutableMatrix{T,A<:AbstractMatrix{T}} <: AbstractMatrix{T}
1196+
a :: A
1197+
end
1198+
Base.size(A::ImmutableMatrix) = size(A.a)
1199+
Base.getindex(A::ImmutableMatrix, i::Int, j::Int) = getindex(A.a, i, j)
1200+
Base.copy(A::ImmutableMatrix) = A
1201+
LinearAlgebra.adjoint(A::ImmutableMatrix) = ImmutableMatrix(adjoint(A.a))
1202+
LinearAlgebra.transpose(A::ImmutableMatrix) = ImmutableMatrix(transpose(A.a))
1203+
1204+
A = ImmutableMatrix([1 2; 3 4])
1205+
for T in (Symmetric, Hermitian), uplo in (:U, :L)
1206+
H = T(A, uplo)
1207+
MH = Matrix(H)
1208+
@test triu(H,-1) == triu(MH,-1)
1209+
@test triu(H) == triu(MH)
1210+
@test triu(H,1) == triu(MH,1)
1211+
@test tril(H,1) == tril(MH,1)
1212+
@test tril(H) == tril(MH)
1213+
@test tril(H,-1) == tril(MH,-1)
1214+
end
1215+
end
1216+
11941217
end # module TestSymmetric

0 commit comments

Comments
 (0)