Skip to content

Commit ee6ebe2

Browse files
evelyne-ringootmaleadt
authored andcommitted
linalg: Support more inputs to tril! and triu!; printarray and getproperty of QR.
1 parent 40fa8c0 commit ee6ebe2

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

lib/GPUArraysCore/src/GPUArraysCore.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Adapt
77

88
export AbstractGPUArray, AbstractGPUVector, AbstractGPUMatrix, AbstractGPUVecOrMat,
99
WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle,
10-
AnyGPUArray, AnyGPUVector, AnyGPUMatrix
10+
AnyGPUArray, AnyGPUVector, AnyGPUMatrix, AnyGPUVecOrMat
1111

1212
"""
1313
AbstractGPUArray{T, N} <: DenseArray{T, N}
@@ -27,6 +27,7 @@ const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{
2727
const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}
2828
const AnyGPUVector{T} = AnyGPUArray{T, 1}
2929
const AnyGPUMatrix{T} = AnyGPUArray{T, 2}
30+
const AnyGPUVecOrMat{T} = Union{AnyGPUArray{T, 1}, AnyGPUArray{T, 2}}
3031

3132

3233
## broadcasting

src/host/linalg.jl

+20-2
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriang
170170
@eval Base.copyto!(A::$T{T, <:AbstractGPUArray{T,N}}, B::$T{T, <:AbstractGPUArray{T,N}}) where {T,N} = $T(copyto!(parent(A), parent(B)))
171171
end
172172

173-
function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
173+
function LinearAlgebra.tril!(A::AnyGPUMatrix{T}, d::Integer = 0) where T
174174
gpu_call(A, d; name="tril!") do ctx, _A, _d
175175
I = @cartesianidx _A
176176
i, j = Tuple(I)
@@ -182,7 +182,7 @@ function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
182182
return A
183183
end
184184

185-
function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
185+
function LinearAlgebra.triu!(A::AnyGPUMatrix{T}, d::Integer = 0) where T
186186
gpu_call(A, d; name="triu!") do ctx, _A, _d
187187
I = @cartesianidx _A
188188
i, j = Tuple(I)
@@ -795,3 +795,21 @@ function Base.isone(x::AbstractGPUMatrix{T}) where {T}
795795

796796
Array(y)[]
797797
end
798+
799+
## QR
800+
801+
import LinearAlgebra: QRPackedQ
802+
803+
function LinearAlgebra.getproperty(F::QR{T,<:AnyGPUMatrix{T}}, d::Symbol) where {T}
804+
m, n = size(F)
805+
if d === :R
806+
return triu!(view(getfield(F, :factors), 1:min(m,n), 1:n))
807+
elseif d === :Q
808+
return LinearAlgebra.QRPackedQ(getfield(F, :factors), F.τ)
809+
else
810+
getfield(F, d)
811+
end
812+
end
813+
814+
Base.print_array(io::IO, Q::QRPackedQ{T,<:AnyGPUMatrix{T},<:AnyGPUMatrix{T}}) where {T} =
815+
Base.print_array(io, collect(adapt(ToArray(), Q)))

test/testsuite/linalg.jl

+7
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,10 @@ end
378378
@test isrealfloattype(typeof(opnorm(AT(mat), p)))
379379
end
380380
end
381+
382+
@testsuite "QR" (AT, eltypes)->begin
383+
@testset "get property" for dims in [(3,5),(3,3),(5,3)],
384+
prop in [:Q, :R], T in eltypes
385+
@test compare(x -> getproperty(qr(x), prop), AT, rand(T, dims))
386+
end
387+
end

0 commit comments

Comments
 (0)