Skip to content

3x3 matrix multiply could potentially be faster #512

Open
@KristofferC

Description

@KristofferC

This issue is mostly to document that there is a potential speedup to be gain in e.g. 3x3 matrix multiply (and possible other operations where the matrix size is "odd").

Current StaticArrays.jl:

julia> for n in (2,3,4)
           s = Ref(rand(SMatrix{n,n}))
           @btime $(s)[] * $(s)[]
       end
  2.709 ns (0 allocations: 0 bytes)
  10.274 ns (0 allocations: 0 bytes)
  6.059 ns (0 allocations: 0 bytes)

3x3 is quite slow (the Ref shenanigans is to prevent spurious benchmark optimizations).

Handcoding something with SIMD.jl:

using StaticArrays
import SIMD
const SVec{N, T} = SIMD.Vec{N, T}

# load given range of linear indices into SVec
@generated function tosimd(D::NTuple{N, T}, ::Type{Val{strt}}, ::Type{Val{stp}}) where {N, T, strt, stp}
    expr = Expr(:tuple, [:(D[$i]) for i in strt:stp]...)
    M = length(expr.args)
    return quote
        $(Expr(:meta, :inline))
        @inbounds return SVec{$M, T}($expr)
    end
end

# constructor SMatrix from several SVecs
@generated function (::Type{SMatrix{dim, dim}})(r::NTuple{M, SVec{N}}) where {dim, M, N}
    return quote
        $(Expr(:meta, :inline))
        @inbounds return SMatrix{$dim, $dim}($(Expr(:tuple, [:(r[$j][$i]) for i in 1:N, j in 1:M]...)))
    end
end


function matmul3x3(a::SMatrix, b::SMatrix)
    D1 = a.data; D2 = b.data
    SV11 = tosimd(D1, Val{1}, Val{3})
    SV12 = tosimd(D1, Val{4}, Val{6})
    SV13 = tosimd(D1, Val{7}, Val{9})
    r1 = muladd(SV13, D2[3], muladd(SV12, D2[2], SV11 * D2[1]))
    r2 = muladd(SV13, D2[6], muladd(SV12, D2[5], SV11 * D2[4]))
    r3 = muladd(SV13, D2[9], muladd(SV12, D2[8], SV11 * D2[7]))
    return SMatrix{3,3}((r1, r2, r3))
end
julia> @btime matmul3x3($(Ref(s))[], $(Ref(s)[]))
  4.391 ns (0 allocations: 0 bytes)

julia> matmul3x3(s,s) - s*s
3×3 SArray{Tuple{3,3},Float64,2,9}:
 0.0  0.0   0.0        
 0.0  0.0  -1.11022e-16
 0.0  0.0   0.0        

which is a ~2.5x speedup.

The code for the handcoded is

julia> @code_native matmul3x3(s,s)
	.section	__TEXT,__text,regular,pure_instructions
	vmovupd	(%rsi), %xmm0
	vmovsd	16(%rsi), %xmm1         ## xmm1 = mem[0],zero
	vinsertf128	$1, %xmm1, %ymm0, %ymm0
	vmovupd	24(%rsi), %xmm1
	vmovsd	40(%rsi), %xmm2         ## xmm2 = mem[0],zero
	vinsertf128	$1, %xmm2, %ymm1, %ymm1
	vmovupd	48(%rsi), %xmm2
	vmovsd	64(%rsi), %xmm3         ## xmm3 = mem[0],zero
	vinsertf128	$1, %xmm3, %ymm2, %ymm2
	vbroadcastsd	(%rdx), %ymm3
	vmulpd	%ymm3, %ymm0, %ymm3
	vbroadcastsd	8(%rdx), %ymm4
	vfmadd213pd	%ymm3, %ymm1, %ymm4
	vbroadcastsd	16(%rdx), %ymm3
	vfmadd213pd	%ymm4, %ymm2, %ymm3
	vbroadcastsd	24(%rdx), %ymm4
	vmulpd	%ymm4, %ymm0, %ymm4
	vbroadcastsd	32(%rdx), %ymm5
	vfmadd213pd	%ymm4, %ymm1, %ymm5
	vbroadcastsd	40(%rdx), %ymm4
	vfmadd213pd	%ymm5, %ymm2, %ymm4
	vbroadcastsd	48(%rdx), %ymm5
	vmulpd	%ymm5, %ymm0, %ymm0
	vbroadcastsd	56(%rdx), %ymm5
	vfmadd213pd	%ymm0, %ymm1, %ymm5
	vbroadcastsd	64(%rdx), %ymm0
	vfmadd213pd	%ymm5, %ymm2, %ymm0
	vbroadcastsd	%xmm4, %ymm1
	vblendpd	$8, %ymm1, %ymm3, %ymm1 ## ymm1 = ymm3[0,1,2],ymm1[3]
	vmovupd	%ymm1, (%rdi)
	vextractf128	$1, %ymm0, %xmm1
	vinsertf128	$1, %xmm0, %ymm0, %ymm0
	vpermpd	$233, %ymm4, %ymm2      ## ymm2 = ymm4[1,2,2,3]
	vblendpd	$12, %ymm0, %ymm2, %ymm0 ## ymm0 = ymm2[0,1],ymm0[2,3]
	vmovupd	%ymm0, 32(%rdi)
	vmovlpd	%xmm1, 64(%rdi)
	movq	%rdi, %rax
	vzeroupper
	retq
	nopw	%cs:(%rax,%rax)
