Skip to content

[BOLT] Gadget scanner: refine class names and debug output (NFC) #135073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bolt/include/bolt/Passes/PAuthGadgetScanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &);

namespace PAuthGadgetScanner {

class PacRetAnalysis;
struct State;
class SrcSafetyAnalysis;
struct SrcState;

/// Description of a gadget kind that can be detected. Intended to be
/// statically allocated to be attached to reports by reference.
Expand Down
103 changes: 53 additions & 50 deletions bolt/lib/Passes/PAuthGadgetScanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class TrackedRegisters {
/// * RET (which is implicitly RET X30) is a protected return if and only if
/// X30 is safe-to-dereference - the state computed for sub- and
/// super-registers is not inspected.
struct State {
struct SrcState {
/// A BitVector containing the registers that are either safe at function
/// entry and were not clobbered yet, or those not clobbered since being
/// authenticated.
Expand All @@ -186,12 +186,12 @@ struct State {
std::vector<SmallPtrSet<const MCInst *, 4>> LastInstWritingReg;

/// Construct an empty state.
State() {}
SrcState() {}

State(unsigned NumRegs, unsigned NumRegsToTrack)
SrcState(unsigned NumRegs, unsigned NumRegsToTrack)
: SafeToDerefRegs(NumRegs), LastInstWritingReg(NumRegsToTrack) {}

State &merge(const State &StateIn) {
SrcState &merge(const SrcState &StateIn) {
if (StateIn.empty())
return *this;
if (empty())
Expand All @@ -208,11 +208,11 @@ struct State {
/// neither safe, nor unsafe ones.
bool empty() const { return SafeToDerefRegs.empty(); }

bool operator==(const State &RHS) const {
bool operator==(const SrcState &RHS) const {
return SafeToDerefRegs == RHS.SafeToDerefRegs &&
LastInstWritingReg == RHS.LastInstWritingReg;
}
bool operator!=(const State &RHS) const { return !((*this) == RHS); }
bool operator!=(const SrcState &RHS) const { return !((*this) == RHS); }
};

static void printLastInsts(
Expand All @@ -228,8 +228,8 @@ static void printLastInsts(
}
}

raw_ostream &operator<<(raw_ostream &OS, const State &S) {
OS << "pacret-state<";
raw_ostream &operator<<(raw_ostream &OS, const SrcState &S) {
OS << "src-state<";
if (S.empty()) {
OS << "empty";
} else {
Expand All @@ -240,18 +240,18 @@ raw_ostream &operator<<(raw_ostream &OS, const State &S) {
return OS;
}

class PacStatePrinter {
class SrcStatePrinter {
public:
void print(raw_ostream &OS, const State &State) const;
explicit PacStatePrinter(const BinaryContext &BC) : BC(BC) {}
void print(raw_ostream &OS, const SrcState &State) const;
explicit SrcStatePrinter(const BinaryContext &BC) : BC(BC) {}

private:
const BinaryContext &BC;
};

void PacStatePrinter::print(raw_ostream &OS, const State &S) const {
void SrcStatePrinter::print(raw_ostream &OS, const SrcState &S) const {
RegStatePrinter RegStatePrinter(BC);
OS << "pacret-state<";
OS << "src-state<";
if (S.empty()) {
assert(S.SafeToDerefRegs.empty());
assert(S.LastInstWritingReg.empty());
Expand All @@ -265,71 +265,71 @@ void PacStatePrinter::print(raw_ostream &OS, const State &S) const {
OS << ">";
}

class PacRetAnalysis
: public DataflowAnalysis<PacRetAnalysis, State, /*Backward=*/false,
PacStatePrinter> {
class SrcSafetyAnalysis
: public DataflowAnalysis<SrcSafetyAnalysis, SrcState, /*Backward=*/false,
SrcStatePrinter> {
using Parent =
DataflowAnalysis<PacRetAnalysis, State, false, PacStatePrinter>;
DataflowAnalysis<SrcSafetyAnalysis, SrcState, false, SrcStatePrinter>;
friend Parent;

public:
PacRetAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
SrcSafetyAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
: Parent(BF, AllocId), NumRegs(BF.getBinaryContext().MRI->getNumRegs()),
RegsToTrackInstsFor(RegsToTrackInstsFor) {}
virtual ~PacRetAnalysis() {}
virtual ~SrcSafetyAnalysis() {}

protected:
const unsigned NumRegs;
/// RegToTrackInstsFor is the set of registers for which the dataflow analysis
/// must compute which the last set of instructions writing to it are.
const TrackedRegisters RegsToTrackInstsFor;

SmallPtrSet<const MCInst *, 4> &lastWritingInsts(State &S,
SmallPtrSet<const MCInst *, 4> &lastWritingInsts(SrcState &S,
MCPhysReg Reg) const {
unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
return S.LastInstWritingReg[Index];
}
const SmallPtrSet<const MCInst *, 4> &lastWritingInsts(const State &S,
const SmallPtrSet<const MCInst *, 4> &lastWritingInsts(const SrcState &S,
MCPhysReg Reg) const {
unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
return S.LastInstWritingReg[Index];
}

void preflight() {}

State createEntryState() {
State S(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
SrcState createEntryState() {
SrcState S(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
for (MCPhysReg Reg : BC.MIB->getTrustedLiveInRegs())
S.SafeToDerefRegs |= BC.MIB->getAliases(Reg, /*OnlySmaller=*/true);
return S;
}

State getStartingStateAtBB(const BinaryBasicBlock &BB) {
SrcState getStartingStateAtBB(const BinaryBasicBlock &BB) {
if (BB.isEntryPoint())
return createEntryState();

return State();
return SrcState();
}

State getStartingStateAtPoint(const MCInst &Point) { return State(); }
SrcState getStartingStateAtPoint(const MCInst &Point) { return SrcState(); }

void doConfluence(State &StateOut, const State &StateIn) {
PacStatePrinter P(BC);
void doConfluence(SrcState &StateOut, const SrcState &StateIn) {
SrcStatePrinter P(BC);
LLVM_DEBUG({
dbgs() << " PacRetAnalysis::Confluence(\n";
dbgs() << " State 1: ";
dbgs() << " SrcSafetyAnalysis::Confluence(\n";
dbgs() << " State 1: ";
P.print(dbgs(), StateOut);
dbgs() << "\n";
dbgs() << " State 2: ";
dbgs() << " State 2: ";
P.print(dbgs(), StateIn);
dbgs() << ")\n";
});

StateOut.merge(StateIn);

LLVM_DEBUG({
dbgs() << " merged state: ";
dbgs() << " merged state: ";
P.print(dbgs(), StateOut);
dbgs() << "\n";
});
Expand All @@ -354,7 +354,7 @@ class PacRetAnalysis
// Returns all registers that can be treated as if they are written by an
// authentication instruction.
SmallVector<MCPhysReg> getRegsMadeSafeToDeref(const MCInst &Point,
const State &Cur) const {
const SrcState &Cur) const {
SmallVector<MCPhysReg> Regs;
const MCPhysReg NoReg = BC.MIB->getNoRegister();

Expand All @@ -378,10 +378,10 @@ class PacRetAnalysis
return Regs;
}

State computeNext(const MCInst &Point, const State &Cur) {
PacStatePrinter P(BC);
SrcState computeNext(const MCInst &Point, const SrcState &Cur) {
SrcStatePrinter P(BC);
LLVM_DEBUG({
dbgs() << " PacRetAnalysis::ComputeNext(";
dbgs() << " SrcSafetyAnalysis::ComputeNext(";
BC.InstPrinter->printInst(&const_cast<MCInst &>(Point), 0, "", *BC.STI,
dbgs());
dbgs() << ", ";
Expand All @@ -395,7 +395,7 @@ class PacRetAnalysis
if (Cur.empty()) {
LLVM_DEBUG(
{ dbgs() << "Skipping computeNext(Point, Cur) as Cur is empty.\n"; });
return State();
return SrcState();
}

// First, compute various properties of the instruction, taking the state
Expand All @@ -406,7 +406,7 @@ class PacRetAnalysis
getRegsMadeSafeToDeref(Point, Cur);

// Then, compute the state after this instruction is executed.
State Next = Cur;
SrcState Next = Cur;

Next.SafeToDerefRegs.reset(Clobbered);
// Keep track of this instruction if it writes to any of the registers we
Expand All @@ -430,15 +430,15 @@ class PacRetAnalysis
}

LLVM_DEBUG({
dbgs() << " .. result: (";
dbgs() << " .. result: (";
P.print(dbgs(), Next);
dbgs() << ")\n";
});

return Next;
}

StringRef getAnnotationName() const { return StringRef("PacRetAnalysis"); }
StringRef getAnnotationName() const { return StringRef("SrcSafetyAnalysis"); }

public:
std::vector<MCInstReference>
Expand All @@ -448,8 +448,8 @@ class PacRetAnalysis
return {};
auto MaybeState = getStateBefore(Inst);
if (!MaybeState)
llvm_unreachable("Expected State to be present");
const State &S = *MaybeState;
llvm_unreachable("Expected state to be present");
const SrcState &S = *MaybeState;
// Due to aliasing registers, multiple registers may have been tracked.
std::set<const MCInst *> LastWritingInsts;
for (MCPhysReg TrackedReg : UsedDirtyRegs) {
Expand All @@ -468,7 +468,7 @@ class PacRetAnalysis

static std::shared_ptr<Report>
shouldReportReturnGadget(const BinaryContext &BC, const MCInstReference &Inst,
const State &S) {
const SrcState &S) {
static const GadgetKind RetKind("non-protected ret found");
if (!BC.MIB->isReturn(Inst))
return nullptr;
Expand Down Expand Up @@ -496,7 +496,7 @@ shouldReportReturnGadget(const BinaryContext &BC, const MCInstReference &Inst,

static std::shared_ptr<Report>
shouldReportCallGadget(const BinaryContext &BC, const MCInstReference &Inst,
const State &S) {
const SrcState &S) {
static const GadgetKind CallKind("non-protected call found");
if (!BC.MIB->isIndirectCall(Inst) && !BC.MIB->isIndirectBranch(Inst))
return nullptr;
Expand Down Expand Up @@ -524,18 +524,19 @@ Analysis::findGadgets(BinaryFunction &BF,
MCPlusBuilder::AllocatorIdTy AllocatorId) {
FunctionAnalysisResult Result;

PacRetAnalysis PRA(BF, AllocatorId, {});
SrcSafetyAnalysis PRA(BF, AllocatorId, {});
LLVM_DEBUG({ dbgs() << "Running src register safety analysis...\n"; });
PRA.run();
LLVM_DEBUG({
dbgs() << " After PacRetAnalysis:\n";
dbgs() << "After src register safety analysis:\n";
BF.dump();
});

BinaryContext &BC = BF.getBinaryContext();
for (BinaryBasicBlock &BB : BF) {
for (int64_t I = 0, E = BB.size(); I < E; ++I) {
MCInstReference Inst(&BB, I);
const State &S = *PRA.getStateBefore(Inst);
const SrcState &S = *PRA.getStateBefore(Inst);

// If non-empty state was never propagated from the entry basic block
// to Inst, assume it to be unreachable and report a warning.
Expand Down Expand Up @@ -570,10 +571,12 @@ void Analysis::computeDetailedInfo(BinaryFunction &BF,
std::vector<MCPhysReg> RegsToTrackVec(RegsToTrack.begin(), RegsToTrack.end());

// Re-compute the analysis with register tracking.
PacRetAnalysis PRWIA(BF, AllocatorId, RegsToTrackVec);
SrcSafetyAnalysis PRWIA(BF, AllocatorId, RegsToTrackVec);
LLVM_DEBUG(
{ dbgs() << "\nRunning detailed src register safety analysis...\n"; });
PRWIA.run();
LLVM_DEBUG({
dbgs() << " After detailed PacRetAnalysis:\n";
dbgs() << "After detailed src register safety analysis:\n";
BF.dump();
});

Expand Down
Loading