-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[BOLT] Gadget scanner: factor out utility code #131895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BOLT] Gadget scanner: factor out utility code #131895
Conversation
Factor out the code for mapping from physical registers to consecutive array indexes. Introduce helper functions to print instructions and registers to prevent mixing of analysis logic and implementation details of debug output. Removed the debug printing from `Gadget::generateReport`, as it doesn't seem to add important information to what was already printed in the report itself.
@llvm/pr-subscribers-bolt Author: Anatoly Trosinenko (atrosinenko) ChangesFactor out the code for mapping from physical registers to consecutive Introduce helper functions to print instructions and registers to Removed the debug printing from Full diff: https://github.com/llvm/llvm-project/pull/131895.diff 1 Files Affected:
diff --git a/bolt/lib/Passes/NonPacProtectedRetAnalysis.cpp b/bolt/lib/Passes/NonPacProtectedRetAnalysis.cpp
index dc7cb275f5664..77a16379a14b9 100644
--- a/bolt/lib/Passes/NonPacProtectedRetAnalysis.cpp
+++ b/bolt/lib/Passes/NonPacProtectedRetAnalysis.cpp
@@ -14,6 +14,7 @@
#include "bolt/Passes/NonPacProtectedRetAnalysis.h"
#include "bolt/Core/ParallelUtilities.h"
#include "bolt/Passes/DataflowAnalysis.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/MC/MCInst.h"
#include "llvm/Support/Format.h"
@@ -58,6 +59,71 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &Ref) {
namespace NonPacProtectedRetAnalysis {
+static void traceInst(const BinaryContext &BC, StringRef Label,
+ const MCInst &MI) {
+ dbgs() << " " << Label << ": ";
+ BC.printInstruction(dbgs(), MI);
+}
+
+static void traceReg(const BinaryContext &BC, StringRef Label,
+ ErrorOr<MCPhysReg> Reg) {
+ dbgs() << " " << Label << ": ";
+ if (Reg.getError())
+ dbgs() << "(error)";
+ else if (*Reg == BC.MIB->getNoRegister())
+ dbgs() << "(none)";
+ else
+ dbgs() << BC.MRI->getName(*Reg);
+ dbgs() << "\n";
+}
+
+static void traceRegMask(const BinaryContext &BC, StringRef Label,
+ BitVector Mask) {
+ dbgs() << " " << Label << ": ";
+ RegStatePrinter(BC).print(dbgs(), Mask);
+ dbgs() << "\n";
+}
+
+// This class represents mapping from arbitrary physical registers to
+// consecutive array indexes.
+class TrackedRegisters {
+ static const uint16_t NoIndex = -1;
+ const std::vector<MCPhysReg> Registers;
+ std::vector<uint16_t> RegToIndexMapping;
+
+ static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) {
+ if (RegsToTrack.empty())
+ return 0;
+ return 1 + *llvm::max_element(RegsToTrack);
+ }
+
+public:
+ TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack)
+ : Registers(RegsToTrack),
+ RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) {
+ for (unsigned I = 0; I < RegsToTrack.size(); ++I)
+ RegToIndexMapping[RegsToTrack[I]] = I;
+ }
+
+ const ArrayRef<MCPhysReg> getRegisters() const { return Registers; }
+
+ size_t getNumTrackedRegisters() const { return Registers.size(); }
+
+ bool empty() const { return Registers.empty(); }
+
+ bool isTracked(MCPhysReg Reg) const {
+ bool IsTracked = (unsigned)Reg < RegToIndexMapping.size() &&
+ RegToIndexMapping[Reg] != NoIndex;
+ assert(IsTracked == llvm::is_contained(Registers, Reg));
+ return IsTracked;
+ }
+
+ unsigned getIndex(MCPhysReg Reg) const {
+ assert(isTracked(Reg) && "Register is not tracked");
+ return RegToIndexMapping[Reg];
+ }
+};
+
// The security property that is checked is:
// When a register is used as the address to jump to in a return instruction,
// that register must either:
@@ -169,52 +235,34 @@ class PacRetAnalysis
PacRetAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
: Parent(BF, AllocId), NumRegs(BF.getBinaryContext().MRI->getNumRegs()),
- RegsToTrackInstsFor(RegsToTrackInstsFor),
- TrackingLastInsts(!RegsToTrackInstsFor.empty()),
- Reg2StateIdx(RegsToTrackInstsFor.empty()
- ? 0
- : *llvm::max_element(RegsToTrackInstsFor) + 1,
- -1) {
- for (unsigned I = 0; I < RegsToTrackInstsFor.size(); ++I)
- Reg2StateIdx[RegsToTrackInstsFor[I]] = I;
- }
+ RegsToTrackInstsFor(RegsToTrackInstsFor) {}
virtual ~PacRetAnalysis() {}
protected:
const unsigned NumRegs;
/// RegToTrackInstsFor is the set of registers for which the dataflow analysis
/// must compute which the last set of instructions writing to it are.
- const std::vector<MCPhysReg> RegsToTrackInstsFor;
- const bool TrackingLastInsts;
- /// Reg2StateIdx maps Register to the index in the vector used in State to
- /// track which instructions last wrote to this register.
- std::vector<uint16_t> Reg2StateIdx;
+ const TrackedRegisters RegsToTrackInstsFor;
SmallPtrSet<const MCInst *, 4> &lastWritingInsts(State &S,
MCPhysReg Reg) const {
- assert(Reg < Reg2StateIdx.size());
- assert(isTrackingReg(Reg));
- return S.LastInstWritingReg[Reg2StateIdx[Reg]];
+ unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
+ return S.LastInstWritingReg[Index];
}
const SmallPtrSet<const MCInst *, 4> &lastWritingInsts(const State &S,
MCPhysReg Reg) const {
- assert(Reg < Reg2StateIdx.size());
- assert(isTrackingReg(Reg));
- return S.LastInstWritingReg[Reg2StateIdx[Reg]];
- }
-
- bool isTrackingReg(MCPhysReg Reg) const {
- return llvm::is_contained(RegsToTrackInstsFor, Reg);
+ unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
+ return S.LastInstWritingReg[Index];
}
void preflight() {}
State getStartingStateAtBB(const BinaryBasicBlock &BB) {
- return State(NumRegs, RegsToTrackInstsFor.size());
+ return State(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
}
State getStartingStateAtPoint(const MCInst &Point) {
- return State(NumRegs, RegsToTrackInstsFor.size());
+ return State(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
}
void doConfluence(State &StateOut, const State &StateIn) {
@@ -275,7 +323,7 @@ class PacRetAnalysis
Next.NonAutClobRegs |= Written;
// Keep track of this instruction if it writes to any of the registers we
// need to track that for:
- for (MCPhysReg Reg : RegsToTrackInstsFor)
+ for (MCPhysReg Reg : RegsToTrackInstsFor.getRegisters())
if (Written[Reg])
lastWritingInsts(Next, Reg) = {&Point};
@@ -287,7 +335,7 @@ class PacRetAnalysis
// https://github.com/llvm/llvm-project/pull/122304#discussion_r1939515516
Next.NonAutClobRegs.reset(
BC.MIB->getAliases(*AutReg, /*OnlySmaller=*/true));
- if (TrackingLastInsts && isTrackingReg(*AutReg))
+ if (RegsToTrackInstsFor.isTracked(*AutReg))
lastWritingInsts(Next, *AutReg).clear();
}
@@ -306,7 +354,7 @@ class PacRetAnalysis
std::vector<MCInstReference>
getLastClobberingInsts(const MCInst Ret, BinaryFunction &BF,
const BitVector &UsedDirtyRegs) const {
- if (!TrackingLastInsts)
+ if (RegsToTrackInstsFor.empty())
return {};
auto MaybeState = getStateAt(Ret);
if (!MaybeState)
@@ -355,28 +403,18 @@ Analysis::computeDfState(PacRetAnalysis &PRA, BinaryFunction &BF,
}
MCPhysReg RetReg = *MaybeRetReg;
LLVM_DEBUG({
- dbgs() << " Found RET inst: ";
- BC.printInstruction(dbgs(), Inst);
- dbgs() << " RetReg: " << BC.MRI->getName(RetReg)
- << "; authenticatesReg: "
- << BC.MIB->isAuthenticationOfReg(Inst, RetReg) << "\n";
+ traceInst(BC, "Found RET inst", Inst);
+ traceReg(BC, "RetReg", RetReg);
+ traceReg(BC, "Authenticated reg", BC.MIB->getAuthenticatedReg(Inst));
});
if (BC.MIB->isAuthenticationOfReg(Inst, RetReg))
break;
BitVector UsedDirtyRegs = PRA.getStateAt(Inst)->NonAutClobRegs;
- LLVM_DEBUG({
- dbgs() << " NonAutClobRegs at Ret: ";
- RegStatePrinter RSP(BC);
- RSP.print(dbgs(), UsedDirtyRegs);
- dbgs() << "\n";
- });
+ LLVM_DEBUG(
+ { traceRegMask(BC, "NonAutClobRegs at Ret", UsedDirtyRegs); });
UsedDirtyRegs &= BC.MIB->getAliases(RetReg, /*OnlySmaller=*/true);
- LLVM_DEBUG({
- dbgs() << " Intersection with RetReg: ";
- RegStatePrinter RSP(BC);
- RSP.print(dbgs(), UsedDirtyRegs);
- dbgs() << "\n";
- });
+ LLVM_DEBUG(
+ { traceRegMask(BC, "Intersection with RetReg", UsedDirtyRegs); });
if (UsedDirtyRegs.any()) {
// This return instruction needs to be reported
Result.Diagnostics.push_back(std::make_shared<Gadget>(
@@ -472,12 +510,6 @@ void Gadget::generateReport(raw_ostream &OS, const BinaryContext &BC) const {
OS << " " << (I + 1) << ". ";
BC.printInstruction(OS, InstRef, InstRef.getAddress(), BF);
};
- LLVM_DEBUG({
- dbgs() << " .. OverWritingRetRegInst:\n";
- for (MCInstReference Ref : OverwritingRetRegInst) {
- dbgs() << " " << Ref << "\n";
- }
- });
if (OverwritingRetRegInst.size() == 1) {
const MCInstReference OverwInst = OverwritingRetRegInst[0];
assert(OverwInst.ParentKind == MCInstReference::BasicBlockParent);
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, this is a very nice cleanup!
Apart from my other minor comments, I'm wondering if it would be useful to have 1 regression test that checks debug output....
I think we typically don't do that in LLVM, but when we're doing a mostly NFC change of how debug printing works, ideally, there should be a test to show that we're not changing it (or only changing it in the way we intended)?
I wonder if you have an opinion about adding a small regression test to document/test what the "trace" output is expected to look like?
Added a basic test for debug output - this looks like a useful overview for what gadget scanner does. This helped identify a few issues with debug output in the next patches. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!
@atrosinenko I've landed 10624e6 to fix warnings from this PR. Thanks! |
@kazutakahirata Thank you! |
Factor out the code for mapping from physical registers to consecutive
array indexes.
Introduce helper functions to print instructions and registers to
prevent mixing of analysis logic and implementation details of debug
output.
Removed the debug printing from
Gadget::generateReport
, as it doesn'tseem to add important information to what was already printed in the
report itself.