Skip to content

Commit c64c328

Browse files
authored
Size check in 2-argument mul (#1315)
This PR adds a size check in the 2-argument `mul`, so that now the destination array is allocated only if the sizes of the arguments are compatible with matrix multiplication. This means that we don't allocate in case of an error anymore. The performance for small-matrix multiplication seems largely similar (it's comparable to #1310, and seems identical within the noise limit): ```julia julia> A = [1 2; 3 4]; julia> @Btime $A * $A; 42.304 ns (2 allocations: 112 bytes) # before this PR 44.203 ns (2 allocations: 112 bytes) # this PR ``` We also redirect the generic `mul` to `_mul` now, which is the function that defines the multiplication code. This allows us to reuse the `_mul` definition elsewhere without having to repeat code. Currently, this is mainly necessary in the `Bidiagonal`-triangular multiplications.
2 parents 39ee3af + f2c5004 commit c64c328

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

src/bidiag.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,26 +1261,22 @@ function _dibimul_nonzeroalpha!(C::Bidiagonal, A::Diagonal, B::Bidiagonal, _add)
12611261
end
12621262

12631263
function mul(A::UpperOrUnitUpperTriangular, B::Bidiagonal)
1264-
TS = promote_op(matprod, eltype(A), eltype(B))
1265-
C = mul!(similar(A, TS, size(A)), A, B)
1264+
C = _mul(A, B)
12661265
return B.uplo == 'U' ? UpperTriangular(C) : C
12671266
end
12681267

12691268
function mul(A::LowerOrUnitLowerTriangular, B::Bidiagonal)
1270-
TS = promote_op(matprod, eltype(A), eltype(B))
1271-
C = mul!(similar(A, TS, size(A)), A, B)
1269+
C = _mul(A, B)
12721270
return B.uplo == 'L' ? LowerTriangular(C) : C
12731271
end
12741272

12751273
function mul(A::Bidiagonal, B::UpperOrUnitUpperTriangular)
1276-
TS = promote_op(matprod, eltype(A), eltype(B))
1277-
C = mul!(similar(B, TS, size(B)), A, B)
1274+
C = _mul(A, B)
12781275
return A.uplo == 'U' ? UpperTriangular(C) : C
12791276
end
12801277

12811278
function mul(A::Bidiagonal, B::LowerOrUnitLowerTriangular)
1282-
TS = promote_op(matprod, eltype(A), eltype(B))
1283-
C = mul!(similar(B, TS, size(B)), A, B)
1279+
C = _mul(A, B)
12841280
return A.uplo == 'L' ? LowerTriangular(C) : C
12851281
end
12861282

src/matmul.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@ end
5151

5252
# Matrix-vector multiplication
5353
function (*)(A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{S}) where {T<:BlasFloat,S<:Real}
54+
matmul_size_check(size(A), size(x))
5455
TS = promote_op(matprod, T, S)
5556
y = isconcretetype(TS) ? convert(AbstractVector{TS}, x) : x
5657
mul!(similar(x, TS, size(A,1)), A, y)
5758
end
5859
function (*)(A::AbstractMatrix{T}, x::AbstractVector{S}) where {T,S}
60+
matmul_size_check(size(A), size(x))
5961
TS = promote_op(matprod, T, S)
6062
mul!(similar(x, TS, axes(A,1)), A, x)
6163
end
@@ -113,7 +115,10 @@ julia> [1 1; 0 1] * [1 0; 1 1]
113115
"""
114116
(*)(A::AbstractMatrix, B::AbstractMatrix) = mul(A, B)
115117
# we add an extra level of indirection to avoid ambiguities in *
116-
function mul(A::AbstractMatrix, B::AbstractMatrix)
118+
# We also define the core functionality within _mul to reuse the code elsewhere
119+
mul(A::AbstractMatrix, B::AbstractMatrix) = _mul(A, B)
120+
function _mul(A::AbstractMatrix, B::AbstractMatrix)
121+
matmul_size_check(size(A), size(B))
117122
TS = promote_op(matprod, eltype(A), eltype(B))
118123
mul!(matprod_dest(A, B, TS), A, B)
119124
end

src/symmetric.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,12 +710,14 @@ end
710710
mul(A::HermOrSym, B::HermOrSym) = A * copyto!(similar(parent(B)), B)
711711
# catch a few potential BLAS-cases
712712
function mul(A::HermOrSym{<:BlasFloat,<:StridedMatrix}, B::AdjOrTrans{<:BlasFloat,<:StridedMatrix})
713+
matmul_size_check(size(A), size(B))
713714
T = promote_type(eltype(A), eltype(B))
714715
mul!(similar(B, T, (size(A, 1), size(B, 2))),
715716
convert(AbstractMatrix{T}, A),
716717
copy_oftype(B, T)) # make sure the AdjOrTrans wrapper is resolved
717718
end
718719
function mul(A::AdjOrTrans{<:BlasFloat,<:StridedMatrix}, B::HermOrSym{<:BlasFloat,<:StridedMatrix})
720+
matmul_size_check(size(A), size(B))
719721
T = promote_type(eltype(A), eltype(B))
720722
mul!(similar(B, T, (size(A, 1), size(B, 2))),
721723
copy_oftype(A, T), # make sure the AdjOrTrans wrapper is resolved

0 commit comments

Comments
 (0)