@@ -8636,16 +8636,6 @@ static void analyzeCallOperands(const AArch64TargetLowering &TLI,
8636
8636
}
8637
8637
}
8638
8638
8639
- static SMECallAttrs
8640
- getSMECallAttrs(const Function &Function,
8641
- const TargetLowering::CallLoweringInfo &CLI) {
8642
- if (CLI.CB)
8643
- return SMECallAttrs(*CLI.CB);
8644
- if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8645
- return SMECallAttrs(SMEAttrs(Function), SMEAttrs(ES->getSymbol()));
8646
- return SMECallAttrs(SMEAttrs(Function), SMEAttrs(SMEAttrs::Normal));
8647
- }
8648
-
8649
8639
bool AArch64TargetLowering::isEligibleForTailCallOptimization(
8650
8640
const CallLoweringInfo &CLI) const {
8651
8641
CallingConv::ID CalleeCC = CLI.CallConv;
@@ -8664,10 +8654,12 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
8664
8654
8665
8655
// SME Streaming functions are not eligible for TCO as they may require
8666
8656
// the streaming mode or ZA to be restored after returning from the call.
8667
- SMECallAttrs CallAttrs = getSMECallAttrs(CallerF, CLI);
8668
- if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
8669
- CallAttrs.requiresPreservingAllZAState() ||
8670
- CallAttrs.caller().hasStreamingBody())
8657
+ SMEAttrs CallerAttrs(MF.getFunction());
8658
+ auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
8659
+ if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
8660
+ CallerAttrs.requiresLazySave(CalleeAttrs) ||
8661
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
8662
+ CallerAttrs.hasStreamingBody())
8671
8663
return false;
8672
8664
8673
8665
// Functions using the C or Fast calling convention that have an SVE signature
@@ -8959,13 +8951,14 @@ static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
8959
8951
return TLI.LowerCallTo(CLI).second;
8960
8952
}
8961
8953
8962
- static unsigned getSMCondition(const SMECallAttrs &CallAttrs) {
8963
- if (!CallAttrs.caller().hasStreamingCompatibleInterface() ||
8964
- CallAttrs.caller().hasStreamingBody())
8954
+ static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8955
+ const SMEAttrs &CalleeAttrs) {
8956
+ if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8957
+ CallerAttrs.hasStreamingBody())
8965
8958
return AArch64SME::Always;
8966
- if (CallAttrs.callee() .hasNonStreamingInterface())
8959
+ if (CalleeAttrs .hasNonStreamingInterface())
8967
8960
return AArch64SME::IfCallerIsStreaming;
8968
- if (CallAttrs.callee() .hasStreamingInterface())
8961
+ if (CalleeAttrs .hasStreamingInterface())
8969
8962
return AArch64SME::IfCallerIsNonStreaming;
8970
8963
8971
8964
llvm_unreachable("Unsupported attributes");
@@ -9098,7 +9091,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9098
9091
}
9099
9092
9100
9093
// Determine whether we need any streaming mode changes.
9101
- SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), CLI);
9094
+ SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
9095
+ if (CLI.CB)
9096
+ CalleeAttrs = SMEAttrs(*CLI.CB);
9097
+ else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
9098
+ CalleeAttrs = SMEAttrs(ES->getSymbol());
9102
9099
9103
9100
auto DescribeCallsite =
9104
9101
[&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
@@ -9113,8 +9110,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9113
9110
return R;
9114
9111
};
9115
9112
9116
- bool RequiresLazySave = CallAttrs.requiresLazySave();
9117
- bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
9113
+ bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
9114
+ bool RequiresSaveAllZA =
9115
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
9118
9116
if (RequiresLazySave) {
9119
9117
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
9120
9118
MachinePointerInfo MPI =
@@ -9142,18 +9140,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9142
9140
return DescribeCallsite(R) << " sets up a lazy save for ZA";
9143
9141
});
9144
9142
} else if (RequiresSaveAllZA) {
9145
- assert(!CallAttrs.callee() .hasSharedZAInterface() &&
9143
+ assert(!CalleeAttrs .hasSharedZAInterface() &&
9146
9144
"Cannot share state that may not exist");
9147
9145
Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
9148
9146
/*IsSave=*/true);
9149
9147
}
9150
9148
9151
9149
SDValue PStateSM;
9152
- bool RequiresSMChange = CallAttrs .requiresSMChange();
9150
+ bool RequiresSMChange = CallerAttrs .requiresSMChange(CalleeAttrs );
9153
9151
if (RequiresSMChange) {
9154
- if (CallAttrs.caller() .hasStreamingInterfaceOrBody())
9152
+ if (CallerAttrs .hasStreamingInterfaceOrBody())
9155
9153
PStateSM = DAG.getConstant(1, DL, MVT::i64);
9156
- else if (CallAttrs.caller() .hasNonStreamingInterface())
9154
+ else if (CallerAttrs .hasNonStreamingInterface())
9157
9155
PStateSM = DAG.getConstant(0, DL, MVT::i64);
9158
9156
else
9159
9157
PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
@@ -9170,7 +9168,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9170
9168
9171
9169
SDValue ZTFrameIdx;
9172
9170
MachineFrameInfo &MFI = MF.getFrameInfo();
9173
- bool ShouldPreserveZT0 = CallAttrs .requiresPreservingZT0();
9171
+ bool ShouldPreserveZT0 = CallerAttrs .requiresPreservingZT0(CalleeAttrs );
9174
9172
9175
9173
// If the caller has ZT0 state which will not be preserved by the callee,
9176
9174
// spill ZT0 before the call.
@@ -9186,7 +9184,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9186
9184
9187
9185
// If caller shares ZT0 but the callee is not shared ZA, we need to stop
9188
9186
// PSTATE.ZA before the call if there is no lazy-save active.
9189
- bool DisableZA = CallAttrs .requiresDisablingZABeforeCall();
9187
+ bool DisableZA = CallerAttrs .requiresDisablingZABeforeCall(CalleeAttrs );
9190
9188
assert((!DisableZA || !RequiresLazySave) &&
9191
9189
"Lazy-save should have PSTATE.SM=1 on entry to the function");
9192
9190
@@ -9468,9 +9466,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9468
9466
InGlue = Chain.getValue(1);
9469
9467
}
9470
9468
9471
- SDValue NewChain =
9472
- changeStreamingMode( DAG, DL, CallAttrs.callee(). hasStreamingInterface(),
9473
- Chain, InGlue, getSMCondition(CallAttrs ), PStateSM);
9469
+ SDValue NewChain = changeStreamingMode(
9470
+ DAG, DL, CalleeAttrs. hasStreamingInterface(), Chain, InGlue ,
9471
+ getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9474
9472
Chain = NewChain.getValue(0);
9475
9473
InGlue = NewChain.getValue(1);
9476
9474
}
@@ -9649,8 +9647,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9649
9647
if (RequiresSMChange) {
9650
9648
assert(PStateSM && "Expected a PStateSM to be set");
9651
9649
Result = changeStreamingMode(
9652
- DAG, DL, !CallAttrs.callee() .hasStreamingInterface(), Result, InGlue,
9653
- getSMCondition(CallAttrs ), PStateSM);
9650
+ DAG, DL, !CalleeAttrs .hasStreamingInterface(), Result, InGlue,
9651
+ getSMCondition(CallerAttrs, CalleeAttrs ), PStateSM);
9654
9652
9655
9653
if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
9656
9654
InGlue = Result.getValue(1);
@@ -9660,7 +9658,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
9660
9658
}
9661
9659
}
9662
9660
9663
- if (CallAttrs .requiresEnablingZAAfterCall())
9661
+ if (CallerAttrs .requiresEnablingZAAfterCall(CalleeAttrs ))
9664
9662
// Unconditionally resume ZA.
9665
9663
Result = DAG.getNode(
9666
9664
AArch64ISD::SMSTART, DL, MVT::Other, Result,
@@ -28520,10 +28518,12 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
28520
28518
28521
28519
// Checks to allow the use of SME instructions
28522
28520
if (auto *Base = dyn_cast<CallBase>(&Inst)) {
28523
- auto CallAttrs = SMECallAttrs(*Base);
28524
- if (CallAttrs.requiresSMChange() || CallAttrs.requiresLazySave() ||
28525
- CallAttrs.requiresPreservingZT0() ||
28526
- CallAttrs.requiresPreservingAllZAState())
28521
+ auto CallerAttrs = SMEAttrs(*Inst.getFunction());
28522
+ auto CalleeAttrs = SMEAttrs(*Base);
28523
+ if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
28524
+ CallerAttrs.requiresLazySave(CalleeAttrs) ||
28525
+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
28526
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
28527
28527
return true;
28528
28528
}
28529
28529
return false;
0 commit comments