From f56fd4f63dab04cdc4a5b8a8cc7770dc23dda0c3 Mon Sep 17 00:00:00 2001 From: Andy Ferris Date: Fri, 5 May 2017 15:35:10 +1000 Subject: [PATCH] Capture complex outer product Fixes #156 --- src/matrix_multiply.jl | 14 ++++++++++++++ test/matrix_multiply.jl | 7 +++++++ 2 files changed, 21 insertions(+) diff --git a/src/matrix_multiply.jl b/src/matrix_multiply.jl index 11142eb1..1f5da1cf 100644 --- a/src/matrix_multiply.jl +++ b/src/matrix_multiply.jl @@ -36,11 +36,13 @@ promote_matprod{T1,T2}(::Type{T1}, ::Type{T2}) = typeof(zero(T1)*zero(T2) + zero @inline *(A::StaticMatrix, B::StaticMatrix) = _A_mul_B(Size(A), Size(B), A, B) @inline *(A::StaticVector, B::StaticMatrix) = *(reshape(A, Size(Size(A)[1], 1)), B) @inline *(A::StaticVector, B::RowVector{<:Any, <:StaticVector}) = _A_mul_B(Size(A), Size(B), A, B) +@inline *(A::StaticVector, B::RowVector{<:Any, CV} where CV <: ConjVector{<:Any, <:StaticVector}) = _A_mul_B(Size(A), Size(B), A, B) @inline A_mul_B!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticVector) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B) @inline A_mul_B!(dest::StaticVecOrMat, A::StaticMatrix, B::StaticMatrix) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B) @inline A_mul_B!(dest::StaticVecOrMat, A::StaticVector, B::StaticMatrix) = A_mul_B!(dest, reshape(A, Size(Size(A)[1], 1)), B) @inline A_mul_B!(dest::StaticVecOrMat, A::StaticVector, B::RowVector{<:Any, <:StaticVector}) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B) +@inline A_mul_B!(dest::StaticVecOrMat, A::StaticVector, B::RowVector{<:Any, CV} where CV <: ConjVector{<:Any, <:StaticVector}) = _A_mul_B!(Size(dest), dest, Size(A), Size(B), A, B) #@inline *{TA<:Base.LinAlg.BlasFloat,Tb}(A::StaticMatrix{TA}, b::StaticVector{Tb}) @@ -93,6 +95,18 @@ end end end +# complex outer product +@generated function _A_mul_B(::Size{sa}, ::Size{sb}, a::StaticVector{<: Any, Ta}, b::RowVector{Tb, CV} where CV <: ConjVector{<:Any, <:StaticVector}) where {sa, sb, Ta, Tb} + newsize = (sa[1], sb[2]) + exprs = [:(a[$i]*b[$j]) for i = 1:sa[1], j = 1:sb[2]] + + return quote + @_inline_meta + T = promote_op(*, Ta, Tb) + @inbounds return similar_type(b, T, Size($newsize))(tuple($(exprs...))) + end +end + @generated function _A_mul_B(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} # Heuristic choice for amount of codegen if sa[1]*sa[2]*sb[2] <= 8*8*8 diff --git a/test/matrix_multiply.jl b/test/matrix_multiply.jl index d1ee29d7..a78fe993 100644 --- a/test/matrix_multiply.jl +++ b/test/matrix_multiply.jl @@ -44,6 +44,13 @@ m = @SMatrix [1 2 3 4] v = @SVector [1, 2] @test @inferred(v*m) === @SMatrix [1 2 3 4; 2 4 6 8] + + # Outer product + v2 = zeros(SVector{3, Int}) + @test v2 * v2' === zeros(SMatrix{3, 3, Int}) + + v3 = zeros(SVector{3, Complex{Int}}) + @test v3 * v3' === zeros(SMatrix{3, 3, Complex{Int}}) end @testset "Matrix-matrix" begin