-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[RISCV][llvm] Support BUILD_VECTOR codegen for P extension #169083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-backend-risc-v Author: Brandon Wu (4vtomat) ChangesFull diff: https://github.com/llvm/llvm-project/pull/169083.diff 5 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index 5025122db3681..fc2d034f4d589 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -1867,6 +1867,43 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
CurDAG->RemoveDeadNode(Node);
return;
}
+ case RISCVISD::PPACK_DH: {
+ assert(Subtarget->enablePExtCodeGen() && Subtarget->isRV32());
+
+ SDValue Val0 = Node->getOperand(0);
+ SDValue Val1 = Node->getOperand(1);
+ SDValue Val2 = Node->getOperand(2);
+ SDValue Val3 = Node->getOperand(3);
+
+ SDValue Ops[] = {
+ CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32), Val0,
+ CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32), Val2,
+ CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)};
+ SDValue RegPair0 =
+ SDValue(CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL,
+ MVT::Untyped, Ops),
+ 0);
+ SDValue Ops1[] = {
+ CurDAG->getTargetConstant(RISCV::GPRPairRegClassID, DL, MVT::i32), Val1,
+ CurDAG->getTargetConstant(RISCV::sub_gpr_even, DL, MVT::i32), Val3,
+ CurDAG->getTargetConstant(RISCV::sub_gpr_odd, DL, MVT::i32)};
+ SDValue RegPair1 =
+ SDValue(CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL,
+ MVT::Untyped, Ops1),
+ 0);
+
+ MachineSDNode *PackDH = CurDAG->getMachineNode(
+ RISCV::PPACK_DH, DL, MVT::Untyped, {RegPair0, RegPair1});
+
+ SDValue Lo = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_even, DL,
+ MVT::i32, SDValue(PackDH, 0));
+ SDValue Hi = CurDAG->getTargetExtractSubreg(RISCV::sub_gpr_odd, DL,
+ MVT::i32, SDValue(PackDH, 0));
+ ReplaceUses(SDValue(Node, 0), Lo);
+ ReplaceUses(SDValue(Node, 1), Hi);
+ CurDAG->RemoveDeadNode(Node);
+ return;
+ }
case ISD::INTRINSIC_WO_CHAIN: {
unsigned IntNo = Node->getConstantOperandVal(0);
switch (IntNo) {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6020fb6ca16ce..2b814fab6a92e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -519,6 +519,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTruncStoreAction(MVT::v4i16, MVT::v4i8, Expand);
} else {
VTs.append({MVT::v2i16, MVT::v4i8});
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
}
setOperationAction(ISD::UADDSAT, VTs, Legal);
setOperationAction(ISD::SADDSAT, VTs, Legal);
@@ -4434,6 +4435,29 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
SDLoc DL(Op);
+ if (Subtarget.isRV32() && Subtarget.enablePExtCodeGen()) {
+ if (VT != MVT::v4i8)
+ return SDValue();
+
+ // <4 x i8> BUILD_VECTOR a, b, c, d -> PACK(PPACK.DH pair(a, b), pair(c, d))
+ SDValue Val0 = DAG.getNode(ISD::BITCAST, DL, MVT::v4i8, Op->getOperand(0));
+ SDValue Val1 = DAG.getNode(ISD::BITCAST, DL, MVT::v4i8, Op->getOperand(1));
+ SDValue Val2 = DAG.getNode(ISD::BITCAST, DL, MVT::v4i8, Op->getOperand(2));
+ SDValue Val3 = DAG.getNode(ISD::BITCAST, DL, MVT::v4i8, Op->getOperand(3));
+ SDValue PackDH =
+ DAG.getNode(RISCVISD::PPACK_DH, DL, {MVT::v2i16, MVT::v2i16},
+ {Val0, Val1, Val2, Val3});
+
+ return DAG.getNode(
+ ISD::BITCAST, DL, MVT::v4i8,
+ SDValue(
+ DAG.getMachineNode(
+ RISCV::PACK, DL, MVT::i32,
+ {DAG.getNode(ISD::BITCAST, DL, MVT::i32, PackDH.getValue(0)),
+ DAG.getNode(ISD::BITCAST, DL, MVT::i32, PackDH.getValue(1))}),
+ 0));
+ }
+
// Proper support for f16 requires Zvfh. bf16 always requires special
// handling. We need to cast the scalar to integer and create an integer
// build_vector.
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index 51339d66f6de1..db5bc19aa4487 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -24,6 +24,13 @@ def SImm8UnsignedAsmOperand : SImmAsmOperand<8, "Unsigned"> {
let RenderMethod = "addSImm8UnsignedOperands";
}
+// (<2 x i16>, <2 x i16>) PPACK_DH (<4 x i8>, <4 x i8>, <4 x i8>, <4 x i8>)
+def SDT_RISCVPPackDH
+ : SDTypeProfile<2, 4, [SDTCisVT<0, v2i16>, SDTCisSameAs<0, 1>,
+ SDTCisVT<2, v4i8>, SDTCisSameAs<0, 3>,
+ SDTCisSameAs<0, 4>, SDTCisSameAs<0, 5>]>;
+def riscv_ppack_dh : RVSDNode<"PPACK_DH", SDT_RISCVPPackDH>;
+
// A 8-bit signed immediate allowing range [-128, 255]
// but represented as [-128, 127].
def simm8_unsigned : RISCVOp, ImmLeaf<XLenVT, "return isInt<8>(Imm);"> {
@@ -1530,6 +1537,13 @@ let Predicates = [HasStdExtP, IsRV32] in {
def : StPat<store, SW, GPR, v2i16>;
def : LdPat<load, LW, v4i8>;
def : LdPat<load, LW, v2i16>;
+
+ // Build vector patterns
+ def : Pat<(v4i8 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b),
+ (XLenVT GPR:$c), (XLenVT GPR:$d))),
+ (PACK (PPACK_H GPR:$a, GPR:$b), (PPACK_H GPR:$c, GPR:$d))>;
+ def : Pat<(v2i16 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b))),
+ (PACK GPR:$a, GPR:$b)>;
} // Predicates = [HasStdExtP, IsRV32]
let Predicates = [HasStdExtP, IsRV64] in {
@@ -1566,4 +1580,29 @@ let Predicates = [HasStdExtP, IsRV64] in {
def : LdPat<load, LD, v8i8>;
def : LdPat<load, LD, v4i16>;
def : LdPat<load, LD, v2i32>;
+
+ // Build vector patterns
+ def : Pat<(v8i8 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b),
+ (XLenVT GPR:$c), (XLenVT GPR:$d),
+ (XLenVT undef), (XLenVT undef),
+ (XLenVT undef), (XLenVT undef))),
+ (PPACK_W (PPACK_H GPR:$a, GPR:$b), (PPACK_H GPR:$c, GPR:$d))>;
+
+ def : Pat<(v8i8 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b),
+ (XLenVT GPR:$c), (XLenVT GPR:$d),
+ (XLenVT GPR:$e), (XLenVT GPR:$f),
+ (XLenVT GPR:$g), (XLenVT GPR:$h))),
+ (PACK(PPACK_W (PPACK_H GPR:$a, GPR:$b), (PPACK_H GPR:$c, GPR:$d)),
+ (PPACK_W (PPACK_H GPR:$e, GPR:$f), (PPACK_H GPR:$g, GPR:$h)))>;
+
+ def : Pat<(v4i16 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b),
+ (XLenVT undef), (XLenVT undef))),
+ (PPACK_W GPR:$a, GPR:$b)>;
+
+ def : Pat<(v4i16 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b),
+ (XLenVT GPR:$c), (XLenVT GPR:$d))),
+ (PACK (PPACK_W GPR:$a, GPR:$b), (PPACK_W GPR:$c, GPR:$d))>;
+
+ def : Pat<(v2i32 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b))),
+ (PACK GPR:$a, GPR:$b)>;
} // Predicates = [HasStdExtP, IsRV64]
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index bb3e691311cd8..79cf5b7903454 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -523,6 +523,47 @@ define void @test_non_const_splat_i16(ptr %ret_ptr, ptr %a_ptr, i16 %elt) {
ret void
}
+define void @test_build_vector_i8(i8 %a, i8 %c, i8 %b, i8 %d, ptr %ret_ptr) {
+; CHECK-RV32-LABEL: test_build_vector_i8:
+; CHECK-RV32: # %bb.0:
+; CHECK-RV32-NEXT: ppack.dh a0, a0, a2
+; CHECK-RV32-NEXT: pack a0, a0, a1
+; CHECK-RV32-NEXT: sw a0, 0(a4)
+; CHECK-RV32-NEXT: ret
+;
+; CHECK-RV64-LABEL: test_build_vector_i8:
+; CHECK-RV64: # %bb.0:
+; CHECK-RV64-NEXT: ppack.h a1, a1, a3
+; CHECK-RV64-NEXT: ppack.h a0, a0, a2
+; CHECK-RV64-NEXT: ppack.w a0, a0, a1
+; CHECK-RV64-NEXT: sw a0, 0(a4)
+; CHECK-RV64-NEXT: ret
+ %v0 = insertelement <4 x i8> poison, i8 %a, i32 0
+ %v1 = insertelement <4 x i8> %v0, i8 %b, i32 1
+ %v2 = insertelement <4 x i8> %v1, i8 %c, i32 2
+ %v3 = insertelement <4 x i8> %v2, i8 %d, i32 3
+ store <4 x i8> %v3, ptr %ret_ptr
+ ret void
+}
+
+define void @test_build_vector_i16(ptr %ret_ptr, i16 %a, i16 %b) {
+; CHECK-RV32-LABEL: test_build_vector_i16:
+; CHECK-RV32: # %bb.0:
+; CHECK-RV32-NEXT: pack a1, a1, a2
+; CHECK-RV32-NEXT: sw a1, 0(a0)
+; CHECK-RV32-NEXT: ret
+;
+; CHECK-RV64-LABEL: test_build_vector_i16:
+; CHECK-RV64: # %bb.0:
+; CHECK-RV64-NEXT: ppack.w a1, a1, a2
+; CHECK-RV64-NEXT: sw a1, 0(a0)
+; CHECK-RV64-NEXT: ret
+ %v0 = insertelement <2 x i16> poison, i16 %a, i32 0
+ %v1 = insertelement <2 x i16> %v0, i16 %b, i32 1
+ store <2 x i16> %v1, ptr %ret_ptr
+ ret void
+}
+
; Intrinsic declarations
declare <2 x i16> @llvm.sadd.sat.v2i16(<2 x i16>, <2 x i16>)
declare <2 x i16> @llvm.uadd.sat.v2i16(<2 x i16>, <2 x i16>)
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index f989b025a12dc..36996f0ac7ac8 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -685,6 +685,59 @@ define void @test_non_const_splat_i32(ptr %ret_ptr, ptr %a_ptr, i32 %elt) {
ret void
}
+define void @test_build_vector_i8(ptr %ret_ptr, i8 %a, i8 %b, i8 %c, i8 %d, i8 %e, i8 %f, i8 %g, i8 %h) {
+; CHECK-LABEL: test_build_vector_i8:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lbu t0, 0(sp)
+; CHECK-NEXT: ppack.h a5, a5, a6
+; CHECK-NEXT: ppack.h a3, a3, a4
+; CHECK-NEXT: ppack.h a1, a1, a2
+; CHECK-NEXT: ppack.h a2, a7, t0
+; CHECK-NEXT: ppack.w a2, a5, a2
+; CHECK-NEXT: ppack.w a1, a1, a3
+; CHECK-NEXT: pack a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %v0 = insertelement <8 x i8> poison, i8 %a, i32 0
+ %v1 = insertelement <8 x i8> %v0, i8 %b, i32 1
+ %v2 = insertelement <8 x i8> %v1, i8 %c, i32 2
+ %v3 = insertelement <8 x i8> %v2, i8 %d, i32 3
+ %v4 = insertelement <8 x i8> %v3, i8 %e, i32 4
+ %v5 = insertelement <8 x i8> %v4, i8 %f, i32 5
+ %v6 = insertelement <8 x i8> %v5, i8 %g, i32 6
+ %v7 = insertelement <8 x i8> %v6, i8 %h, i32 7
+ store <8 x i8> %v7, ptr %ret_ptr
+ ret void
+}
+
+define void @test_build_vector_i16(ptr %ret_ptr, i16 %a, i16 %b, i16 %c, i16 %d) {
+; CHECK-LABEL: test_build_vector_i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ppack.w a3, a3, a4
+; CHECK-NEXT: ppack.w a1, a1, a2
+; CHECK-NEXT: pack a1, a1, a3
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %v0 = insertelement <4 x i16> poison, i16 %a, i32 0
+ %v1 = insertelement <4 x i16> %v0, i16 %b, i32 1
+ %v2 = insertelement <4 x i16> %v1, i16 %c, i32 2
+ %v3 = insertelement <4 x i16> %v2, i16 %d, i32 3
+ store <4 x i16> %v3, ptr %ret_ptr
+ ret void
+}
+
+define void @test_build_vector_i32(ptr %ret_ptr, i32 %a, i32 %b) {
+; CHECK-LABEL: test_build_vector_i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: pack a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %v0 = insertelement <2 x i32> poison, i32 %a, i32 0
+ %v1 = insertelement <2 x i32> %v0, i32 %b, i32 1
+ store <2 x i32> %v1, ptr %ret_ptr
+ ret void
+}
+
; Intrinsic declarations
declare <4 x i16> @llvm.sadd.sat.v4i16(<4 x i16>, <4 x i16>)
declare <4 x i16> @llvm.uadd.sat.v4i16(<4 x i16>, <4 x i16>)
|
| return SDValue(); | ||
|
|
||
| // <4 x i8> BUILD_VECTOR a, b, c, d -> PACK(PPACK.DH pair(a, b), pair(c, d)) | ||
| SDValue Val0 = DAG.getNode(ISD::BITCAST, DL, MVT::v4i8, Op->getOperand(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be SCALAR_TO_VECTOR. Other optimizations understand that as a transition from scalar to vector better than a bitcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for information!
| def : LdPat<load, LW, v2i16>; | ||
|
|
||
| // Build vector patterns | ||
| def : Pat<(v4i8 (build_vector (XLenVT GPR:$a), (XLenVT GPR:$b), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this custom legalized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I forgot to remove it lol
🐧 Linux x64 Test Results
|
No description provided.