diff --git a/src/linalg.jl b/src/linalg.jl index 0c540473..7a0a5064 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -103,7 +103,7 @@ end return quote @_inline_meta - @inbounds return similar_type(a, Size($Snew))(tuple($(exprs...))) + @inbounds return similar_type(a, promote_type(eltype(a), eltype(b)), Size($Snew))(tuple($(exprs...))) end end # TODO make these more efficient @@ -129,7 +129,7 @@ end return quote @_inline_meta - @inbounds return similar_type(a, Size($Snew))(tuple($(exprs...))) + @inbounds return similar_type(a, promote_type(eltype(a), eltype(b)), Size($Snew))(tuple($(exprs...))) end end # TODO make these more efficient @@ -297,8 +297,8 @@ end end end -# TODO same for `RowVector`? @inline Size(::Union{RowVector{T, SA}, Type{RowVector{T, SA}}}) where {T, SA <: StaticArray} = Size(1, Size(SA)[1]) +@inline Size(::Union{RowVector{T, CA}, Type{RowVector{T, CA}}} where CA <: ConjVector{<:Any, SA}) where {T, SA <: StaticArray} = Size(1, Size(SA)[1]) @inline Size(::Union{Symmetric{T,SA}, Type{Symmetric{T,SA}}}) where {T,SA<:StaticArray} = Size(SA) @inline Size(::Union{Hermitian{T,SA}, Type{Hermitian{T,SA}}}) where {T,SA<:StaticArray} = Size(SA) diff --git a/src/matrix_multiply.jl b/src/matrix_multiply.jl index 11142eb1..628fbe4f 100644 --- a/src/matrix_multiply.jl +++ b/src/matrix_multiply.jl @@ -36,11 +36,13 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero @inline *(A::StaticMatrix, B::StaticMatrix) = _A_mul_B(Size(A), Size(B), A, B) @inline *(A::StaticVector, B::StaticMatrix) = *(reshape(A, Size(Size(A)[1], 1)), B) @inline *(A::StaticVector, B::RowVector{<:Any, <:StaticVector}) = _A_mul_B(Size(A), Size(B), A, B) +@inline *(A::StaticVector, B::RowVector{<:Any, <:ConjVector{<:Any, <:StaticVector}}) = _A_mul_B(Size(A), Size(B), A, B) @inline A_mul_B!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticVector) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B) @inline A_mul_B!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticMatrix) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B) @inline A_mul_B!(dest::StaticVecOrMat, A::StaticVector, B::StaticMatrix) = A_mul_B!(dest, reshape(A, Size(Size(A)[1], 1)), B) @inline A_mul_B!(dest::StaticVecOrMat, A::StaticVector, B::RowVector{<:Any, <:StaticVector}) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B) +@inline A_mul_B!(dest::StaticVecOrMat, A::StaticVector, B::RowVector{<:Any, <:ConjVector{<:Any, <:StaticVector}}) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B) #@inline *{TA<:Base.LinAlg.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb}) @@ -93,6 +95,18 @@ end end end +# complex outer product +@generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticVector{<: Any, Ta}, b::RowVector{Tb, <:ConjVector{<:Any, <:StaticVector}}) where {sa, sb, Ta, Tb} + newsize = (sa[1], sb[2]) + exprs = [:(a[$i]*b[$j]) for i = 1:sa[1], j = 1:sb[2]] + + return quote + @_inline_meta + T = promote_op(*, Ta, Tb) + @inbounds return similar_type(b, T, Size($newsize))(tuple($(exprs...))) + end +end + @generated function _A_mul_B(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} # Heuristic choice for amount of codegen if sa[1]*sa[2]*sb[2] <= 8*8*8 diff --git a/test/linalg.jl b/test/linalg.jl index 84de59be..1763cde8 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -100,6 +100,9 @@ @test @inferred(vcat(SVector(1),SVector(2),SVector(3),SVector(4))) === SVector(1,2,3,4) @test @inferred(hcat(SVector(1),SVector(2),SVector(3),SVector(4))) === SMatrix{1,4}(1,2,3,4) + + vcat(SVector(1.0f0), SVector(1.0)) === SVector(1.0, 1.0) + hcat(SVector(1.0f0), SVector(1.0)) === SMatrix{1,2}(1.0, 1.0) end @testset "normalization" begin diff --git a/test/matrix_multiply.jl b/test/matrix_multiply.jl index d1ee29d7..78aca908 100644 --- a/test/matrix_multiply.jl +++ b/test/matrix_multiply.jl @@ -44,6 +44,15 @@ m = @SMatrix [1 2 3 4] v = @SVector [1, 2] @test @inferred(v*m) === @SMatrix [1 2 3 4; 2 4 6 8] + + # Outer product + v2 = SVector(1, 2) + v3 = SVector(3, 4) + @test v2 * v3' === @SMatrix [3 4; 6 8] + + v4 = SVector(1+0im, 2+0im) + v5 = SVector(3+0im, 4+0im) + @test v4 * v5' === @SMatrix [3+0im 4+0im; 6+0im 8+0im] end @testset "Matrix-matrix" begin