From ef1cca481527816a1c6f4f8b6108c5d657ed1f91 Mon Sep 17 00:00:00 2001
From: Alexander Lekomtsev <alelekomtsev@gmail.com>
Date: Wed, 8 Jan 2025 18:32:14 +0300
Subject: [PATCH] cmd/compile/internal/ssa: use sequence of shNadd instructions
 instead of mul

In some cases, we can replace a mul instruction with a sequence of shNadd instructions.
Using a sequence of three shNadd instructions does not provide a performance benefit.
However, a sequence of two shNadd instructions is still faster than using a single mul instruction.

Co-Authored-By: Andrei <therain.i@yahoo.com>
---
 .../compile/internal/ssa/_gen/RISCV64.rules   |  18 ++
 .../compile/internal/ssa/rewriteRISCV64.go    | 251 ++++++++++++++++++
 test/codegen/shift.go                         |  23 ++
 3 files changed, 292 insertions(+)

diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules
index 9ae96043810cfd..8e75f1c648ee27 100644
--- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules
+++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules
@@ -843,6 +843,24 @@
 (ADD (SLLI [2] x) y) && buildcfg.GORISCV64 >= 22 => (SH2ADD x y)
 (ADD (SLLI [3] x) y) && buildcfg.GORISCV64 >= 22 => (SH3ADD x y)
 
+// Mul on some constants
+(MUL x (MOVDconst [3])) && buildcfg.GORISCV64 >= 22 => (SH1ADD x x)
+(MUL x (MOVDconst [5])) && buildcfg.GORISCV64 >= 22 => (SH2ADD x x)
+(MUL x (MOVDconst [9])) && buildcfg.GORISCV64 >= 22 => (SH3ADD x x)
+
+(MUL <t> x (MOVDconst [11])) && buildcfg.GORISCV64 >= 22 => (SH1ADD (SH2ADD <t> x x) x)
+(MUL <t> x (MOVDconst [13])) && buildcfg.GORISCV64 >= 22 => (SH2ADD (SH1ADD <t> x x) x)
+(MUL <t> x (MOVDconst [19])) && buildcfg.GORISCV64 >= 22 => (SH1ADD (SH3ADD <t> x x) x)
+(MUL <t> x (MOVDconst [21])) && buildcfg.GORISCV64 >= 22 => (SH2ADD (SH2ADD <t> x x) x)
+(MUL <t> x (MOVDconst [25])) && buildcfg.GORISCV64 >= 22 => (SH3ADD (SH1ADD <t> x x) x)
+(MUL <t> x (MOVDconst [27])) && buildcfg.GORISCV64 >= 22 => (SH1ADD (SH3ADD <t> x x) (SH3ADD <t> x x))
+(MUL <t> x (MOVDconst [37])) && buildcfg.GORISCV64 >= 22 => (SH2ADD (SH3ADD <t> x x) x)
+(MUL <t> x (MOVDconst [41])) && buildcfg.GORISCV64 >= 22 => (SH3ADD (SH2ADD <t> x x) x)
+(MUL <t> x (MOVDconst [45])) && buildcfg.GORISCV64 >= 22 => (SH2ADD (SH3ADD <t> x x) (SH3ADD <t> x x))
+(MUL <t> x (MOVDconst [73])) && buildcfg.GORISCV64 >= 22 => (SH3ADD (SH3ADD <t> x x) x)
+(MUL <t> x (MOVDconst [81])) && buildcfg.GORISCV64 >= 22 => (SH3ADD (SH3ADD <t> x x) (SH3ADD <t> x x))
+
+
 // Integer minimum and maximum.
 (Min64  x y) && buildcfg.GORISCV64 >= 22 => (MIN  x y)
 (Max64  x y) && buildcfg.GORISCV64 >= 22 => (MAX  x y)
diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go
index aa44ab311e92af..e38e4063b96bc1 100644
--- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go
+++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go
@@ -531,6 +531,8 @@ func rewriteValueRISCV64(v *Value) bool {
 		return rewriteValueRISCV64_OpRISCV64MOVWstore(v)
 	case OpRISCV64MOVWstorezero:
 		return rewriteValueRISCV64_OpRISCV64MOVWstorezero(v)
+	case OpRISCV64MUL:
+		return rewriteValueRISCV64_OpRISCV64MUL(v)
 	case OpRISCV64NEG:
 		return rewriteValueRISCV64_OpRISCV64NEG(v)
 	case OpRISCV64NEGW:
@@ -6024,6 +6026,255 @@ func rewriteValueRISCV64_OpRISCV64MOVWstorezero(v *Value) bool {
 	}
 	return false
 }
