14
14
#include " bolt/Passes/NonPacProtectedRetAnalysis.h"
15
15
#include " bolt/Core/ParallelUtilities.h"
16
16
#include " bolt/Passes/DataflowAnalysis.h"
17
+ #include " llvm/ADT/STLExtras.h"
17
18
#include " llvm/ADT/SmallSet.h"
18
19
#include " llvm/MC/MCInst.h"
19
20
#include " llvm/Support/Format.h"
@@ -58,6 +59,71 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &Ref) {
58
59
59
60
namespace NonPacProtectedRetAnalysis {
60
61
62
+ static void traceInst (const BinaryContext &BC, StringRef Label,
63
+ const MCInst &MI) {
64
+ dbgs () << " " << Label << " : " ;
65
+ BC.printInstruction (dbgs (), MI);
66
+ }
67
+
68
+ static void traceReg (const BinaryContext &BC, StringRef Label,
69
+ ErrorOr<MCPhysReg> Reg) {
70
+ dbgs () << " " << Label << " : " ;
71
+ if (Reg.getError ())
72
+ dbgs () << " (error)" ;
73
+ else if (*Reg == BC.MIB ->getNoRegister ())
74
+ dbgs () << " (none)" ;
75
+ else
76
+ dbgs () << BC.MRI ->getName (*Reg);
77
+ dbgs () << " \n " ;
78
+ }
79
+
80
+ static void traceRegMask (const BinaryContext &BC, StringRef Label,
81
+ BitVector Mask) {
82
+ dbgs () << " " << Label << " : " ;
83
+ RegStatePrinter (BC).print (dbgs (), Mask);
84
+ dbgs () << " \n " ;
85
+ }
86
+
87
+ // This class represents mapping from arbitrary physical registers to
88
+ // consecutive array indexes.
89
+ class TrackedRegisters {
90
+ static const uint16_t NoIndex = -1 ;
91
+ const std::vector<MCPhysReg> Registers;
92
+ std::vector<uint16_t > RegToIndexMapping;
93
+
94
+ static size_t getMappingSize (const std::vector<MCPhysReg> &RegsToTrack) {
95
+ if (RegsToTrack.empty ())
96
+ return 0 ;
97
+ return 1 + *llvm::max_element (RegsToTrack);
98
+ }
99
+
100
+ public:
101
+ TrackedRegisters (const std::vector<MCPhysReg> &RegsToTrack)
102
+ : Registers(RegsToTrack),
103
+ RegToIndexMapping (getMappingSize(RegsToTrack), NoIndex) {
104
+ for (unsigned I = 0 ; I < RegsToTrack.size (); ++I)
105
+ RegToIndexMapping[RegsToTrack[I]] = I;
106
+ }
107
+
108
+ const ArrayRef<MCPhysReg> getRegisters () const { return Registers; }
109
+
110
+ size_t getNumTrackedRegisters () const { return Registers.size (); }
111
+
112
+ bool empty () const { return Registers.empty (); }
113
+
114
+ bool isTracked (MCPhysReg Reg) const {
115
+ bool IsTracked = (unsigned )Reg < RegToIndexMapping.size () &&
116
+ RegToIndexMapping[Reg] != NoIndex;
117
+ assert (IsTracked == llvm::is_contained (Registers, Reg));
118
+ return IsTracked;
119
+ }
120
+
121
+ unsigned getIndex (MCPhysReg Reg) const {
122
+ assert (isTracked (Reg) && " Register is not tracked" );
123
+ return RegToIndexMapping[Reg];
124
+ }
125
+ };
126
+
61
127
// The security property that is checked is:
62
128
// When a register is used as the address to jump to in a return instruction,
63
129
// that register must either:
@@ -169,52 +235,34 @@ class PacRetAnalysis
169
235
PacRetAnalysis (BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
170
236
const std::vector<MCPhysReg> &RegsToTrackInstsFor)
171
237
: Parent(BF, AllocId), NumRegs(BF.getBinaryContext().MRI->getNumRegs ()),
172
- RegsToTrackInstsFor(RegsToTrackInstsFor),
173
- TrackingLastInsts(!RegsToTrackInstsFor.empty()),
174
- Reg2StateIdx(RegsToTrackInstsFor.empty()
175
- ? 0
176
- : *llvm::max_element(RegsToTrackInstsFor) + 1,
177
- -1) {
178
- for (unsigned I = 0 ; I < RegsToTrackInstsFor.size (); ++I)
179
- Reg2StateIdx[RegsToTrackInstsFor[I]] = I;
180
- }
238
+ RegsToTrackInstsFor(RegsToTrackInstsFor) {}
181
239
virtual ~PacRetAnalysis () {}
182
240
183
241
protected:
184
242
const unsigned NumRegs;
185
243
// / RegToTrackInstsFor is the set of registers for which the dataflow analysis
186
244
// / must compute which the last set of instructions writing to it are.
187
- const std::vector<MCPhysReg> RegsToTrackInstsFor;
188
- const bool TrackingLastInsts;
189
- // / Reg2StateIdx maps Register to the index in the vector used in State to
190
- // / track which instructions last wrote to this register.
191
- std::vector<uint16_t > Reg2StateIdx;
245
+ const TrackedRegisters RegsToTrackInstsFor;
192
246
193
247
SmallPtrSet<const MCInst *, 4 > &lastWritingInsts (State &S,
194
248
MCPhysReg Reg) const {
195
- assert (Reg < Reg2StateIdx.size ());
196
- assert (isTrackingReg (Reg));
197
- return S.LastInstWritingReg [Reg2StateIdx[Reg]];
249
+ unsigned Index = RegsToTrackInstsFor.getIndex (Reg);
250
+ return S.LastInstWritingReg [Index];
198
251
}
199
252
const SmallPtrSet<const MCInst *, 4 > &lastWritingInsts (const State &S,
200
253
MCPhysReg Reg) const {
201
- assert (Reg < Reg2StateIdx.size ());
202
- assert (isTrackingReg (Reg));
203
- return S.LastInstWritingReg [Reg2StateIdx[Reg]];
204
- }
205
-
206
- bool isTrackingReg (MCPhysReg Reg) const {
207
- return llvm::is_contained (RegsToTrackInstsFor, Reg);
254
+ unsigned Index = RegsToTrackInstsFor.getIndex (Reg);
255
+ return S.LastInstWritingReg [Index];
208
256
}
209
257
210
258
void preflight () {}
211
259
212
260
State getStartingStateAtBB (const BinaryBasicBlock &BB) {
213
- return State (NumRegs, RegsToTrackInstsFor.size ());
261
+ return State (NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
214
262
}
215
263
216
264
State getStartingStateAtPoint (const MCInst &Point ) {
217
- return State (NumRegs, RegsToTrackInstsFor.size ());
265
+ return State (NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters ());
218
266
}
219
267
220
268
void doConfluence (State &StateOut, const State &StateIn) {
@@ -275,7 +323,7 @@ class PacRetAnalysis
275
323
Next.NonAutClobRegs |= Written;
276
324
// Keep track of this instruction if it writes to any of the registers we
277
325
// need to track that for:
278
- for (MCPhysReg Reg : RegsToTrackInstsFor)
326
+ for (MCPhysReg Reg : RegsToTrackInstsFor. getRegisters () )
279
327
if (Written[Reg])
280
328
lastWritingInsts (Next, Reg) = {&Point };
281
329
@@ -287,7 +335,7 @@ class PacRetAnalysis
287
335
// https://github.com/llvm/llvm-project/pull/122304#discussion_r1939515516
288
336
Next.NonAutClobRegs .reset (
289
337
BC.MIB ->getAliases (*AutReg, /* OnlySmaller=*/ true ));
290
- if (TrackingLastInsts && isTrackingReg (*AutReg))
338
+ if (RegsToTrackInstsFor. isTracked (*AutReg))
291
339
lastWritingInsts (Next, *AutReg).clear ();
292
340
}
293
341
@@ -306,7 +354,7 @@ class PacRetAnalysis
306
354
std::vector<MCInstReference>
307
355
getLastClobberingInsts (const MCInst Ret, BinaryFunction &BF,
308
356
const BitVector &UsedDirtyRegs) const {
309
- if (!TrackingLastInsts )
357
+ if (RegsToTrackInstsFor. empty () )
310
358
return {};
311
359
auto MaybeState = getStateAt (Ret);
312
360
if (!MaybeState)
@@ -355,28 +403,18 @@ Analysis::computeDfState(PacRetAnalysis &PRA, BinaryFunction &BF,
355
403
}
356
404
MCPhysReg RetReg = *MaybeRetReg;
357
405
LLVM_DEBUG ({
358
- dbgs () << " Found RET inst: " ;
359
- BC.printInstruction (dbgs (), Inst);
360
- dbgs () << " RetReg: " << BC.MRI ->getName (RetReg)
361
- << " ; authenticatesReg: "
362
- << BC.MIB ->isAuthenticationOfReg (Inst, RetReg) << " \n " ;
406
+ traceInst (BC, " Found RET inst" , Inst);
407
+ traceReg (BC, " RetReg" , RetReg);
408
+ traceReg (BC, " Authenticated reg" , BC.MIB ->getAuthenticatedReg (Inst));
363
409
});
364
410
if (BC.MIB ->isAuthenticationOfReg (Inst, RetReg))
365
411
break ;
366
412
BitVector UsedDirtyRegs = PRA.getStateAt (Inst)->NonAutClobRegs ;
367
- LLVM_DEBUG ({
368
- dbgs () << " NonAutClobRegs at Ret: " ;
369
- RegStatePrinter RSP (BC);
370
- RSP.print (dbgs (), UsedDirtyRegs);
371
- dbgs () << " \n " ;
372
- });
413
+ LLVM_DEBUG (
414
+ { traceRegMask (BC, " NonAutClobRegs at Ret" , UsedDirtyRegs); });
373
415
UsedDirtyRegs &= BC.MIB ->getAliases (RetReg, /* OnlySmaller=*/ true );
374
- LLVM_DEBUG ({
375
- dbgs () << " Intersection with RetReg: " ;
376
- RegStatePrinter RSP (BC);
377
- RSP.print (dbgs (), UsedDirtyRegs);
378
- dbgs () << " \n " ;
379
- });
416
+ LLVM_DEBUG (
417
+ { traceRegMask (BC, " Intersection with RetReg" , UsedDirtyRegs); });
380
418
if (UsedDirtyRegs.any ()) {
381
419
// This return instruction needs to be reported
382
420
Result.Diagnostics .push_back (std::make_shared<Gadget>(
@@ -472,12 +510,6 @@ void Gadget::generateReport(raw_ostream &OS, const BinaryContext &BC) const {
472
510
OS << " " << (I + 1 ) << " . " ;
473
511
BC.printInstruction (OS, InstRef, InstRef.getAddress (), BF);
474
512
};
475
- LLVM_DEBUG ({
476
- dbgs () << " .. OverWritingRetRegInst:\n " ;
477
- for (MCInstReference Ref : OverwritingRetRegInst) {
478
- dbgs () << " " << Ref << " \n " ;
479
- }
480
- });
481
513
if (OverwritingRetRegInst.size () == 1 ) {
482
514
const MCInstReference OverwInst = OverwritingRetRegInst[0 ];
483
515
assert (OverwInst.ParentKind == MCInstReference::BasicBlockParent);
0 commit comments