;}

while for the standard *

julia> @code_native s*s
	.section	__TEXT,__text,regular,pure_instructions
	vmovsd	(%rdx), %xmm0           ## xmm0 = mem[0],zero
	vmovsd	8(%rdx), %xmm7          ## xmm7 = mem[0],zero
	vmovsd	16(%rdx), %xmm6         ## xmm6 = mem[0],zero
	vmovsd	16(%rsi), %xmm11        ## xmm11 = mem[0],zero
	vmovsd	40(%rsi), %xmm12        ## xmm12 = mem[0],zero
	vmovsd	64(%rsi), %xmm9         ## xmm9 = mem[0],zero
	vmovsd	24(%rdx), %xmm3         ## xmm3 = mem[0],zero
	vmovupd	(%rsi), %xmm4
	vmovupd	8(%rsi), %xmm10
	vmovhpd	(%rsi), %xmm11, %xmm5   ## xmm5 = xmm11[0],mem[0]
	vinsertf128	$1, %xmm5, %ymm4, %ymm5
	vunpcklpd	%xmm3, %xmm0, %xmm1 ## xmm1 = xmm0[0],xmm3[0]
	vmovddup	%xmm0, %xmm0    ## xmm0 = xmm0[0,0]
	vinsertf128	$1, %xmm1, %ymm0, %ymm0
	vmulpd	%ymm0, %ymm5, %ymm0
	vmovsd	32(%rdx), %xmm5         ## xmm5 = mem[0],zero
	vmovupd	24(%rsi), %xmm8
	vmovhpd	24(%rsi), %xmm12, %xmm1 ## xmm1 = xmm12[0],mem[0]
	vinsertf128	$1, %xmm1, %ymm8, %ymm1
	vunpcklpd	%xmm5, %xmm7, %xmm2 ## xmm2 = xmm7[0],xmm5[0]
	vmovddup	%xmm7, %xmm7    ## xmm7 = xmm7[0,0]
	vinsertf128	$1, %xmm2, %ymm7, %ymm2
	vmulpd	%ymm2, %ymm1, %ymm1
	vaddpd	%ymm1, %ymm0, %ymm1
	vmovsd	40(%rdx), %xmm0         ## xmm0 = mem[0],zero
	vmovhpd	48(%rsi), %xmm9, %xmm2  ## xmm2 = xmm9[0],mem[0]
	vmovupd	48(%rsi), %xmm13
	vinsertf128	$1, %xmm2, %ymm13, %ymm2
	vunpcklpd	%xmm0, %xmm6, %xmm7 ## xmm7 = xmm6[0],xmm0[0]
	vmovddup	%xmm6, %xmm6    ## xmm6 = xmm6[0,0]
	vinsertf128	$1, %xmm7, %ymm6, %ymm6
	vmulpd	%ymm6, %ymm2, %ymm2
	vaddpd	%ymm2, %ymm1, %ymm14
	vmovsd	48(%rdx), %xmm1         ## xmm1 = mem[0],zero
	vmovsd	56(%rdx), %xmm2         ## xmm2 = mem[0],zero
	vmovsd	64(%rdx), %xmm6         ## xmm6 = mem[0],zero
	vinsertf128	$1, %xmm4, %ymm10, %ymm4
	vmovddup	%xmm3, %xmm3    ## xmm3 = xmm3[0,0]
	vmovddup	%xmm1, %xmm7    ## xmm7 = xmm1[0,0]
	vinsertf128	$1, %xmm7, %ymm3, %ymm3
	vmulpd	%ymm3, %ymm4, %ymm3
	vmovupd	32(%rsi), %xmm4
	vinsertf128	$1, %xmm8, %ymm4, %ymm4
	vmovddup	%xmm5, %xmm5    ## xmm5 = xmm5[0,0]
	vmovddup	%xmm2, %xmm7    ## xmm7 = xmm2[0,0]
	vinsertf128	$1, %xmm7, %ymm5, %ymm5
	vmulpd	%ymm5, %ymm4, %ymm4
	vaddpd	%ymm4, %ymm3, %ymm3
	vmovupd	56(%rsi), %xmm4
	vinsertf128	$1, %xmm13, %ymm4, %ymm4
	vmovddup	%xmm0, %xmm0    ## xmm0 = xmm0[0,0]
	vmovddup	%xmm6, %xmm5    ## xmm5 = xmm6[0,0]
	vinsertf128	$1, %xmm5, %ymm0, %ymm0
	vmulpd	%ymm0, %ymm4, %ymm0
	vaddpd	%ymm0, %ymm3, %ymm0
	vmulsd	%xmm1, %xmm11, %xmm1
	vmulsd	%xmm2, %xmm12, %xmm2
	vaddsd	%xmm2, %xmm1, %xmm1
	vmulsd	%xmm6, %xmm9, %xmm2
	vaddsd	%xmm2, %xmm1, %xmm1
	vmovupd	%ymm14, (%rdi)
	vmovupd	%ymm0, 32(%rdi)
	vmovsd	%xmm1, 64(%rdi)
	movq	%rdi, %rax
	vzeroupper
	retq
	nop
;}

We can see how much the xmm registries are used for this compared to the SIMD version. Comparing to the 4x4 case

julia> s = rand(SMatrix{4,4})

julia> @code_native s*s
	.section	__TEXT,__text,regular,pure_instructions
	vbroadcastsd	(%rdx), %ymm4
	vmovupd	(%rsi), %ymm3
	vmovupd	32(%rsi), %ymm2
	vmovupd	64(%rsi), %ymm1
	vmovupd	96(%rsi), %ymm0
	vmulpd	%ymm4, %ymm3, %ymm4
	vbroadcastsd	8(%rdx), %ymm5
	vmulpd	%ymm5, %ymm2, %ymm5
	vaddpd	%ymm5, %ymm4, %ymm4
        ...
	vbroadcastsd	96(%rdx), %ymm7
	vmulpd	%ymm7, %ymm3, %ymm3
	vbroadcastsd	104(%rdx), %ymm7
	vmulpd	%ymm7, %ymm2, %ymm2
	vaddpd	%ymm2, %ymm3, %ymm2
	vbroadcastsd	112(%rdx), %ymm3
	vmulpd	%ymm3, %ymm1, %ymm1
	vaddpd	%ymm1, %ymm2, %ymm1
	vbroadcastsd	120(%rdx), %ymm2
	vmulpd	%ymm2, %ymm0, %ymm0
	vaddpd	%ymm0, %ymm1, %ymm0
	vmovupd	%ymm4, (%rdi)
	vmovupd	%ymm5, 32(%rdi)
	vmovupd	%ymm6, 64(%rdi)
	vmovupd	%ymm0, 96(%rdi)
	movq	%rdi, %rax
	vzeroupper
	retq
	nopl	(%rax)

we can now see that everything is using the ymm registries which is what we want.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions