diff --git a/bolt/include/bolt/Passes/PAuthGadgetScanner.h b/bolt/include/bolt/Passes/PAuthGadgetScanner.h index 75a8d26c64537..ccfe632889c7a 100644 --- a/bolt/include/bolt/Passes/PAuthGadgetScanner.h +++ b/bolt/include/bolt/Passes/PAuthGadgetScanner.h @@ -12,7 +12,6 @@ #include "bolt/Core/BinaryContext.h" #include "bolt/Core/BinaryFunction.h" #include "bolt/Passes/BinaryPasses.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/Support/raw_ostream.h" #include @@ -197,9 +196,6 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &); namespace PAuthGadgetScanner { -class SrcSafetyAnalysis; -struct SrcState; - /// Description of a gadget kind that can be detected. Intended to be /// statically allocated to be attached to reports by reference. class GadgetKind { @@ -208,7 +204,7 @@ class GadgetKind { public: GadgetKind(const char *Description) : Description(Description) {} - const StringRef getDescription() const { return Description; } + StringRef getDescription() const { return Description; } }; /// Base report located at some instruction, without any additional information. @@ -223,8 +219,8 @@ struct Report { // The two methods below are called by Analysis::computeDetailedInfo when // iterating over the reports. - virtual const ArrayRef getAffectedRegisters() const { return {}; } - virtual void setOverwritingInstrs(const ArrayRef Instrs) {} + virtual ArrayRef getAffectedRegisters() const { return {}; } + virtual void setOverwritingInstrs(ArrayRef Instrs) {} void printBasicInfo(raw_ostream &OS, const BinaryContext &BC, StringRef IssueKind) const; @@ -247,11 +243,11 @@ struct GadgetReport : public Report { void generateReport(raw_ostream &OS, const BinaryContext &BC) const override; - const ArrayRef getAffectedRegisters() const override { + ArrayRef getAffectedRegisters() const override { return AffectedRegisters; } - void setOverwritingInstrs(const ArrayRef Instrs) override { + void setOverwritingInstrs(ArrayRef Instrs) override { OverwritingInstrs.assign(Instrs.begin(), Instrs.end()); } }; @@ -259,7 +255,7 @@ struct GadgetReport : public Report { /// Report with a free-form message attached. struct GenericReport : public Report { std::string Text; - GenericReport(MCInstReference Location, const std::string &Text) + GenericReport(MCInstReference Location, StringRef Text) : Report(Location), Text(Text) {} virtual void generateReport(raw_ostream &OS, const BinaryContext &BC) const override; diff --git a/bolt/lib/Passes/PAuthGadgetScanner.cpp b/bolt/lib/Passes/PAuthGadgetScanner.cpp index 12eb9c66130b9..92608aebce3ee 100644 --- a/bolt/lib/Passes/PAuthGadgetScanner.cpp +++ b/bolt/lib/Passes/PAuthGadgetScanner.cpp @@ -91,21 +91,21 @@ class TrackedRegisters { const std::vector Registers; std::vector RegToIndexMapping; - static size_t getMappingSize(const std::vector &RegsToTrack) { + static size_t getMappingSize(ArrayRef RegsToTrack) { if (RegsToTrack.empty()) return 0; return 1 + *llvm::max_element(RegsToTrack); } public: - TrackedRegisters(const std::vector &RegsToTrack) + TrackedRegisters(ArrayRef RegsToTrack) : Registers(RegsToTrack), RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) { for (unsigned I = 0; I < RegsToTrack.size(); ++I) RegToIndexMapping[RegsToTrack[I]] = I; } - const ArrayRef getRegisters() const { return Registers; } + ArrayRef getRegisters() const { return Registers; } size_t getNumTrackedRegisters() const { return Registers.size(); } @@ -232,9 +232,9 @@ struct SrcState { bool operator!=(const SrcState &RHS) const { return !((*this) == RHS); } }; -static void printLastInsts( - raw_ostream &OS, - const std::vector> &LastInstWritingReg) { +static void +printLastInsts(raw_ostream &OS, + ArrayRef> LastInstWritingReg) { OS << "Insts: "; for (unsigned I = 0; I < LastInstWritingReg.size(); ++I) { auto &Set = LastInstWritingReg[I]; @@ -294,8 +294,7 @@ void SrcStatePrinter::print(raw_ostream &OS, const SrcState &S) const { /// version for functions without reconstructed CFG. class SrcSafetyAnalysis { public: - SrcSafetyAnalysis(BinaryFunction &BF, - const std::vector &RegsToTrackInstsFor) + SrcSafetyAnalysis(BinaryFunction &BF, ArrayRef RegsToTrackInstsFor) : BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()), RegsToTrackInstsFor(RegsToTrackInstsFor) {} @@ -303,11 +302,10 @@ class SrcSafetyAnalysis { static std::shared_ptr create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId, - const std::vector &RegsToTrackInstsFor); + ArrayRef RegsToTrackInstsFor); virtual void run() = 0; - virtual ErrorOr - getStateBefore(const MCInst &Inst) const = 0; + virtual const SrcState &getStateBefore(const MCInst &Inst) const = 0; protected: BinaryContext &BC; @@ -347,7 +345,7 @@ class SrcSafetyAnalysis { } BitVector getClobberedRegs(const MCInst &Point) const { - BitVector Clobbered(NumRegs, false); + BitVector Clobbered(NumRegs); // Assume a call can clobber all registers, including callee-saved // registers. There's a good chance that callee-saved registers will be // saved on the stack at some point during execution of the callee. @@ -409,8 +407,7 @@ class SrcSafetyAnalysis { // FirstCheckerInst should belong to the same basic block (see the // assertion in DataflowSrcSafetyAnalysis::run()), meaning it was // deterministically processed a few steps before this instruction. - const SrcState &StateBeforeChecker = - getStateBefore(*FirstCheckerInst).get(); + const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst); if (StateBeforeChecker.SafeToDerefRegs[CheckedReg]) Regs.push_back(CheckedReg); } @@ -520,13 +517,10 @@ class SrcSafetyAnalysis { public: std::vector getLastClobberingInsts(const MCInst &Inst, BinaryFunction &BF, - const ArrayRef UsedDirtyRegs) const { + ArrayRef UsedDirtyRegs) const { if (RegsToTrackInstsFor.empty()) return {}; - auto MaybeState = getStateBefore(Inst); - if (!MaybeState) - llvm_unreachable("Expected state to be present"); - const SrcState &S = *MaybeState; + const SrcState &S = getStateBefore(Inst); // Due to aliasing registers, multiple registers may have been tracked. std::set LastWritingInsts; for (MCPhysReg TrackedReg : UsedDirtyRegs) { @@ -537,7 +531,7 @@ class SrcSafetyAnalysis { for (const MCInst *Inst : LastWritingInsts) { MCInstReference Ref = MCInstReference::get(Inst, BF); assert(Ref && "Expected Inst to be found"); - Result.push_back(MCInstReference(Ref)); + Result.push_back(Ref); } return Result; } @@ -557,11 +551,11 @@ class DataflowSrcSafetyAnalysis public: DataflowSrcSafetyAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId, - const std::vector &RegsToTrackInstsFor) + ArrayRef RegsToTrackInstsFor) : SrcSafetyAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {} - ErrorOr getStateBefore(const MCInst &Inst) const override { - return DFParent::getStateBefore(Inst); + const SrcState &getStateBefore(const MCInst &Inst) const override { + return DFParent::getStateBefore(Inst).get(); } void run() override { @@ -674,7 +668,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis { public: CFGUnawareSrcSafetyAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId, - const std::vector &RegsToTrackInstsFor) + ArrayRef RegsToTrackInstsFor) : SrcSafetyAnalysis(BF, RegsToTrackInstsFor), BF(BF), AllocId(AllocId) { StateAnnotationIndex = BC.MIB->getOrCreateAnnotationIndex("CFGUnawareSrcSafetyAnalysis"); @@ -708,7 +702,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis { } } - ErrorOr getStateBefore(const MCInst &Inst) const override { + const SrcState &getStateBefore(const MCInst &Inst) const override { return BC.MIB->getAnnotationAs(Inst, StateAnnotationIndex); } @@ -718,7 +712,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis { std::shared_ptr SrcSafetyAnalysis::create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId, - const std::vector &RegsToTrackInstsFor) { + ArrayRef RegsToTrackInstsFor) { if (BF.hasCFG()) return std::make_shared(BF, AllocId, RegsToTrackInstsFor); @@ -825,7 +819,7 @@ Analysis::findGadgets(BinaryFunction &BF, BinaryContext &BC = BF.getBinaryContext(); iterateOverInstrs(BF, [&](MCInstReference Inst) { - const SrcState &S = *Analysis->getStateBefore(Inst); + const SrcState &S = Analysis->getStateBefore(Inst); // If non-empty state was never propagated from the entry basic block // to Inst, assume it to be unreachable and report a warning.