Skip to content

Commit 9761e5d

Browse files
committed
[BOLT] Gadget scanner: factor out utility code
Factor out the code for mapping from physical registers to consecutive array indexes. Introduce helper functions to print instructions and registers to prevent mixing of analysis logic and implementation details of debug output. Removed the debug printing from `Gadget::generateReport`, as it doesn't seem to add important information to what was already printed in the report itself.
1 parent c53caae commit 9761e5d

File tree

1 file changed

+84
-52
lines changed

1 file changed

+84
-52
lines changed

bolt/lib/Passes/NonPacProtectedRetAnalysis.cpp

+84-52
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "bolt/Passes/NonPacProtectedRetAnalysis.h"
1515
#include "bolt/Core/ParallelUtilities.h"
1616
#include "bolt/Passes/DataflowAnalysis.h"
17+
#include "llvm/ADT/STLExtras.h"
1718
#include "llvm/ADT/SmallSet.h"
1819
#include "llvm/MC/MCInst.h"
1920
#include "llvm/Support/Format.h"
@@ -58,6 +59,71 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &Ref) {
5859

5960
namespace NonPacProtectedRetAnalysis {
6061

62+
static void traceInst(const BinaryContext &BC, StringRef Label,
63+
const MCInst &MI) {
64+
dbgs() << " " << Label << ": ";
65+
BC.printInstruction(dbgs(), MI);
66+
}
67+
68+
static void traceReg(const BinaryContext &BC, StringRef Label,
69+
ErrorOr<MCPhysReg> Reg) {
70+
dbgs() << " " << Label << ": ";
71+
if (Reg.getError())
72+
dbgs() << "(error)";
73+
else if (*Reg == BC.MIB->getNoRegister())
74+
dbgs() << "(none)";
75+
else
76+
dbgs() << BC.MRI->getName(*Reg);
77+
dbgs() << "\n";
78+
}
79+
80+
static void traceRegMask(const BinaryContext &BC, StringRef Label,
81+
BitVector Mask) {
82+
dbgs() << " " << Label << ": ";
83+
RegStatePrinter(BC).print(dbgs(), Mask);
84+
dbgs() << "\n";
85+
}
86+
87+
// This class represents mapping from arbitrary physical registers to
88+
// consecutive array indexes.
89+
class TrackedRegisters {
90+
static const uint16_t NoIndex = -1;
91+
const std::vector<MCPhysReg> Registers;
92+
std::vector<uint16_t> RegToIndexMapping;
93+
94+
static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) {
95+
if (RegsToTrack.empty())
96+
return 0;
97+
return 1 + *llvm::max_element(RegsToTrack);
98+
}
99+
100+
public:
101+
TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack)
102+
: Registers(RegsToTrack),
103+
RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) {
104+
for (unsigned I = 0; I < RegsToTrack.size(); ++I)
105+
RegToIndexMapping[RegsToTrack[I]] = I;
106+
}
107+
108+
const ArrayRef<MCPhysReg> getRegisters() const { return Registers; }
109+
110+
size_t getNumTrackedRegisters() const { return Registers.size(); }
111+
112+
bool empty() const { return Registers.empty(); }
113+
114+
bool isTracked(MCPhysReg Reg) const {
115+
bool IsTracked = (unsigned)Reg < RegToIndexMapping.size() &&
116+
RegToIndexMapping[Reg] != NoIndex;
117+
assert(IsTracked == llvm::is_contained(Registers, Reg));
118+
return IsTracked;
119+
}
120+
121+
unsigned getIndex(MCPhysReg Reg) const {
122+
assert(isTracked(Reg) && "Register is not tracked");
123+
return RegToIndexMapping[Reg];
124+
}
125+
};
126+
61127
// The security property that is checked is:
62128
// When a register is used as the address to jump to in a return instruction,
63129
// that register must either:
@@ -169,52 +235,34 @@ class PacRetAnalysis
169235
PacRetAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
170236
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
171237
: Parent(BF, AllocId), NumRegs(BF.getBinaryContext().MRI->getNumRegs()),
172-
RegsToTrackInstsFor(RegsToTrackInstsFor),
173-
TrackingLastInsts(!RegsToTrackInstsFor.empty()),
174-
Reg2StateIdx(RegsToTrackInstsFor.empty()
175-
? 0
176-
: *llvm::max_element(RegsToTrackInstsFor) + 1,
177-
-1) {
178-
for (unsigned I = 0; I < RegsToTrackInstsFor.size(); ++I)
179-
Reg2StateIdx[RegsToTrackInstsFor[I]] = I;
180-
}
238+
RegsToTrackInstsFor(RegsToTrackInstsFor) {}
181239
virtual ~PacRetAnalysis() {}
182240

183241
protected:
184242
const unsigned NumRegs;
185243
/// RegToTrackInstsFor is the set of registers for which the dataflow analysis
186244
/// must compute which the last set of instructions writing to it are.
187-
const std::vector<MCPhysReg> RegsToTrackInstsFor;
188-
const bool TrackingLastInsts;
189-
/// Reg2StateIdx maps Register to the index in the vector used in State to
190-
/// track which instructions last wrote to this register.
191-
std::vector<uint16_t> Reg2StateIdx;
245+
const TrackedRegisters RegsToTrackInstsFor;
192246

