Skip to content

Commit e7d837d

Browse files
Allow optimizing mask conversions on x64 as well (#110195)
* Allow optimizing mask conversions on x64 as well * Ensure the right operand is accessed on xarch * Minimally handle CndSel as part of optimizing mask conversions * Add some additional comments and clean up the logic a bit * Apply formatting patch
1 parent 1db85e8 commit e7d837d

File tree

4 files changed

+149
-46
lines changed

4 files changed

+149
-46
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26786,7 +26786,55 @@ bool GenTree::OperIsHWIntrinsic(NamedIntrinsic intrinsicId) const
2678626786
{
2678726787
if (OperIsHWIntrinsic())
2678826788
{
26789-
return AsHWIntrinsic()->GetHWIntrinsicId() == intrinsicId;
26789+
return AsHWIntrinsic()->OperIsHWIntrinsic(intrinsicId);
26790+
}
26791+
return false;
26792+
}
26793+
26794+
//------------------------------------------------------------------------
26795+
// OperIsConvertMaskToVector: Is this a ConvertMaskToVector hwintrinsic
26796+
//
26797+
// Return Value:
26798+
// true if the node is a ConvertMaskToVector hwintrinsic
26799+
// otherwise; false
26800+
//
26801+
bool GenTree::OperIsConvertMaskToVector() const
26802+
{
26803+
if (OperIsHWIntrinsic())
26804+
{
26805+
return AsHWIntrinsic()->OperIsConvertMaskToVector();
26806+
}
26807+
return false;
26808+
}
26809+
26810+
//------------------------------------------------------------------------
26811+
// OperIsConvertVectorToMask: Is this a ConvertVectorToMask hwintrinsic
26812+
//
26813+
// Return Value:
26814+
// true if the node is a ConvertVectorToMask hwintrinsic
26815+
// otherwise; false
26816+
//
26817+
bool GenTree::OperIsConvertVectorToMask() const
26818+
{
26819+
if (OperIsHWIntrinsic())
26820+
{
26821+
return AsHWIntrinsic()->OperIsConvertVectorToMask();
26822+
}
26823+
return false;
26824+
}
26825+
26826+
//------------------------------------------------------------------------
26827+
// OperIsVectorConditionalSelect: Is this a vector ConditionalSelect hwintrinsic
26828+
//
26829+
// Return Value:
26830+
// true if the node is a vector ConditionalSelect hwintrinsic
26831+
// otherwise; false
26832+
//
26833+
bool GenTree::OperIsVectorConditionalSelect() const
26834+
{
26835+
if (OperIsHWIntrinsic())
26836+
{
26837+
return AsHWIntrinsic()->OperIsVectorConditionalSelect();
2679026838
}
2679126839
return false;
2679226840
}
@@ -30678,8 +30726,6 @@ bool GenTree::CanDivOrModPossiblyOverflow(Compiler* comp) const
3067830726
#if defined(FEATURE_HW_INTRINSICS)
3067930727
GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
3068030728
{
30681-
assert(tree->OperIsHWIntrinsic());
30682-
3068330729
if (!opts.Tier0OptimizationEnabled())
3068430730
{
3068530731
return tree;

src/coreclr/jit/gentree.h

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,32 +1665,9 @@ struct GenTree
16651665
}
16661666

16671667
bool OperIsHWIntrinsic(NamedIntrinsic intrinsicId) const;
1668-
1669-
bool OperIsConvertMaskToVector() const
1670-
{
1671-
#if defined(FEATURE_HW_INTRINSICS)
1672-
#if defined(TARGET_XARCH)
1673-
return OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector);
1674-
#elif defined(TARGET_ARM64)
1675-
return OperIsHWIntrinsic(NI_Sve_ConvertMaskToVector);
1676-
#endif // !TARGET_XARCH && !TARGET_ARM64
1677-
#else
1678-
return false;
1679-
#endif // FEATURE_HW_INTRINSICS
1680-
}
1681-
1682-
bool OperIsConvertVectorToMask() const
1683-
{
1684-
#if defined(FEATURE_HW_INTRINSICS)
1685-
#if defined(TARGET_XARCH)
1686-
return OperIsHWIntrinsic(NI_EVEX_ConvertVectorToMask);
1687-
#elif defined(TARGET_ARM64)
1688-
return OperIsHWIntrinsic(NI_Sve_ConvertVectorToMask);
1689-
#endif // !TARGET_XARCH && !TARGET_ARM64
1690-
#else
1691-
return false;
1692-
#endif // FEATURE_HW_INTRINSICS
1693-
}
1668+
bool OperIsConvertMaskToVector() const;
1669+
bool OperIsConvertVectorToMask() const;
1670+
bool OperIsVectorConditionalSelect() const;
16941671

16951672
// This is here for cleaner GT_LONG #ifdefs.
16961673
static bool OperIsLong(genTreeOps gtOper)
@@ -6583,6 +6560,45 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic
65836560
bool OperIsBitwiseHWIntrinsic() const;
65846561
bool OperIsEmbRoundingEnabled() const;
65856562

6563+
bool OperIsHWIntrinsic(NamedIntrinsic intrinsicId) const
6564+
{
6565+
return GetHWIntrinsicId() == intrinsicId;
6566+
}
6567+
6568+
bool OperIsConvertMaskToVector() const
6569+
{
6570+
#if defined(TARGET_XARCH)
6571+
return OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector);
6572+
#elif defined(TARGET_ARM64)
6573+
return OperIsHWIntrinsic(NI_Sve_ConvertMaskToVector);
6574+
#else
6575+
return false;
6576+
#endif
6577+
}
6578+
6579+
bool OperIsConvertVectorToMask() const
6580+
{
6581+
#if defined(TARGET_XARCH)
6582+
return OperIsHWIntrinsic(NI_EVEX_ConvertVectorToMask);
6583+
#elif defined(TARGET_ARM64)
6584+
return OperIsHWIntrinsic(NI_Sve_ConvertVectorToMask);
6585+
#else
6586+
return false;
6587+
#endif
6588+
}
6589+
6590+
bool OperIsVectorConditionalSelect() const
6591+
{
6592+
#if defined(TARGET_XARCH)
6593+
return OperIsHWIntrinsic(NI_Vector128_ConditionalSelect) || OperIsHWIntrinsic(NI_Vector256_ConditionalSelect) ||
6594+
OperIsHWIntrinsic(NI_Vector512_ConditionalSelect);
6595+
#elif defined(TARGET_ARM64)
6596+
return OperIsHWIntrinsic(NI_AdvSimd_BitwiseSelect) || OperIsHWIntrinsic(NI_Sve_ConditionalSelect);
6597+
#else
6598+
return false;
6599+
#endif
6600+
}
6601+
65866602
bool OperRequiresAsgFlag() const;
65876603
bool OperRequiresCallFlag() const;
65886604
bool OperRequiresGlobRefFlag() const;

src/coreclr/jit/lsrabuild.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2766,12 +2766,12 @@ void LinearScan::buildIntervals()
27662766
{
27672767
calleeSaveCount = CNT_CALLEE_ENREG;
27682768
}
2769-
#if (defined(TARGET_XARCH) || defined(TARGET_ARM64)) && defined(FEATURE_SIMD)
2769+
#if defined(FEATURE_MASKED_HW_INTRINSICS)
27702770
else if (varTypeUsesMaskReg(interval->registerType))
27712771
{
27722772
calleeSaveCount = CNT_CALLEE_SAVED_MASK;
27732773
}
2774-
#endif // (TARGET_XARCH || TARGET_ARM64) && FEATURE_SIMD
2774+
#endif // FEATURE_MASKED_HW_INTRINSICS
27752775
else
27762776
{
27772777
assert(varTypeUsesFloatReg(interval->registerType));

src/coreclr/jit/optimizemaskconversions.cpp

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include "jitpch.h"
55

6-
#if defined(TARGET_ARM64)
6+
#if defined(FEATURE_MASKED_HW_INTRINSICS)
77

88
struct MaskConversionsWeight
99
{
@@ -19,8 +19,13 @@ struct MaskConversionsWeight
1919
// Conversion of mask to vector is one instruction.
2020
static constexpr const weight_t costOfConvertMaskToVector = 1.0;
2121

22+
#if defined(TARGET_ARM64)
2223
// Conversion of vector to mask is two instructions.
2324
static constexpr const weight_t costOfConvertVectorToMask = 2.0;
25+
#else
26+
// Conversion of vector to mask is one instructions.
27+
static constexpr const weight_t costOfConvertVectorToMask = 1.0;
28+
#endif
2429

2530
// The simd types of the Lcl Store after conversion to vector.
2631
CorInfoType simdBaseJitType = CORINFO_TYPE_UNDEF;
@@ -136,6 +141,7 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
136141
switch ((*use)->OperGet())
137142
{
138143
case GT_STORE_LCL_VAR:
144+
{
139145
isLocalStore = true;
140146

141147
// Look for:
@@ -147,19 +153,48 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
147153
hasConversion = true;
148154
}
149155
break;
156+
}
150157

151158
case GT_LCL_VAR:
159+
{
152160
isLocalUse = true;
153161

154162
// Look for:
155-
// user:ConvertVectorToMask(use:LCL_VAR(x)))
163+
// user: ConvertVectorToMask(use:LCL_VAR(x)))
164+
// -or-
165+
// user: ConditionalSelect(use:LCL_VAR(x), y, z)
156166

157-
if (user->OperIsConvertVectorToMask())
167+
if (user->OperIsHWIntrinsic())
158168
{
159-
convertOp = user->AsHWIntrinsic();
160-
hasConversion = true;
169+
GenTreeHWIntrinsic* hwintrin = user->AsHWIntrinsic();
170+
NamedIntrinsic ni = hwintrin->GetHWIntrinsicId();
171+
172+
if (hwintrin->OperIsConvertVectorToMask())
173+
{
174+
convertOp = user->AsHWIntrinsic();
175+
hasConversion = true;
176+
}
177+
else if (hwintrin->OperIsVectorConditionalSelect())
178+
{
179+
// We don't actually have a convert here, but we do have a case where
180+
// the mask is being used in a ConditionalSelect and therefore can be
181+
// consumed directly as a mask. While the IR shows TYP_SIMD, it gets
182+
// handled in lowering as part of the general embedded-mask support.
183+
184+
// We notably don't check that op2->isEmbeddedMaskingCompatibleHWIntrinsic()
185+
// because we can still consume the mask directly in such cases. We'll just
186+
// emit `vblendmps zmm1 {k1}, zmm2, zmm3` instead of containing the CndSel
187+
// as part of something like `vaddps zmm1 {k1}, zmm2, zmm3`
188+
189+
if (hwintrin->Op(1) == (*use))
190+
{
191+
convertOp = user->AsHWIntrinsic();
192+
hasConversion = true;
193+
}
194+
}
161195
}
162196
break;
197+
}
163198

164199
default:
165200
break;
@@ -254,6 +289,12 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
254289

255290
Compiler::fgWalkResult PostOrderVisit(GenTree** use, GenTree* user)
256291
{
292+
#if defined(TARGET_ARM64)
293+
static constexpr const int ConvertVectorToMaskValueOp = 2;
294+
#else
295+
static constexpr const int ConvertVectorToMaskValueOp = 1;
296+
#endif
297+
257298
GenTreeLclVarCommon* lclOp = nullptr;
258299
bool isLocalStore = false;
259300
bool isLocalUse = false;
@@ -276,11 +317,12 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
276317
isLocalStore = true;
277318
addConversion = true;
278319
}
279-
else if ((*use)->OperIsConvertVectorToMask() && (*use)->AsHWIntrinsic()->Op(2)->OperIs(GT_LCL_VAR))
320+
else if ((*use)->OperIsConvertVectorToMask() &&
321+
(*use)->AsHWIntrinsic()->Op(ConvertVectorToMaskValueOp)->OperIs(GT_LCL_VAR))
280322
{
281323
// Found
282324
// user(use:ConvertVectorToMask(LCL_VAR(x)))
283-
lclOp = (*use)->AsHWIntrinsic()->Op(2)->AsLclVarCommon();
325+
lclOp = (*use)->AsHWIntrinsic()->Op(ConvertVectorToMaskValueOp)->AsLclVarCommon();
284326
isLocalUse = true;
285327
removeConversion = true;
286328
}
@@ -393,7 +435,7 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
393435
MaskConversionsWeightTable* weightsTable;
394436
};
395437

396-
#endif // TARGET_ARM64
438+
#endif // FEATURE_MASKED_HW_INTRINSICS
397439

398440
//------------------------------------------------------------------------
399441
// fgOptimizeMaskConversions: Allow locals to be of Mask type
@@ -445,7 +487,7 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
445487
//
446488
PhaseStatus Compiler::fgOptimizeMaskConversions()
447489
{
448-
#if defined(TARGET_ARM64)
490+
#if defined(FEATURE_MASKED_HW_INTRINSICS)
449491

450492
if (opts.OptimizationDisabled())
451493
{
@@ -476,10 +518,10 @@ PhaseStatus Compiler::fgOptimizeMaskConversions()
476518
{
477519
for (Statement* const stmt : block->Statements())
478520
{
479-
// Only check statements where there is a local of type TYP_SIMD16/TYP_MASK.
521+
// Only check statements where there is a local of type TYP_SIMD/TYP_MASK.
480522
for (GenTreeLclVarCommon* lcl : stmt->LocalsTreeList())
481523
{
482-
if (lcl->TypeIs(TYP_SIMD16, TYP_MASK))
524+
if (varTypeIsSIMDOrMask(lcl))
483525
{
484526
// Parse the entire statement.
485527
MaskConversionsCheckVisitor ev(this, block->getBBWeight(this), &weightsTable);
@@ -504,10 +546,10 @@ PhaseStatus Compiler::fgOptimizeMaskConversions()
504546
{
505547
for (Statement* const stmt : block->Statements())
506548
{
507-
// Only check statements where there is a local of type TYP_SIMD16/TYP_MASK.
549+
// Only check statements where there is a local of type TYP_SIMD/TYP_MASK.
508550
for (GenTreeLclVarCommon* lcl : stmt->LocalsTreeList())
509551
{
510-
if (lcl->TypeIs(TYP_SIMD16, TYP_MASK))
552+
if (varTypeIsSIMDOrMask(lcl))
511553
{
512554
// Parse the entire statement.
513555
MaskConversionsUpdateVisitor ev(this, stmt, &weightsTable);
@@ -524,8 +566,7 @@ PhaseStatus Compiler::fgOptimizeMaskConversions()
524566
}
525567

526568
return PhaseStatus::MODIFIED_EVERYTHING;
527-
528569
#else
529570
return PhaseStatus::MODIFIED_NOTHING;
530-
#endif // TARGET_ARM64
571+
#endif // FEATURE_MASKED_HW_INTRINSICS
531572
}

0 commit comments

Comments
 (0)