Skip to content

Commit 566b946

Browse files
committed
Out-of-place triu/tril for Symmetric in each branch
1 parent ea8e858 commit 566b946

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

src/symmetric.jl

+12-12
Original file line numberDiff line numberDiff line change
@@ -470,49 +470,49 @@ Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo)
470470
# tril/triu
471471
function tril(A::Hermitian, k::Integer=0)
472472
if A.uplo == 'U' && k <= 0
473-
return tril!(copy(A.data'),k)
473+
return tril_maybe_inplace(copy(A.data'),k)
474474
elseif A.uplo == 'U' && k > 0
475-
return tril!(copy(A.data'),-1) + tril!(triu(A.data),k)
475+
return tril_maybe_inplace(copy(A.data'),-1) + tril_maybe_inplace(triu(A.data),k)
476476
elseif A.uplo == 'L' && k <= 0
477477
return tril(A.data,k)
478478
else
479-
return tril(A.data,-1) + tril!(triu!(copy(A.data')),k)
479+
return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(A.data')),k)
480480
end
481481
end
482482

483483
function tril(A::Symmetric, k::Integer=0)
484484
if A.uplo == 'U' && k <= 0
485-
return tril!(copy(transpose(A.data)),k)
485+
return tril_maybe_inplace(copy(transpose(A.data)),k)
486486
elseif A.uplo == 'U' && k > 0
487-
return tril!(copy(transpose(A.data)),-1) + tril!(triu(A.data),k)
487+
return tril_maybe_inplace(copy(transpose(A.data)),-1) + tril_maybe_inplace(triu(A.data),k)
488488
elseif A.uplo == 'L' && k <= 0
489489
return tril(A.data,k)
490490
else
491-
return tril(A.data,-1) + tril!(triu!(copy(transpose(A.data))),k)
491+
return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(transpose(A.data))),k)
492492
end
493493
end
494494

495495
function triu(A::Hermitian, k::Integer=0)
496496
if A.uplo == 'U' && k >= 0
497497
return triu(A.data,k)
498498
elseif A.uplo == 'U' && k < 0
499-
return triu(A.data,1) + triu!(tril!(copy(A.data')),k)
499+
return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(A.data')),k)
500500
elseif A.uplo == 'L' && k >= 0
501-
return triu!(copy(A.data'),k)
501+
return triu_maybe_inplace(copy(A.data'),k)
502502
else
503-
return triu!(copy(A.data'),1) + triu!(tril(A.data),k)
503+
return triu_maybe_inplace(copy(A.data'),1) + triu_maybe_inplace(tril(A.data),k)
504504
end
505505
end
506506

507507
function triu(A::Symmetric, k::Integer=0)
508508
if A.uplo == 'U' && k >= 0
509509
return triu(A.data,k)
510510
elseif A.uplo == 'U' && k < 0
511-
return triu(A.data,1) + triu!(tril!(copy(transpose(A.data))),k)
511+
return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(transpose(A.data))),k)
512512
elseif A.uplo == 'L' && k >= 0
513-
return triu!(copy(transpose(A.data)),k)
513+
return triu_maybe_inplace(copy(transpose(A.data)),k)
514514
else
515-
return triu!(copy(transpose(A.data)),1) + triu!(tril(A.data),k)
515+
return triu_maybe_inplace(copy(transpose(A.data)),1) + triu_maybe_inplace(tril(A.data),k)
516516
end
517517
end
518518

src/triangular.jl

+5
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,11 @@ function tril!(A::UnitLowerTriangular, k::Integer=0)
484484
return tril!(LowerTriangular(A.data), k)
485485
end
486486

487+
tril_maybe_inplace(A, k::Integer=0) = tril(A, k)
488+
triu_maybe_inplace(A, k::Integer=0) = triu(A, k)
489+
tril_maybe_inplace(A::StridedMatrix, k::Integer=0) = tril!(A, k)
490+
triu_maybe_inplace(A::StridedMatrix, k::Integer=0) = triu!(A, k)
491+
487492
adjoint(A::LowerTriangular) = UpperTriangular(adjoint(A.data))
488493
adjoint(A::UpperTriangular) = LowerTriangular(adjoint(A.data))
489494
adjoint(A::UnitLowerTriangular) = UnitUpperTriangular(adjoint(A.data))

test/symmetric.jl

+23
Original file line numberDiff line numberDiff line change
@@ -1199,4 +1199,27 @@ end
11991199
end
12001200
end
12011201

1202+
@testset "triu/tril with immutable arrays" begin
1203+
struct ImmutableMatrix{T,A<:AbstractMatrix{T}} <: AbstractMatrix{T}
1204+
a :: A
1205+
end
1206+
Base.size(A::ImmutableMatrix) = size(A.a)
1207+
Base.getindex(A::ImmutableMatrix, i::Int, j::Int) = getindex(A.a, i, j)
1208+
Base.copy(A::ImmutableMatrix) = A
1209+
LinearAlgebra.adjoint(A::ImmutableMatrix) = ImmutableMatrix(adjoint(A.a))
1210+
LinearAlgebra.transpose(A::ImmutableMatrix) = ImmutableMatrix(transpose(A.a))
1211+
1212+
A = ImmutableMatrix([1 2; 3 4])
1213+
for T in (Symmetric, Hermitian), uplo in (:U, :L)
1214+
H = T(A, uplo)
1215+
MH = Matrix(H)
1216+
@test triu(H,-1) == triu(MH,-1)
1217+
@test triu(H) == triu(MH)
1218+
@test triu(H,1) == triu(MH,1)
1219+
@test tril(H,1) == tril(MH,1)
1220+
@test tril(H) == tril(MH)
1221+
@test tril(H,-1) == tril(MH,-1)
1222+
end
1223+
end
1224+
12021225
end # module TestSymmetric

0 commit comments

Comments
 (0)