193247
SmallPtrSet<const MCInst *, 4> &lastWritingInsts(State &S,
194248
MCPhysReg Reg) const {
195-
assert(Reg < Reg2StateIdx.size());
196-
assert(isTrackingReg(Reg));
197-
return S.LastInstWritingReg[Reg2StateIdx[Reg]];
249+
unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
250+
return S.LastInstWritingReg[Index];
198251
}
199252
const SmallPtrSet<const MCInst *, 4> &lastWritingInsts(const State &S,
200253
MCPhysReg Reg) const {
201-
assert(Reg < Reg2StateIdx.size());
202-
assert(isTrackingReg(Reg));
203-
return S.LastInstWritingReg[Reg2StateIdx[Reg]];
204-
}
205-
206-
bool isTrackingReg(MCPhysReg Reg) const {
207-
return llvm::is_contained(RegsToTrackInstsFor, Reg);
254+
unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
255+
return S.LastInstWritingReg[Index];
208256
}
209257

210258
void preflight() {}
211259

212260
State getStartingStateAtBB(const BinaryBasicBlock &BB) {
213-
return State(NumRegs, RegsToTrackInstsFor.size());
261+
return State(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
214262
}
215263

216264
State getStartingStateAtPoint(const MCInst &Point) {
217-
return State(NumRegs, RegsToTrackInstsFor.size());
265+
return State(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
218266
}
219267

220268
void doConfluence(State &StateOut, const State &StateIn) {
@@ -275,7 +323,7 @@ class PacRetAnalysis
275323
Next.NonAutClobRegs |= Written;
276324
// Keep track of this instruction if it writes to any of the registers we
277325
// need to track that for:
278-
for (MCPhysReg Reg : RegsToTrackInstsFor)
326+
for (MCPhysReg Reg : RegsToTrackInstsFor.getRegisters())
279327
if (Written[Reg])
280328
lastWritingInsts(Next, Reg) = {&Point};
281329

@@ -287,7 +335,7 @@ class PacRetAnalysis
287335
// https://github.com/llvm/llvm-project/pull/122304#discussion_r1939515516
288336
Next.NonAutClobRegs.reset(
289337
BC.MIB->getAliases(*AutReg, /*OnlySmaller=*/true));
290-
if (TrackingLastInsts && isTrackingReg(*AutReg))
338+
if (RegsToTrackInstsFor.isTracked(*AutReg))
291339
lastWritingInsts(Next, *AutReg).clear();
292340
}
293341

@@ -306,7 +354,7 @@ class PacRetAnalysis
306354
std::vector<MCInstReference>
307355
getLastClobberingInsts(const MCInst Ret, BinaryFunction &BF,
308356
const BitVector &UsedDirtyRegs) const {
309-
if (!TrackingLastInsts)
357+
if (RegsToTrackInstsFor.empty())
310358
return {};
311359
auto MaybeState = getStateAt(Ret);
312360
if (!MaybeState)
@@ -355,28 +403,18 @@ Analysis::computeDfState(PacRetAnalysis &PRA, BinaryFunction &BF,
355403
}
356404
MCPhysReg RetReg = *MaybeRetReg;
357405
LLVM_DEBUG({
358-
dbgs() << " Found RET inst: ";
359-
BC.printInstruction(dbgs(), Inst);
360-
dbgs() << " RetReg: " << BC.MRI->getName(RetReg)
361-
<< "; authenticatesReg: "
362-
<< BC.MIB->isAuthenticationOfReg(Inst, RetReg) << "\n";
406+
traceInst(BC, "Found RET inst", Inst);
407+
traceReg(BC, "RetReg", RetReg);
408+
traceReg(BC, "Authenticated reg", BC.MIB->getAuthenticatedReg(Inst));
363409
});
364410
if (BC.MIB->isAuthenticationOfReg(Inst, RetReg))
365411
break;
366412
BitVector UsedDirtyRegs = PRA.getStateAt(Inst)->NonAutClobRegs;
367-
LLVM_DEBUG({
368-
dbgs() << " NonAutClobRegs at Ret: ";
369-
RegStatePrinter RSP(BC);
370-
RSP.print(dbgs(), UsedDirtyRegs);
371-
dbgs() << "\n";
372-
});
413+
LLVM_DEBUG(
414+
{ traceRegMask(BC, "NonAutClobRegs at Ret", UsedDirtyRegs); });
373415
UsedDirtyRegs &= BC.MIB->getAliases(RetReg, /*OnlySmaller=*/true);
374-
LLVM_DEBUG({
375-
dbgs() << " Intersection with RetReg: ";
376-
RegStatePrinter RSP(BC);
377-
RSP.print(dbgs(), UsedDirtyRegs);
378-
dbgs() << "\n";
379-
});
416+
LLVM_DEBUG(
417+
{ traceRegMask(BC, "Intersection with RetReg", UsedDirtyRegs); });
380418
if (UsedDirtyRegs.any()) {
381419
// This return instruction needs to be reported
382420
Result.Diagnostics.push_back(std::make_shared<Gadget>(
@@ -472,12 +510,6 @@ void Gadget::generateReport(raw_ostream &OS, const BinaryContext &BC) const {
472510
OS << " " << (I + 1) << ". ";
473511
BC.printInstruction(OS, InstRef, InstRef.getAddress(), BF);
474512
};
475-
LLVM_DEBUG({
476-
dbgs() << " .. OverWritingRetRegInst:\n";
477-
for (MCInstReference Ref : OverwritingRetRegInst) {
478-
dbgs() << " " << Ref << "\n";
479-
}
480-
});
481513
if (OverwritingRetRegInst.size() == 1) {
482514
const MCInstReference OverwInst = OverwritingRetRegInst[0];
483515
assert(OverwInst.ParentKind == MCInstReference::BasicBlockParent);

0 commit comments

Comments
 (0)