Skip to content

Commit 0487d1e

Browse files
committed
[BOLT] Gadget scanner: use more appropriate types (NFC)
* use more flexible `const ArrayRef<T>` and `StringRef` types instead of `const std::vector<T> &` and `const std::string &`, correspondingly, for function arguments * return plain `const SrcState &` instead of `ErrorOr<const SrcState &>` from `SrcSafetyAnalysis::getStateBefore`, as absent state is not handled gracefully by any caller
1 parent 48a2836 commit 0487d1e

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

bolt/include/bolt/Passes/PAuthGadgetScanner.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "bolt/Core/BinaryContext.h"
1313
#include "bolt/Core/BinaryFunction.h"
1414
#include "bolt/Passes/BinaryPasses.h"
15-
#include "llvm/ADT/SmallSet.h"
1615
#include "llvm/Support/raw_ostream.h"
1716
#include <memory>
1817

@@ -197,9 +196,6 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &);
197196

198197
namespace PAuthGadgetScanner {
199198

200-
class SrcSafetyAnalysis;
201-
struct SrcState;
202-
203199
/// Description of a gadget kind that can be detected. Intended to be
204200
/// statically allocated to be attached to reports by reference.
205201
class GadgetKind {
@@ -208,7 +204,7 @@ class GadgetKind {
208204
public:
209205
GadgetKind(const char *Description) : Description(Description) {}
210206

211-
const StringRef getDescription() const { return Description; }
207+
StringRef getDescription() const { return Description; }
212208
};
213209

214210
/// Base report located at some instruction, without any additional information.
@@ -259,7 +255,7 @@ struct GadgetReport : public Report {
259255
/// Report with a free-form message attached.
260256
struct GenericReport : public Report {
261257
std::string Text;
262-
GenericReport(MCInstReference Location, const std::string &Text)
258+
GenericReport(MCInstReference Location, StringRef Text)
263259
: Report(Location), Text(Text) {}
264260
virtual void generateReport(raw_ostream &OS,
265261
const BinaryContext &BC) const override;

bolt/lib/Passes/PAuthGadgetScanner.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,14 @@ class TrackedRegisters {
9191
const std::vector<MCPhysReg> Registers;
9292
std::vector<uint16_t> RegToIndexMapping;
9393

94-
static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) {
94+
static size_t getMappingSize(const ArrayRef<MCPhysReg> RegsToTrack) {
9595
if (RegsToTrack.empty())
9696
return 0;
9797
return 1 + *llvm::max_element(RegsToTrack);
9898
}
9999

100100
public:
101-
TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack)
101+
TrackedRegisters(const ArrayRef<MCPhysReg> RegsToTrack)
102102
: Registers(RegsToTrack),
103103
RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) {
104104
for (unsigned I = 0; I < RegsToTrack.size(); ++I)
@@ -234,7 +234,7 @@ struct SrcState {
234234

235235
static void printLastInsts(
236236
raw_ostream &OS,
237-
const std::vector<SmallPtrSet<const MCInst *, 4>> &LastInstWritingReg) {
237+
const ArrayRef<SmallPtrSet<const MCInst *, 4>> LastInstWritingReg) {
238238
OS << "Insts: ";
239239
for (unsigned I = 0; I < LastInstWritingReg.size(); ++I) {
240240
auto &Set = LastInstWritingReg[I];
@@ -295,19 +295,18 @@ void SrcStatePrinter::print(raw_ostream &OS, const SrcState &S) const {
295295
class SrcSafetyAnalysis {
296296
public:
297297
SrcSafetyAnalysis(BinaryFunction &BF,
298-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
298+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
299299
: BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()),
300300
RegsToTrackInstsFor(RegsToTrackInstsFor) {}
301301

302302
virtual ~SrcSafetyAnalysis() {}
303303

304304
static std::shared_ptr<SrcSafetyAnalysis>
305305
create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
306-
const std::vector<MCPhysReg> &RegsToTrackInstsFor);
306+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor);
307307

308308
virtual void run() = 0;
309-
virtual ErrorOr<const SrcState &>
310-
getStateBefore(const MCInst &Inst) const = 0;
309+
virtual const SrcState &getStateBefore(const MCInst &Inst) const = 0;
311310

312311
protected:
313312
BinaryContext &BC;
@@ -347,7 +346,7 @@ class SrcSafetyAnalysis {
347346
}
348347

349348
BitVector getClobberedRegs(const MCInst &Point) const {
350-
BitVector Clobbered(NumRegs, false);
349+
BitVector Clobbered(NumRegs);
351350
// Assume a call can clobber all registers, including callee-saved
352351
// registers. There's a good chance that callee-saved registers will be
353352
// saved on the stack at some point during execution of the callee.
@@ -409,8 +408,7 @@ class SrcSafetyAnalysis {
409408
// FirstCheckerInst should belong to the same basic block (see the
410409
// assertion in DataflowSrcSafetyAnalysis::run()), meaning it was
411410
// deterministically processed a few steps before this instruction.
412-
const SrcState &StateBeforeChecker =
413-
getStateBefore(*FirstCheckerInst).get();
411+
const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst);
414412
if (StateBeforeChecker.SafeToDerefRegs[CheckedReg])
415413
Regs.push_back(CheckedReg);
416414
}
@@ -523,10 +521,7 @@ class SrcSafetyAnalysis {
523521
const ArrayRef<MCPhysReg> UsedDirtyRegs) const {
524522
if (RegsToTrackInstsFor.empty())
525523
return {};
526-
auto MaybeState = getStateBefore(Inst);
527-
if (!MaybeState)
528-
llvm_unreachable("Expected state to be present");
529-
const SrcState &S = *MaybeState;
524+
const SrcState &S = getStateBefore(Inst);
530525
// Due to aliasing registers, multiple registers may have been tracked.
531526
std::set<const MCInst *> LastWritingInsts;
532527
for (MCPhysReg TrackedReg : UsedDirtyRegs) {
@@ -537,7 +532,7 @@ class SrcSafetyAnalysis {
537532
for (const MCInst *Inst : LastWritingInsts) {
538533
MCInstReference Ref = MCInstReference::get(Inst, BF);
539534
assert(Ref && "Expected Inst to be found");
540-
Result.push_back(MCInstReference(Ref));
535+
Result.push_back(Ref);
541536
}
542537
return Result;
543538
}
@@ -557,11 +552,11 @@ class DataflowSrcSafetyAnalysis
557552
public:
558553
DataflowSrcSafetyAnalysis(BinaryFunction &BF,
559554
MCPlusBuilder::AllocatorIdTy AllocId,
560-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
555+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
561556
: SrcSafetyAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {}
562557

563-
ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
564-
return DFParent::getStateBefore(Inst);
558+
const SrcState &getStateBefore(const MCInst &Inst) const override {
559+
return DFParent::getStateBefore(Inst).get();
565560
}
566561

567562
void run() override {
@@ -674,7 +669,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
674669
public:
675670
CFGUnawareSrcSafetyAnalysis(BinaryFunction &BF,
676671
MCPlusBuilder::AllocatorIdTy AllocId,
677-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
672+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
678673
: SrcSafetyAnalysis(BF, RegsToTrackInstsFor), BF(BF), AllocId(AllocId) {
679674
StateAnnotationIndex =
680675
BC.MIB->getOrCreateAnnotationIndex("CFGUnawareSrcSafetyAnalysis");
@@ -708,7 +703,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
708703
}
709704
}
710705

711-
ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
706+
const SrcState &getStateBefore(const MCInst &Inst) const override {
712707
return BC.MIB->getAnnotationAs<SrcState>(Inst, StateAnnotationIndex);
713708
}
714709

@@ -718,7 +713,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
718713
std::shared_ptr<SrcSafetyAnalysis>
719714
SrcSafetyAnalysis::create(BinaryFunction &BF,
720715
MCPlusBuilder::AllocatorIdTy AllocId,
721-
const std::vector<MCPhysReg> &RegsToTrackInstsFor) {
716+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor) {
722717
if (BF.hasCFG())
723718
return std::make_shared<DataflowSrcSafetyAnalysis>(BF, AllocId,
724719
RegsToTrackInstsFor);
@@ -825,7 +820,7 @@ Analysis::findGadgets(BinaryFunction &BF,
825820

826821
BinaryContext &BC = BF.getBinaryContext();
827822
iterateOverInstrs(BF, [&](MCInstReference Inst) {
828-
const SrcState &S = *Analysis->getStateBefore(Inst);
823+
const SrcState &S = Analysis->getStateBefore(Inst);
829824

830825
// If non-empty state was never propagated from the entry basic block
831826
// to Inst, assume it to be unreachable and report a warning.

0 commit comments

Comments
 (0)