Open
Description
This issue is mostly to document that there is a potential speedup to be gain in e.g. 3x3 matrix multiply (and possible other operations where the matrix size is "odd").
Current StaticArrays.jl:
julia> for n in (2,3,4)
s = Ref(rand(SMatrix{n,n}))
@btime $(s)[] * $(s)[]
end
2.709 ns (0 allocations: 0 bytes)
10.274 ns (0 allocations: 0 bytes)
6.059 ns (0 allocations: 0 bytes)
3x3 is quite slow (the Ref
shenanigans is to prevent spurious benchmark optimizations).
Handcoding something with SIMD.jl:
using StaticArrays
import SIMD
const SVec{N, T} = SIMD.Vec{N, T}
# load given range of linear indices into SVec
@generated function tosimd(D::NTuple{N, T}, ::Type{Val{strt}}, ::Type{Val{stp}}) where {N, T, strt, stp}
expr = Expr(:tuple, [:(D[$i]) for i in strt:stp]...)
M = length(expr.args)
return quote
$(Expr(:meta, :inline))
@inbounds return SVec{$M, T}($expr)
end
end
# constructor SMatrix from several SVecs
@generated function (::Type{SMatrix{dim, dim}})(r::NTuple{M, SVec{N}}) where {dim, M, N}
return quote
$(Expr(:meta, :inline))
@inbounds return SMatrix{$dim, $dim}($(Expr(:tuple, [:(r[$j][$i]) for i in 1:N, j in 1:M]...)))
end
end
function matmul3x3(a::SMatrix, b::SMatrix)
D1 = a.data; D2 = b.data
SV11 = tosimd(D1, Val{1}, Val{3})
SV12 = tosimd(D1, Val{4}, Val{6})
SV13 = tosimd(D1, Val{7}, Val{9})
r1 = muladd(SV13, D2[3], muladd(SV12, D2[2], SV11 * D2[1]))
r2 = muladd(SV13, D2[6], muladd(SV12, D2[5], SV11 * D2[4]))
r3 = muladd(SV13, D2[9], muladd(SV12, D2[8], SV11 * D2[7]))
return SMatrix{3,3}((r1, r2, r3))
end
julia> @btime matmul3x3($(Ref(s))[], $(Ref(s)[]))
4.391 ns (0 allocations: 0 bytes)
julia> matmul3x3(s,s) - s*s
3×3 SArray{Tuple{3,3},Float64,2,9}:
0.0 0.0 0.0
0.0 0.0 -1.11022e-16
0.0 0.0 0.0
which is a ~2.5x speedup.
The code for the handcoded is
julia> @code_native matmul3x3(s,s)
.section __TEXT,__text,regular,pure_instructions
vmovupd (%rsi), %xmm0
vmovsd 16(%rsi), %xmm1 ## xmm1 = mem[0],zero
vinsertf128 $1, %xmm1, %ymm0, %ymm0
vmovupd 24(%rsi), %xmm1
vmovsd 40(%rsi), %xmm2 ## xmm2 = mem[0],zero
vinsertf128 $1, %xmm2, %ymm1, %ymm1
vmovupd 48(%rsi), %xmm2
vmovsd 64(%rsi), %xmm3 ## xmm3 = mem[0],zero
vinsertf128 $1, %xmm3, %ymm2, %ymm2
vbroadcastsd (%rdx), %ymm3
vmulpd %ymm3, %ymm0, %ymm3
vbroadcastsd 8(%rdx), %ymm4
vfmadd213pd %ymm3, %ymm1, %ymm4
vbroadcastsd 16(%rdx), %ymm3
vfmadd213pd %ymm4, %ymm2, %ymm3
vbroadcastsd 24(%rdx), %ymm4
vmulpd %ymm4, %ymm0, %ymm4
vbroadcastsd 32(%rdx), %ymm5
vfmadd213pd %ymm4, %ymm1, %ymm5
vbroadcastsd 40(%rdx), %ymm4
vfmadd213pd %ymm5, %ymm2, %ymm4
vbroadcastsd 48(%rdx), %ymm5
vmulpd %ymm5, %ymm0, %ymm0
vbroadcastsd 56(%rdx), %ymm5
vfmadd213pd %ymm0, %ymm1, %ymm5
vbroadcastsd 64(%rdx), %ymm0
vfmadd213pd %ymm5, %ymm2, %ymm0
vbroadcastsd %xmm4, %ymm1
vblendpd $8, %ymm1, %ymm3, %ymm1 ## ymm1 = ymm3[0,1,2],ymm1[3]
vmovupd %ymm1, (%rdi)
vextractf128 $1, %ymm0, %xmm1
vinsertf128 $1, %xmm0, %ymm0, %ymm0
vpermpd $233, %ymm4, %ymm2 ## ymm2 = ymm4[1,2,2,3]
vblendpd $12, %ymm0, %ymm2, %ymm0 ## ymm0 = ymm2[0,1],ymm0[2,3]
vmovupd %ymm0, 32(%rdi)
vmovlpd %xmm1, 64(%rdi)
movq %rdi, %rax
vzeroupper
retq
nopw %cs:(%rax,%rax)
;}
while for the standard *
julia> @code_native s*s
.section __TEXT,__text,regular,pure_instructions
vmovsd (%rdx), %xmm0 ## xmm0 = mem[0],zero
vmovsd 8(%rdx), %xmm7 ## xmm7 = mem[0],zero
vmovsd 16(%rdx), %xmm6 ## xmm6 = mem[0],zero
vmovsd 16(%rsi), %xmm11 ## xmm11 = mem[0],zero
vmovsd 40(%rsi), %xmm12 ## xmm12 = mem[0],zero
vmovsd 64(%rsi), %xmm9 ## xmm9 = mem[0],zero
vmovsd 24(%rdx), %xmm3 ## xmm3 = mem[0],zero
vmovupd (%rsi), %xmm4
vmovupd 8(%rsi), %xmm10
vmovhpd (%rsi), %xmm11, %xmm5 ## xmm5 = xmm11[0],mem[0]
vinsertf128 $1, %xmm5, %ymm4, %ymm5
vunpcklpd %xmm3, %xmm0, %xmm1 ## xmm1 = xmm0[0],xmm3[0]
vmovddup %xmm0, %xmm0 ## xmm0 = xmm0[0,0]
vinsertf128 $1, %xmm1, %ymm0, %ymm0
vmulpd %ymm0, %ymm5, %ymm0
vmovsd 32(%rdx), %xmm5 ## xmm5 = mem[0],zero
vmovupd 24(%rsi), %xmm8
vmovhpd 24(%rsi), %xmm12, %xmm1 ## xmm1 = xmm12[0],mem[0]
vinsertf128 $1, %xmm1, %ymm8, %ymm1
vunpcklpd %xmm5, %xmm7, %xmm2 ## xmm2 = xmm7[0],xmm5[0]
vmovddup %xmm7, %xmm7 ## xmm7 = xmm7[0,0]
vinsertf128 $1, %xmm2, %ymm7, %ymm2
vmulpd %ymm2, %ymm1, %ymm1
vaddpd %ymm1, %ymm0, %ymm1
vmovsd 40(%rdx), %xmm0 ## xmm0 = mem[0],zero
vmovhpd 48(%rsi), %xmm9, %xmm2 ## xmm2 = xmm9[0],mem[0]
vmovupd 48(%rsi), %xmm13
vinsertf128 $1, %xmm2, %ymm13, %ymm2
vunpcklpd %xmm0, %xmm6, %xmm7 ## xmm7 = xmm6[0],xmm0[0]
vmovddup %xmm6, %xmm6 ## xmm6 = xmm6[0,0]
vinsertf128 $1, %xmm7, %ymm6, %ymm6
vmulpd %ymm6, %ymm2, %ymm2
vaddpd %ymm2, %ymm1, %ymm14
vmovsd 48(%rdx), %xmm1 ## xmm1 = mem[0],zero
vmovsd 56(%rdx), %xmm2 ## xmm2 = mem[0],zero
vmovsd 64(%rdx), %xmm6 ## xmm6 = mem[0],zero
vinsertf128 $1, %xmm4, %ymm10, %ymm4
vmovddup %xmm3, %xmm3 ## xmm3 = xmm3[0,0]
vmovddup %xmm1, %xmm7 ## xmm7 = xmm1[0,0]
vinsertf128 $1, %xmm7, %ymm3, %ymm3
vmulpd %ymm3, %ymm4, %ymm3
vmovupd 32(%rsi), %xmm4
vinsertf128 $1, %xmm8, %ymm4, %ymm4
vmovddup %xmm5, %xmm5 ## xmm5 = xmm5[0,0]
vmovddup %xmm2, %xmm7 ## xmm7 = xmm2[0,0]
vinsertf128 $1, %xmm7, %ymm5, %ymm5
vmulpd %ymm5, %ymm4, %ymm4
vaddpd %ymm4, %ymm3, %ymm3
vmovupd 56(%rsi), %xmm4
vinsertf128 $1, %xmm13, %ymm4, %ymm4
vmovddup %xmm0, %xmm0 ## xmm0 = xmm0[0,0]
vmovddup %xmm6, %xmm5 ## xmm5 = xmm6[0,0]
vinsertf128 $1, %xmm5, %ymm0, %ymm0
vmulpd %ymm0, %ymm4, %ymm0
vaddpd %ymm0, %ymm3, %ymm0
vmulsd %xmm1, %xmm11, %xmm1
vmulsd %xmm2, %xmm12, %xmm2
vaddsd %xmm2, %xmm1, %xmm1
vmulsd %xmm6, %xmm9, %xmm2
vaddsd %xmm2, %xmm1, %xmm1
vmovupd %ymm14, (%rdi)
vmovupd %ymm0, 32(%rdi)
vmovsd %xmm1, 64(%rdi)
movq %rdi, %rax
vzeroupper
retq
nop
;}
We can see how much the xmm
registries are used for this compared to the SIMD version. Comparing to the 4x4 case
julia> s = rand(SMatrix{4,4})
julia> @code_native s*s
.section __TEXT,__text,regular,pure_instructions
vbroadcastsd (%rdx), %ymm4
vmovupd (%rsi), %ymm3
vmovupd 32(%rsi), %ymm2
vmovupd 64(%rsi), %ymm1
vmovupd 96(%rsi), %ymm0
vmulpd %ymm4, %ymm3, %ymm4
vbroadcastsd 8(%rdx), %ymm5
vmulpd %ymm5, %ymm2, %ymm5
vaddpd %ymm5, %ymm4, %ymm4
...
vbroadcastsd 96(%rdx), %ymm7
vmulpd %ymm7, %ymm3, %ymm3
vbroadcastsd 104(%rdx), %ymm7
vmulpd %ymm7, %ymm2, %ymm2
vaddpd %ymm2, %ymm3, %ymm2
vbroadcastsd 112(%rdx), %ymm3
vmulpd %ymm3, %ymm1, %ymm1
vaddpd %ymm1, %ymm2, %ymm1
vbroadcastsd 120(%rdx), %ymm2
vmulpd %ymm2, %ymm0, %ymm0
vaddpd %ymm0, %ymm1, %ymm0
vmovupd %ymm4, (%rdi)
vmovupd %ymm5, 32(%rdi)
vmovupd %ymm6, 64(%rdi)
vmovupd %ymm0, 96(%rdi)
movq %rdi, %rax
vzeroupper
retq
nopl (%rax)
we can now see that everything is using the ymm
registries which is what we want.