diff --git a/bolt/include/bolt/Core/MCPlusBuilder.h b/bolt/include/bolt/Core/MCPlusBuilder.h index bbef65700b2a5..5b7779f255556 100644 --- a/bolt/include/bolt/Core/MCPlusBuilder.h +++ b/bolt/include/bolt/Core/MCPlusBuilder.h @@ -577,12 +577,12 @@ class MCPlusBuilder { return getNoRegister(); } - /// Returns the register used as call destination, or no-register, if not - /// an indirect call. Sets IsAuthenticatedInternally if the instruction - /// accepts a signed pointer as its operand and authenticates it internally. + /// Returns the register used as the destination of an indirect branch or call + /// instruction. Sets IsAuthenticatedInternally if the instruction accepts + /// a signed pointer as its operand and authenticates it internally. virtual MCPhysReg - getRegUsedAsCallDest(const MCInst &Inst, - bool &IsAuthenticatedInternally) const { + getRegUsedAsIndirectBranchDest(const MCInst &Inst, + bool &IsAuthenticatedInternally) const { llvm_unreachable("not implemented"); return getNoRegister(); } diff --git a/bolt/lib/Passes/PAuthGadgetScanner.cpp b/bolt/lib/Passes/PAuthGadgetScanner.cpp index a3b320c545734..8710eba77097d 100644 --- a/bolt/lib/Passes/PAuthGadgetScanner.cpp +++ b/bolt/lib/Passes/PAuthGadgetScanner.cpp @@ -457,14 +457,16 @@ static std::shared_ptr shouldReportCallGadget(const BinaryContext &BC, const MCInstReference &Inst, const State &S) { static const GadgetKind CallKind("non-protected call found"); - if (!BC.MIB->isCall(Inst) && !BC.MIB->isBranch(Inst)) + if (!BC.MIB->isIndirectCall(Inst) && !BC.MIB->isIndirectBranch(Inst)) return nullptr; bool IsAuthenticated = false; - MCPhysReg DestReg = BC.MIB->getRegUsedAsCallDest(Inst, IsAuthenticated); - if (IsAuthenticated || DestReg == BC.MIB->getNoRegister()) + MCPhysReg DestReg = + BC.MIB->getRegUsedAsIndirectBranchDest(Inst, IsAuthenticated); + if (IsAuthenticated) return nullptr; + assert(DestReg != BC.MIB->getNoRegister()); LLVM_DEBUG({ traceInst(BC, "Found call inst", Inst); traceReg(BC, "Call destination reg", DestReg); diff --git a/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp b/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp index 2a648baa4d514..5ecc30b8bb107 100644 --- a/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp +++ b/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "AArch64InstrInfo.h" #include "AArch64MCSymbolizer.h" #include "MCTargetDesc/AArch64AddressingModes.h" #include "MCTargetDesc/AArch64FixupKinds.h" @@ -277,15 +278,14 @@ class AArch64MCPlusBuilder : public MCPlusBuilder { } } - MCPhysReg - getRegUsedAsCallDest(const MCInst &Inst, - bool &IsAuthenticatedInternally) const override { - assert(isCall(Inst) || isBranch(Inst)); - IsAuthenticatedInternally = false; + MCPhysReg getRegUsedAsIndirectBranchDest( + const MCInst &Inst, bool &IsAuthenticatedInternally) const override { + assert(isIndirectCall(Inst) || isIndirectBranch(Inst)); switch (Inst.getOpcode()) { case AArch64::BR: case AArch64::BLR: + IsAuthenticatedInternally = false; return Inst.getOperand(0).getReg(); case AArch64::BRAA: case AArch64::BRAB: @@ -298,9 +298,7 @@ class AArch64MCPlusBuilder : public MCPlusBuilder { IsAuthenticatedInternally = true; return Inst.getOperand(0).getReg(); default: - if (isIndirectCall(Inst) || isIndirectBranch(Inst)) - llvm_unreachable("Unhandled indirect branch"); - return getNoRegister(); + llvm_unreachable("Unhandled indirect branch or call"); } } @@ -662,7 +660,7 @@ class AArch64MCPlusBuilder : public MCPlusBuilder { } bool isIndirectCall(const MCInst &Inst) const override { - return Inst.getOpcode() == AArch64::BLR; + return isIndirectCallOpcode(Inst.getOpcode()); } MCPhysReg getSpRegister(int Size) const { diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h index b3d3ec1455c8b..0ffaca9af4006 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -726,6 +726,19 @@ static inline bool isIndirectBranchOpcode(int Opc) { return false; } +static inline bool isIndirectCallOpcode(unsigned Opc) { + switch (Opc) { + case AArch64::BLR: + case AArch64::BLRAA: + case AArch64::BLRAB: + case AArch64::BLRAAZ: + case AArch64::BLRABZ: + return true; + default: + return false; + } +} + static inline bool isPTrueOpcode(unsigned Opc) { switch (Opc) { case AArch64::PTRUE_B: