Skip to content

Commit 703b479

Browse files
authored
Revert "[AArch64][SME] Split SMECallAttrs out of SMEAttrs" (#138664)
Reverts #137239 This broke implementing SME ABI routines in C/C++ (used for some stubs), see: https://lab.llvm.org/buildbot/#/builders/94/builds/6859
1 parent 488cb24 commit 703b479

File tree

8 files changed

+211
-269
lines changed

8 files changed

+211
-269
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

+38-38
Original file line numberDiff line numberDiff line change
@@ -8636,16 +8636,6 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
86368636
}
86378637
}
86388638

8639-
static SMECallAttrs
8640-
getSMECallAttrs(const Function &Function,
8641-
const TargetLowering::CallLoweringInfo &CLI) {
8642-
if (CLI.CB)
8643-
return SMECallAttrs(*CLI.CB);
8644-
if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8645-
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
8646-
return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
8647-
}
8648-
86498639
bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86508640
const CallLoweringInfo &CLI) const {
86518641
CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8664,10 +8654,12 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
86648654

86658655
// SME Streaming functions are not eligible for TCO as they may require
86668656
// the streaming mode or ZA to be restored after returning from the call.
8667-
SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8668-
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8669-
CallAttrs.requiresPreservingAllZAState() ||
8670-
CallAttrs.caller().hasStreamingBody())
8657+
SMEAttrs CallerAttrs(MF.getFunction());
8658+
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8659+
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8660+
CallerAttrs.requiresLazySave(CalleeAttrs) ||
8661+
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8662+
CallerAttrs.hasStreamingBody())
86718663
return false;
86728664

86738665
// Functions using the C or Fast calling convention that have an SVE signature
@@ -8959,13 +8951,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
89598951
return TLI.LowerCallTo(CLI).second;
89608952
}
89618953

8962-
static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
8963-
if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8964-
CallAttrs.caller().hasStreamingBody())
8954+
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8955+
const SMEAttrs &CalleeAttrs) {
8956+
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8957+
CallerAttrs.hasStreamingBody())
89658958
return AArch64SME::Always;
8966-
if (CallAttrs.callee().hasNonStreamingInterface())
8959+
if (CalleeAttrs.hasNonStreamingInterface())
89678960
return AArch64SME::IfCallerIsStreaming;
8968-
if (CallAttrs.callee().hasStreamingInterface())
8961+
if (CalleeAttrs.hasStreamingInterface())
89698962
return AArch64SME::IfCallerIsNonStreaming;
89708963

89718964
llvm_unreachable("Unsupported attributes");
@@ -9098,7 +9091,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
90989091
}
90999092

91009093
// Determine whether we need any streaming mode changes.
9101-
SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9094+
SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9095+
if (CLI.CB)
9096+
CalleeAttrs = SMEAttrs(*CLI.CB);
9097+
else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9098+
CalleeAttrs = SMEAttrs(ES->getSymbol());
91029099

91039100
auto DescribeCallsite =
91049101
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9113,8 +9110,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91139110
return R;
91149111
};
91159112

9116-
bool RequiresLazySave = CallAttrs.requiresLazySave();
9117-
bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9113+
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9114+
bool RequiresSaveAllZA =
9115+
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
91189116
if (RequiresLazySave) {
91199117
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
91209118
MachinePointerInfo MPI =
@@ -9142,18 +9140,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91429140
return DescribeCallsite(R) << " sets up a lazy save for ZA";
91439141
});
91449142
} else if (RequiresSaveAllZA) {
9145-
assert(!CallAttrs.callee().hasSharedZAInterface() &&
9143+
assert(!CalleeAttrs.hasSharedZAInterface() &&
91469144
"Cannot share state that may not exist");
91479145
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
91489146
/*IsSave=*/true);
91499147
}
91509148

91519149
SDValue PStateSM;
9152-
bool RequiresSMChange = CallAttrs.requiresSMChange();
9150+
bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
91539151
if (RequiresSMChange) {
9154-
if (CallAttrs.caller().hasStreamingInterfaceOrBody())
9152+
if (CallerAttrs.hasStreamingInterfaceOrBody())
91559153
PStateSM = DAG.getConstant(1, DL, MVT::i64);
9156-
else if (CallAttrs.caller().hasNonStreamingInterface())
9154+
else if (CallerAttrs.hasNonStreamingInterface())
91579155
PStateSM = DAG.getConstant(0, DL, MVT::i64);
91589156
else
91599157
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9170,7 +9168,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91709168

91719169
SDValue ZTFrameIdx;
91729170
MachineFrameInfo &MFI = MF.getFrameInfo();
9173-
bool ShouldPreserveZT0 = CallAttrs.requiresPreservingZT0();
9171+
bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
91749172

91759173
// If the caller has ZT0 state which will not be preserved by the callee,
91769174
// spill ZT0 before the call.
@@ -9186,7 +9184,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
91869184

91879185
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
91889186
// PSTATE.ZA before the call if there is no lazy-save active.
9189-
bool DisableZA = CallAttrs.requiresDisablingZABeforeCall();
9187+
bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
91909188
assert((!DisableZA || !RequiresLazySave) &&
91919189
"Lazy-save should have PSTATE.SM=1 on entry to the function");
91929190

@@ -9468,9 +9466,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
94689466
InGlue = Chain.getValue(1);
94699467
}
94709468

9471-
SDValue NewChain =
9472-
changeStreamingMode(DAG, DL, CallAttrs.callee().hasStreamingInterface(),
9473-
Chain, InGlue, getSMCondition(CallAttrs), PStateSM);
9469+
SDValue NewChain = changeStreamingMode(
9470+
DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
9471+
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
94749472
Chain = NewChain.getValue(0);
94759473
InGlue = NewChain.getValue(1);
94769474
}
@@ -9649,8 +9647,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96499647
if (RequiresSMChange) {
96509648
assert(PStateSM && "Expected a PStateSM to be set");
96519649
Result = changeStreamingMode(
9652-
DAG, DL, !CallAttrs.callee().hasStreamingInterface(), Result, InGlue,
9653-
getSMCondition(CallAttrs), PStateSM);
9650+
DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
9651+
getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
96549652

96559653
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
96569654
InGlue = Result.getValue(1);
@@ -9660,7 +9658,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
96609658
}
96619659
}
96629660

9663-
if (CallAttrs.requiresEnablingZAAfterCall())
9661+
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
96649662
// Unconditionally resume ZA.
96659663
Result = DAG.getNode(
96669664
AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28520,10 +28518,12 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
2852028518

2852128519
// Checks to allow the use of SME instructions
2852228520
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28523-
auto CallAttrs = SMECallAttrs(*Base);
28524-
if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28525-
CallAttrs.requiresPreservingZT0() ||
28526-
CallAttrs.requiresPreservingAllZAState())
28521+
auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28522+
auto CalleeAttrs = SMEAttrs(*Base);
28523+
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28524+
CallerAttrs.requiresLazySave(CalleeAttrs) ||
28525+
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28526+
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
2852728527
return true;
2852828528
}
2852928529
return false;

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+12-13
Original file line numberDiff line numberDiff line change
@@ -268,21 +268,22 @@ const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
268268

269269
bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
270270
const Function *Callee) const {
271-
SMECallAttrs CallAttrs(*Caller, *Callee);
271+
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
272272

273273
// When inlining, we should consider the body of the function, not the
274274
// interface.
275-
if (CallAttrs.callee().hasStreamingBody()) {
276-
CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
277-
CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
275+
if (CalleeAttrs.hasStreamingBody()) {
276+
CalleeAttrs.set(SMEAttrs::SM_Compatible, false);
277+
CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
278278
}
279279

280-
if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
280+
if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
281281
return false;
282282

283-
if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
284-
CallAttrs.requiresPreservingZT0() ||
285-
CallAttrs.requiresPreservingAllZAState()) {
283+
if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
284+
CallerAttrs.requiresSMChange(CalleeAttrs) ||
285+
CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
286+
CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
286287
if (hasPossibleIncompatibleOps(Callee))
287288
return false;
288289
}
@@ -348,14 +349,12 @@ AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
348349
// streaming-mode change, and the call to G from F would also require a
349350
// streaming-mode change, then there is benefit to do the streaming-mode
350351
// change only once and avoid inlining of G into F.
351-
352352
SMEAttrs FAttrs(*F);
353-
SMECallAttrs CallAttrs(Call);
354-
355-
if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
353+
SMEAttrs CalleeAttrs(Call);
354+
if (FAttrs.requiresSMChange(CalleeAttrs)) {
356355
if (F == Call.getCaller()) // (1)
357356
return CallPenaltyChangeSM * DefaultCallPenalty;
358-
if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
357+
if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
359358
return InlineCallPenaltyChangeSM * DefaultCallPenalty;
360359
}
361360

llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp

+29-35
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ void SMEAttrs::set(unsigned M, bool Enable) {
2727
"ZA_New and SME_ABI_Routine are mutually exclusive");
2828

2929
assert(
30-
(isNewZA() + isInZA() + isOutZA() + isInOutZA() + isPreservesZA()) <= 1 &&
30+
(!sharesZA() ||
31+
(isNewZA() ^ isInZA() ^ isInOutZA() ^ isOutZA() ^ isPreservesZA())) &&
3132
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
3233
"'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive");
3334

3435
// ZT0 Attrs
3536
assert(
36-
(isNewZT0() + isInZT0() + isOutZT0() + isInOutZT0() + isPreservesZT0()) <=
37-
1 &&
37+
(!sharesZT0() || (isNewZT0() ^ isInZT0() ^ isInOutZT0() ^ isOutZT0() ^
38+
isPreservesZT0())) &&
3839
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
3940
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
4041

@@ -43,6 +44,27 @@ void SMEAttrs::set(unsigned M, bool Enable) {
4344
"interface");
4445
}
4546

47+
SMEAttrs::SMEAttrs(const CallBase &CB) {
48+
*this = SMEAttrs(CB.getAttributes());
49+
if (auto *F = CB.getCalledFunction()) {
50+
set(SMEAttrs(*F).Bitmask | SMEAttrs(F->getName()).Bitmask);
51+
}
52+
}
53+
54+
SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
55+
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
56+
Bitmask |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
57+
if (FuncName == "__arm_tpidr2_restore")
58+
Bitmask |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
59+
SMEAttrs::SME_ABI_Routine;
60+
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
61+
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
62+
Bitmask |= SMEAttrs::SM_Compatible;
63+
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
64+
FuncName == "__arm_sme_state_size")
65+
Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
66+
}
67+
4668
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
4769
Bitmask = 0;
4870
if (Attrs.hasFnAttr("aarch64_pstate_sm_enabled"))
@@ -77,45 +99,17 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
7799
Bitmask |= encodeZT0State(StateValue::New);
78100
}
79101

80-
void SMEAttrs::addKnownFunctionAttrs(StringRef FuncName) {
81-
unsigned KnownAttrs = SMEAttrs::Normal;
82-
if (FuncName == "__arm_tpidr2_save" || FuncName == "__arm_sme_state")
83-
KnownAttrs |= (SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine);
84-
if (FuncName == "__arm_tpidr2_restore")
85-
KnownAttrs |= SMEAttrs::SM_Compatible | encodeZAState(StateValue::In) |
86-
SMEAttrs::SME_ABI_Routine;
87-
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
88-
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
89-
KnownAttrs |= SMEAttrs::SM_Compatible;
90-
if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
91-
FuncName == "__arm_sme_state_size")
92-
KnownAttrs |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
93-
set(KnownAttrs, /*Enable=*/true);
94-
}
95-
96-
bool SMECallAttrs::requiresSMChange() const {
97-
if (callee().hasStreamingCompatibleInterface())
102+
bool SMEAttrs::requiresSMChange(const SMEAttrs &Callee) const {
103+
if (Callee.hasStreamingCompatibleInterface())
98104
return false;
99105

100106
// Both non-streaming
101-
if (caller().hasNonStreamingInterfaceAndBody() &&
102-
callee().hasNonStreamingInterface())
107+
if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
103108
return false;
104109

105110
// Both streaming
106-
if (caller().hasStreamingInterfaceOrBody() &&
107-
callee().hasStreamingInterface())
111+
if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
108112
return false;
109113

110114
return true;
111115
}
112-
113-
SMECallAttrs::SMECallAttrs(const CallBase &CB)
114-
: CallerFn(*CB.getFunction()), CalledFn(CB.getCalledFunction()),
115-
Callsite(CB.getAttributes()), IsIndirect(CB.isIndirectCall()) {
116-
// FIXME: We probably should not allow SME attributes on direct calls but
117-
// clang duplicates streaming mode attributes at each callsite.
118-
assert((IsIndirect ||
119-
((Callsite.withoutPerCallsiteFlags() | CalledFn) == CalledFn)) &&
120-
"SME attributes at callsite do not match declaration");
121-
}

0 commit comments

Comments
 (0)