Skip to content

[BOLT] Gadget scanner: use more appropriate types (NFC) #135661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions bolt/include/bolt/Passes/PAuthGadgetScanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "bolt/Core/BinaryContext.h"
#include "bolt/Core/BinaryFunction.h"
#include "bolt/Passes/BinaryPasses.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>

Expand Down Expand Up @@ -197,9 +196,6 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &);

namespace PAuthGadgetScanner {

class SrcSafetyAnalysis;
struct SrcState;

/// Description of a gadget kind that can be detected. Intended to be
/// statically allocated to be attached to reports by reference.
class GadgetKind {
Expand All @@ -208,7 +204,7 @@ class GadgetKind {
public:
GadgetKind(const char *Description) : Description(Description) {}

const StringRef getDescription() const { return Description; }
StringRef getDescription() const { return Description; }
};

/// Base report located at some instruction, without any additional information.
Expand All @@ -223,8 +219,8 @@ struct Report {

// The two methods below are called by Analysis::computeDetailedInfo when
// iterating over the reports.
virtual const ArrayRef<MCPhysReg> getAffectedRegisters() const { return {}; }
virtual void setOverwritingInstrs(const ArrayRef<MCInstReference> Instrs) {}
virtual ArrayRef<MCPhysReg> getAffectedRegisters() const { return {}; }
virtual void setOverwritingInstrs(ArrayRef<MCInstReference> Instrs) {}

void printBasicInfo(raw_ostream &OS, const BinaryContext &BC,
StringRef IssueKind) const;
Expand All @@ -247,19 +243,19 @@ struct GadgetReport : public Report {

void generateReport(raw_ostream &OS, const BinaryContext &BC) const override;

const ArrayRef<MCPhysReg> getAffectedRegisters() const override {
ArrayRef<MCPhysReg> getAffectedRegisters() const override {
return AffectedRegisters;
}

void setOverwritingInstrs(const ArrayRef<MCInstReference> Instrs) override {
void setOverwritingInstrs(ArrayRef<MCInstReference> Instrs) override {
OverwritingInstrs.assign(Instrs.begin(), Instrs.end());
}
};

/// Report with a free-form message attached.
struct GenericReport : public Report {
std::string Text;
GenericReport(MCInstReference Location, const std::string &Text)
GenericReport(MCInstReference Location, StringRef Text)
: Report(Location), Text(Text) {}
virtual void generateReport(raw_ostream &OS,
const BinaryContext &BC) const override;
Expand Down
48 changes: 21 additions & 27 deletions bolt/lib/Passes/PAuthGadgetScanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,21 @@ class TrackedRegisters {
const std::vector<MCPhysReg> Registers;
std::vector<uint16_t> RegToIndexMapping;

static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) {
static size_t getMappingSize(ArrayRef<MCPhysReg> RegsToTrack) {
if (RegsToTrack.empty())
return 0;
return 1 + *llvm::max_element(RegsToTrack);
}

public:
TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack)
TrackedRegisters(ArrayRef<MCPhysReg> RegsToTrack)
: Registers(RegsToTrack),
RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) {
for (unsigned I = 0; I < RegsToTrack.size(); ++I)
RegToIndexMapping[RegsToTrack[I]] = I;
}

const ArrayRef<MCPhysReg> getRegisters() const { return Registers; }
ArrayRef<MCPhysReg> getRegisters() const { return Registers; }

size_t getNumTrackedRegisters() const { return Registers.size(); }

Expand Down Expand Up @@ -232,9 +232,9 @@ struct SrcState {
bool operator!=(const SrcState &RHS) const { return !((*this) == RHS); }
};

static void printLastInsts(
raw_ostream &OS,
const std::vector<SmallPtrSet<const MCInst *, 4>> &LastInstWritingReg) {
static void
printLastInsts(raw_ostream &OS,
ArrayRef<SmallPtrSet<const MCInst *, 4>> LastInstWritingReg) {
OS << "Insts: ";
for (unsigned I = 0; I < LastInstWritingReg.size(); ++I) {
auto &Set = LastInstWritingReg[I];
Expand Down Expand Up @@ -294,20 +294,18 @@ void SrcStatePrinter::print(raw_ostream &OS, const SrcState &S) const {
/// version for functions without reconstructed CFG.
class SrcSafetyAnalysis {
public:
SrcSafetyAnalysis(BinaryFunction &BF,
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
SrcSafetyAnalysis(BinaryFunction &BF, ArrayRef<MCPhysReg> RegsToTrackInstsFor)
: BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()),
RegsToTrackInstsFor(RegsToTrackInstsFor) {}

virtual ~SrcSafetyAnalysis() {}

static std::shared_ptr<SrcSafetyAnalysis>
create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
const std::vector<MCPhysReg> &RegsToTrackInstsFor);
ArrayRef<MCPhysReg> RegsToTrackInstsFor);

virtual void run() = 0;
virtual ErrorOr<const SrcState &>
getStateBefore(const MCInst &Inst) const = 0;
virtual const SrcState &getStateBefore(const MCInst &Inst) const = 0;

protected:
BinaryContext &BC;
Expand Down Expand Up @@ -347,7 +345,7 @@ class SrcSafetyAnalysis {
}

BitVector getClobberedRegs(const MCInst &Point) const {
BitVector Clobbered(NumRegs, false);
BitVector Clobbered(NumRegs);
// Assume a call can clobber all registers, including callee-saved
// registers. There's a good chance that callee-saved registers will be
// saved on the stack at some point during execution of the callee.
Expand Down Expand Up @@ -409,8 +407,7 @@ class SrcSafetyAnalysis {
// FirstCheckerInst should belong to the same basic block (see the
// assertion in DataflowSrcSafetyAnalysis::run()), meaning it was
// deterministically processed a few steps before this instruction.
const SrcState &StateBeforeChecker =
getStateBefore(*FirstCheckerInst).get();
const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst);
if (StateBeforeChecker.SafeToDerefRegs[CheckedReg])
Regs.push_back(CheckedReg);
}
Expand Down Expand Up @@ -520,13 +517,10 @@ class SrcSafetyAnalysis {
public:
std::vector<MCInstReference>
getLastClobberingInsts(const MCInst &Inst, BinaryFunction &BF,
const ArrayRef<MCPhysReg> UsedDirtyRegs) const {
ArrayRef<MCPhysReg> UsedDirtyRegs) const {
if (RegsToTrackInstsFor.empty())
return {};
auto MaybeState = getStateBefore(Inst);
if (!MaybeState)
llvm_unreachable("Expected state to be present");
const SrcState &S = *MaybeState;
const SrcState &S = getStateBefore(Inst);
// Due to aliasing registers, multiple registers may have been tracked.
std::set<const MCInst *> LastWritingInsts;
for (MCPhysReg TrackedReg : UsedDirtyRegs) {
Expand All @@ -537,7 +531,7 @@ class SrcSafetyAnalysis {
for (const MCInst *Inst : LastWritingInsts) {
MCInstReference Ref = MCInstReference::get(Inst, BF);
assert(Ref && "Expected Inst to be found");
Result.push_back(MCInstReference(Ref));
Result.push_back(Ref);
}
return Result;
}
Expand All @@ -557,11 +551,11 @@ class DataflowSrcSafetyAnalysis
public:
DataflowSrcSafetyAnalysis(BinaryFunction &BF,
MCPlusBuilder::AllocatorIdTy AllocId,
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
ArrayRef<MCPhysReg> RegsToTrackInstsFor)
: SrcSafetyAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {}

ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
return DFParent::getStateBefore(Inst);
const SrcState &getStateBefore(const MCInst &Inst) const override {
return DFParent::getStateBefore(Inst).get();
}

void run() override {
Expand Down Expand Up @@ -674,7 +668,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
public:
CFGUnawareSrcSafetyAnalysis(BinaryFunction &BF,
MCPlusBuilder::AllocatorIdTy AllocId,
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
ArrayRef<MCPhysReg> RegsToTrackInstsFor)
: SrcSafetyAnalysis(BF, RegsToTrackInstsFor), BF(BF), AllocId(AllocId) {
StateAnnotationIndex =
BC.MIB->getOrCreateAnnotationIndex("CFGUnawareSrcSafetyAnalysis");
Expand Down Expand Up @@ -708,7 +702,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
}
}

ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
const SrcState &getStateBefore(const MCInst &Inst) const override {
return BC.MIB->getAnnotationAs<SrcState>(Inst, StateAnnotationIndex);
}

Expand All @@ -718,7 +712,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
std::shared_ptr<SrcSafetyAnalysis>
SrcSafetyAnalysis::create(BinaryFunction &BF,
MCPlusBuilder::AllocatorIdTy AllocId,
const std::vector<MCPhysReg> &RegsToTrackInstsFor) {
ArrayRef<MCPhysReg> RegsToTrackInstsFor) {
if (BF.hasCFG())
return std::make_shared<DataflowSrcSafetyAnalysis>(BF, AllocId,
RegsToTrackInstsFor);
Expand Down Expand Up @@ -825,7 +819,7 @@ Analysis::findGadgets(BinaryFunction &BF,

BinaryContext &BC = BF.getBinaryContext();
iterateOverInstrs(BF, [&](MCInstReference Inst) {
const SrcState &S = *Analysis->getStateBefore(Inst);
const SrcState &S = Analysis->getStateBefore(Inst);

// If non-empty state was never propagated from the entry basic block
// to Inst, assume it to be unreachable and report a warning.
Expand Down
Loading