+func rewriteValueRISCV64_OpRISCV64MUL(v *Value) bool {
+	v_1 := v.Args[1]
+	v_0 := v.Args[0]
+	b := v.Block
+	// match: (MUL x (MOVDconst [3]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH1ADD x x)
+	for {
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 3 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH1ADD)
+			v.AddArg2(x, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL x (MOVDconst [5]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH2ADD x x)
+	for {
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 5 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH2ADD)
+			v.AddArg2(x, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL x (MOVDconst [9]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH3ADD x x)
+	for {
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 9 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH3ADD)
+			v.AddArg2(x, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [11]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH1ADD (SH2ADD <t> x x) x)
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 11 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH1ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH2ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [13]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH2ADD (SH1ADD <t> x x) x)
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 13 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH2ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH1ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [19]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH1ADD (SH3ADD <t> x x) x)
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 19 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH1ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH3ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [21]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH2ADD (SH2ADD <t> x x) x)
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 21 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH2ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH2ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [25]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH3ADD (SH1ADD <t> x x) x)
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 25 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH3ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH1ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [27]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH1ADD (SH3ADD <t> x x) (SH3ADD <t> x x))
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 27 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH1ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH3ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, v0)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [37]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH2ADD (SH3ADD <t> x x) x)
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 37 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH2ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH3ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [41]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH3ADD (SH2ADD <t> x x) x)
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 41 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH3ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH2ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [45]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH2ADD (SH3ADD <t> x x) (SH3ADD <t> x x))
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 45 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH2ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH3ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, v0)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [73]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH3ADD (SH3ADD <t> x x) x)
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 73 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH3ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH3ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, x)
+			return true
+		}
+		break
+	}
+	// match: (MUL <t> x (MOVDconst [81]))
+	// cond: buildcfg.GORISCV64 >= 22
+	// result: (SH3ADD (SH3ADD <t> x x) (SH3ADD <t> x x))
+	for {
+		t := v.Type
+		for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 {
+			x := v_0
+			if v_1.Op != OpRISCV64MOVDconst || auxIntToInt64(v_1.AuxInt) != 81 || !(buildcfg.GORISCV64 >= 22) {
+				continue
+			}
+			v.reset(OpRISCV64SH3ADD)
+			v0 := b.NewValue0(v.Pos, OpRISCV64SH3ADD, t)
+			v0.AddArg2(x, x)
+			v.AddArg2(v0, v0)
+			return true
+		}
+		break
+	}
+	return false
+}
 func rewriteValueRISCV64_OpRISCV64NEG(v *Value) bool {
 	v_0 := v.Args[0]
 	b := v.Block
diff --git a/test/codegen/shift.go b/test/codegen/shift.go
index 2d8cf868571b7d..ad3629b02bd7f2 100644
--- a/test/codegen/shift.go
+++ b/test/codegen/shift.go
@@ -531,3 +531,26 @@ func checkLeftShiftWithAddition(a int64, b int64) int64 {
 	a = a + b<<3
 	return a
 }
+
+//
+// Multiplication by some constants
+//
+
+func checkMulByConsts(a int64, b int64, c int64, d int64, e int64) (int64, int64, int64, int64, int64) {
+	// riscv64/rva20u64: "MUL"
+	// riscv64/rva22u64: "SH1ADD"
+	a = a * 3
+	// riscv64/rva20u64: "MUL"
+	// riscv64/rva22u64: "SH2ADD"
+	b = b * 5
+	// riscv64/rva20u64: "MUL"
+	// riscv64/rva22u64: "SH2ADD", "SH1ADD"
+	c = c * 13
+	// riscv64/rva20u64: "MUL"
+	// riscv64/rva22u64: "SH1ADD", "SH3ADD"
+	d = d * 27
+	// riscv64/rva20u64: "MUL"
+	// riscv64/rva22u64: "SH3ADD", "SH3ADD"
+	e = e * 73
+	return a, b, c, d, e
+}