Skip to content

Commit f977faa

Browse files
committed
narrow array conversions. fixes #26294, fixes #26178.
Not all array types can convert from any AbstractArray via a 1-argument constructor call.
1 parent 73313f2 commit f977faa

File tree

12 files changed

+36
-3
lines changed

12 files changed

+36
-3
lines changed

base/abstractarray.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ Supertype for `N`-dimensional arrays (or array-like types) with elements of type
1212
AbstractArray
1313

1414
convert(::Type{T}, a::T) where {T<:AbstractArray} = a
15-
convert(::Type{T}, a::AbstractArray) where {T<:AbstractArray} = T(a)
15+
convert(::Type{AbstractArray{T}}, a::AbstractArray) where {T} = AbstractArray{T}(a)
16+
convert(::Type{AbstractArray{T,N}}, a::AbstractArray{<:Any,N}) where {T,N} = AbstractArray{T,N}(a)
1617

1718
if nameof(@__MODULE__) === :Base # avoid method overwrite
1819
# catch undefined constructors before the deprecation kicks in

base/array.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,7 @@ oneunit(x::AbstractMatrix{T}) where {T} = _one(oneunit(T), x)
413413

414414
## Conversions ##
415415

416-
# arises in similar(dest, Pair{Union{},Union{}}) where dest::Dict:
417-
convert(::Type{Vector{Union{}}}, a::Vector{Union{}}) = a
416+
convert(::Type{T}, a::AbstractArray) where {T<:Array} = a isa T ? a : T(a)
418417

419418
promote_rule(a::Type{Array{T,n}}, b::Type{Array{S,n}}) where {T,n,S} = el_same(promote_type(T,S), a, b)
420419

base/bitarray.jl

+2
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,8 @@ julia> BitArray(x+y == 3 for x = 1:2 for y = 1:3)
532532
"""
533533
BitArray(itr) = gen_bitarray(IteratorSize(itr), itr)
534534

535+
convert(T::Type{<:BitArray}, a::AbstractArray) = a isa T ? a : T(a)
536+
535537
# generic constructor from an iterable without compile-time info
536538
# (we pass start(itr) explicitly to avoid a type-instability with filters)
537539
gen_bitarray(isz::IteratorSize, itr) = gen_bitarray_from_itr(itr, start(itr))

base/range.jl

+2
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ abstract type AbstractRange{T} <: AbstractArray{T,1} end
9797
RangeStepStyle(::Type{<:AbstractRange}) = RangeStepIrregular()
9898
RangeStepStyle(::Type{<:AbstractRange{<:Integer}}) = RangeStepRegular()
9999

100+
convert(::Type{T}, r::AbstractRange) where {T<:AbstractRange} = r isa T ? r : T(r)
101+
100102
## ordinal ranges
101103

102104
abstract type OrdinalRange{T,S} <: AbstractRange{T} end

stdlib/LinearAlgebra/src/bidiag.jl

+2
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ Bidiagonal{T}(A::Bidiagonal) where {T} =
172172
# When asked to convert Bidiagonal to AbstractMatrix{T}, preserve structure by converting to Bidiagonal{T} <: AbstractMatrix{T}
173173
AbstractMatrix{T}(A::Bidiagonal) where {T} = convert(Bidiagonal{T}, A)
174174

175+
convert(T::Type{<:Bidiagonal}, m::AbstractMatrix) = m isa T ? m : T(m)
176+
175177
broadcast(::typeof(big), B::Bidiagonal) = Bidiagonal(big.(B.dv), big.(B.ev), B.uplo)
176178

177179
# For B<:Bidiagonal, similar(B[, neweltype]) should yield a Bidiagonal matrix.

stdlib/LinearAlgebra/src/special.jl

+9
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ Tridiagonal(A::AbstractTriangular) =
5858
throw(ArgumentError("matrix cannot be represented as Tridiagonal"))
5959

6060

61+
const ConvertibleSpecialMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,AbstractTriangular}
62+
63+
convert(T::Type{<:Diagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
64+
convert(T::Type{<:SymTridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
65+
convert(T::Type{<:Tridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
66+
67+
convert(T::Type{<:LowerTriangular}, m::Union{LowerTriangular,UnitLowerTriangular}) = m isa T ? m : T(m)
68+
convert(T::Type{<:UpperTriangular}, m::Union{UpperTriangular,UnitUpperTriangular}) = m isa T ? m : T(m)
69+
6170
# Constructs two method definitions taking into account (assumed) commutativity
6271
# e.g. @commutative f(x::S, y::T) where {S,T} = x+y is the same is defining
6372
# f(x::S, y::T) where {S,T} = x+y

stdlib/LinearAlgebra/src/symmetric.jl

+3
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ for (S, H) in ((:Symmetric, :Hermitian), (:Hermitian, :Symmetric))
163163
end
164164
end
165165

166+
convert(T::Type{<:Symmetric}, m::Union{Symmetric,Hermitian}) = m isa T ? m : T(m)
167+
convert(T::Type{<:Hermitian}, m::Union{Symmetric,Hermitian}) = m isa T ? m : T(m)
168+
166169
const HermOrSym{T,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
167170
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
168171
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}

stdlib/LinearAlgebra/test/diagonal.jl

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ srand(1)
2727
@test Diagonal{elty}(x)::Diagonal{elty,typeof(x)} == DM
2828
@test Diagonal{elty}(x).diag === x
2929
end
30+
# issue #26178
31+
@test_throws MethodError convert(Diagonal, [1, 2, 3, 4])
3032
end
3133

3234
@testset "Basic properties" begin

stdlib/SharedArrays/src/SharedArrays.jl

+2
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ function SharedArray{TS,N}(A::Array{TA,N}) where {TS,TA,N}
358358
copyto!(S, A)
359359
end
360360

361+
convert(T::Type{<:SharedArray}, a::Array) = T(a)
362+
361363
function deepcopy_internal(S::SharedArray, stackdict::IdDict)
362364
haskey(stackdict, S) && return stackdict[S]
363365
R = SharedArray{eltype(S),ndims(S)}(size(S); pids = S.pids)

stdlib/SparseArrays/src/sparsematrix.jl

+2
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ function Matrix(S::SparseMatrixCSC{Tv}) where Tv
412412
end
413413
Array(S::SparseMatrixCSC) = Matrix(S)
414414

415+
convert(T::Type{<:SparseMatrixCSC}, m::AbstractMatrix) = m isa T ? m : T(m)
416+
415417
float(S::SparseMatrixCSC) = SparseMatrixCSC(S.m, S.n, copy(S.colptr), copy(S.rowval), float.(S.nzval))
416418
complex(S::SparseMatrixCSC) = SparseMatrixCSC(S.m, S.n, copy(S.colptr), copy(S.rowval), complex(copy(S.nzval)))
417419

stdlib/SparseArrays/src/sparsevector.jl

+4
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,10 @@ SparseVector{Tv,Ti}(s::SparseVector) where {Tv,Ti} =
420420
SparseVector{Tv}(s::SparseVector{<:Any,Ti}) where {Tv,Ti} =
421421
SparseVector{Tv,Ti}(s.n, s.nzind, convert(Vector{Tv}, s.nzval))
422422

423+
convert(T::Type{<:SparseVector}, m::AbstractVector) = m isa T ? m : T(m)
424+
425+
convert(T::Type{<:SparseVector}, m::SparseMatrixCSC) = T(m)
426+
convert(T::Type{<:SparseMatrixCSC}, v::SparseVector) = T(v)
423427

424428
### copying
425429
function prep_sparsevec_copy_dest!(A::SparseVector, lB, nnzB)

stdlib/SuiteSparse/src/cholmod.jl

+5
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,7 @@ function SparseMatrixCSC{Tv,SuiteSparse_long}(A::Sparse{Tv}) where Tv
10821082
return B
10831083
end
10841084
end
1085+
10851086
function (::Type{Symmetric{Float64,SparseMatrixCSC{Float64,SuiteSparse_long}}})(A::Sparse{Float64})
10861087
s = unsafe_load(pointer(A))
10871088
if !issymmetric(A)
@@ -1099,6 +1100,8 @@ function (::Type{Symmetric{Float64,SparseMatrixCSC{Float64,SuiteSparse_long}}})(
10991100
return B
11001101
end
11011102
end
1103+
convert(T::Type{Symmetric{Float64,SparseMatrixCSC{Float64,SuiteSparse_long}}}, A::Sparse{Float64}) = T(A)
1104+
11021105
function Hermitian{Tv,SparseMatrixCSC{Tv,SuiteSparse_long}}(A::Sparse{Tv}) where Tv<:VTypes
11031106
s = unsafe_load(pointer(A))
11041107
if !ishermitian(A)
@@ -1116,6 +1119,8 @@ function Hermitian{Tv,SparseMatrixCSC{Tv,SuiteSparse_long}}(A::Sparse{Tv}) where
11161119
return B
11171120
end
11181121
end
1122+
convert(T::Type{Hermitian{Tv,SparseMatrixCSC{Tv,SuiteSparse_long}}}, A::Sparse{Tv}) where {Tv<:VTypes} = T(A)
1123+
11191124
function sparse(A::Sparse{Float64}) # Notice! Cannot be type stable because of stype
11201125
s = unsafe_load(pointer(A))
11211126
if s.stype == 0

0 commit comments

Comments
 (0)