Skip to content

Commit 51373db

Browse files
committed
[BOLT] Gadget scanner: use more appropriate types (NFC)
* use more flexible `const ArrayRef<T>` and `StringRef` types instead of `const std::vector<T> &` and `const std::string &`, correspondingly, for function arguments * return plain `const SrcState &` instead of `ErrorOr<const SrcState &>` from `SrcSafetyAnalysis::getStateBefore`, as absent state is not handled gracefully by any caller
1 parent 0398d2e commit 51373db

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

bolt/include/bolt/Passes/PAuthGadgetScanner.h

+2-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "bolt/Core/BinaryContext.h"
1313
#include "bolt/Core/BinaryFunction.h"
1414
#include "bolt/Passes/BinaryPasses.h"
15-
#include "llvm/ADT/SmallSet.h"
1615
#include "llvm/Support/raw_ostream.h"
1716
#include <memory>
1817

@@ -199,9 +198,6 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &);
199198

200199
namespace PAuthGadgetScanner {
201200

202-
class SrcSafetyAnalysis;
203-
struct SrcState;
204-
205201
/// Description of a gadget kind that can be detected. Intended to be
206202
/// statically allocated to be attached to reports by reference.
207203
class GadgetKind {
@@ -210,7 +206,7 @@ class GadgetKind {
210206
public:
211207
GadgetKind(const char *Description) : Description(Description) {}
212208

213-
const StringRef getDescription() const { return Description; }
209+
StringRef getDescription() const { return Description; }
214210
};
215211

216212
/// Base report located at some instruction, without any additional information.
@@ -261,7 +257,7 @@ struct GadgetReport : public Report {
261257
/// Report with a free-form message attached.
262258
struct GenericReport : public Report {
263259
std::string Text;
264-
GenericReport(MCInstReference Location, const std::string &Text)
260+
GenericReport(MCInstReference Location, StringRef Text)
265261
: Report(Location), Text(Text) {}
266262
virtual void generateReport(raw_ostream &OS,
267263
const BinaryContext &BC) const override;

bolt/lib/Passes/PAuthGadgetScanner.cpp

+17-22
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,14 @@ class TrackedRegisters {
9191
const std::vector<MCPhysReg> Registers;
9292
std::vector<uint16_t> RegToIndexMapping;
9393

94-
static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) {
94+
static size_t getMappingSize(const ArrayRef<MCPhysReg> RegsToTrack) {
9595
if (RegsToTrack.empty())
9696
return 0;
9797
return 1 + *llvm::max_element(RegsToTrack);
9898
}
9999

100100
public:
101-
TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack)
101+
TrackedRegisters(const ArrayRef<MCPhysReg> RegsToTrack)
102102
: Registers(RegsToTrack),
103103
RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) {
104104
for (unsigned I = 0; I < RegsToTrack.size(); ++I)
@@ -234,7 +234,7 @@ struct SrcState {
234234

235235
static void printLastInsts(
236236
raw_ostream &OS,
237-
const std::vector<SmallPtrSet<const MCInst *, 4>> &LastInstWritingReg) {
237+
const ArrayRef<SmallPtrSet<const MCInst *, 4>> LastInstWritingReg) {
238238
OS << "Insts: ";
239239
for (unsigned I = 0; I < LastInstWritingReg.size(); ++I) {
240240
auto &Set = LastInstWritingReg[I];
@@ -295,19 +295,18 @@ void SrcStatePrinter::print(raw_ostream &OS, const SrcState &S) const {
295295
class SrcSafetyAnalysis {
296296
public:
297297
SrcSafetyAnalysis(BinaryFunction &BF,
298-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
298+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
299299
: BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()),
300300
RegsToTrackInstsFor(RegsToTrackInstsFor) {}
301301

302302
virtual ~SrcSafetyAnalysis() {}
303303

304304
static std::shared_ptr<SrcSafetyAnalysis>
305305
create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
306-
const std::vector<MCPhysReg> &RegsToTrackInstsFor);
306+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor);
307307

308308
virtual void run() = 0;
309-
virtual ErrorOr<const SrcState &>
310-
getStateBefore(const MCInst &Inst) const = 0;
309+
virtual const SrcState &getStateBefore(const MCInst &Inst) const = 0;
311310

312311
protected:
313312
BinaryContext &BC;
@@ -348,7 +347,7 @@ class SrcSafetyAnalysis {
348347
}
349348

350349
BitVector getClobberedRegs(const MCInst &Point) const {
351-
BitVector Clobbered(NumRegs, false);
350+
BitVector Clobbered(NumRegs);
352351
// Assume a call can clobber all registers, including callee-saved
353352
// registers. There's a good chance that callee-saved registers will be
354353
// saved on the stack at some point during execution of the callee.
@@ -409,8 +408,7 @@ class SrcSafetyAnalysis {
409408

410409
// FirstCheckerInst should belong to the same basic block, meaning
411410
// it was deterministically processed a few steps before this instruction.
412-
const SrcState &StateBeforeChecker =
413-
getStateBefore(*FirstCheckerInst).get();
411+
const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst);
414412
if (StateBeforeChecker.SafeToDerefRegs[CheckedReg])
415413
Regs.push_back(CheckedReg);
416414
}
@@ -523,10 +521,7 @@ class SrcSafetyAnalysis {
523521
const ArrayRef<MCPhysReg> UsedDirtyRegs) const {
524522
if (RegsToTrackInstsFor.empty())
525523
return {};
526-
auto MaybeState = getStateBefore(Inst);
527-
if (!MaybeState)
528-
llvm_unreachable("Expected state to be present");
529-
const SrcState &S = *MaybeState;
524+
const SrcState &S = getStateBefore(Inst);
530525
// Due to aliasing registers, multiple registers may have been tracked.
531526
std::set<const MCInst *> LastWritingInsts;
532527
for (MCPhysReg TrackedReg : UsedDirtyRegs) {
@@ -537,7 +532,7 @@ class SrcSafetyAnalysis {
537532
for (const MCInst *Inst : LastWritingInsts) {
538533
MCInstReference Ref = MCInstReference::get(Inst, BF);
539534
assert(Ref && "Expected Inst to be found");
540-
Result.push_back(MCInstReference(Ref));
535+
Result.push_back(Ref);
541536
}
542537
return Result;
543538
}
@@ -557,11 +552,11 @@ class DataflowSrcSafetyAnalysis
557552
public:
558553
DataflowSrcSafetyAnalysis(BinaryFunction &BF,
559554
MCPlusBuilder::AllocatorIdTy AllocId,
560-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
555+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
561556
: SrcSafetyAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {}
562557

563-
ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
564-
return DFParent::getStateBefore(Inst);
558+
const SrcState &getStateBefore(const MCInst &Inst) const override {
559+
return DFParent::getStateBefore(Inst).get();
565560
}
566561

567562
void run() override {
@@ -670,7 +665,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
670665
public:
671666
CFGUnawareSrcSafetyAnalysis(BinaryFunction &BF,
672667
MCPlusBuilder::AllocatorIdTy AllocId,
673-
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
668+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
674669
: SrcSafetyAnalysis(BF, RegsToTrackInstsFor), BF(BF), AllocId(AllocId) {
675670
StateAnnotationIndex =
676671
BC.MIB->getOrCreateAnnotationIndex("CFGUnawareSrcSafetyAnalysis");
@@ -704,7 +699,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
704699
}
705700
}
706701

707-
ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
702+
const SrcState &getStateBefore(const MCInst &Inst) const override {
708703
return BC.MIB->getAnnotationAs<SrcState>(Inst, StateAnnotationIndex);
709704
}
710705

@@ -714,7 +709,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
714709
std::shared_ptr<SrcSafetyAnalysis>
715710
SrcSafetyAnalysis::create(BinaryFunction &BF,
716711
MCPlusBuilder::AllocatorIdTy AllocId,
717-
const std::vector<MCPhysReg> &RegsToTrackInstsFor) {
712+
const ArrayRef<MCPhysReg> RegsToTrackInstsFor) {
718713
if (BF.hasCFG())
719714
return std::make_shared<DataflowSrcSafetyAnalysis>(BF, AllocId,
720715
RegsToTrackInstsFor);
@@ -821,7 +816,7 @@ Analysis::findGadgets(BinaryFunction &BF,
821816

822817
BinaryContext &BC = BF.getBinaryContext();
823818
iterateOverInstrs(BF, [&](MCInstReference Inst) {
824-
const SrcState &S = *Analysis->getStateBefore(Inst);
819+
const SrcState &S = Analysis->getStateBefore(Inst);
825820

826821
// If non-empty state was never propagated from the entry basic block
827822
// to Inst, assume it to be unreachable and report a warning.

0 commit comments

Comments
 (0)