@@ -1241,46 +1241,123 @@ InstructionCost GCNTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
12411241 (ScalarSize == 16 || ScalarSize == 8 )) {
12421242 // Larger vector widths may require additional instructions, but are
12431243 // typically cheaper than scalarized versions.
1244- unsigned NumVectorElts = cast<FixedVectorType>(SrcTy)->getNumElements ();
1245- unsigned RequestedElts =
1246- count_if (Mask, [](int MaskElt) { return MaskElt != -1 ; });
1244+ //
1245+ // We assume that shuffling at a register granularity can be done for free.
1246+ // This is not true for vectors fed into memory instructions, but it is
1247+ // effectively true for all other shuffling. The emphasis of the logic here
1248+ // is to assist generic transform in cleaning up / canonicalizing those
1249+ // shuffles.
1250+
1251+ // With op_sel VOP3P instructions freely can access the low half or high
1252+ // half of a register, so any swizzle of two elements is free.
1253+ if (auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcTy)) {
1254+ unsigned NumSrcElts = SrcVecTy->getNumElements ();
1255+ if (ST->hasVOP3PInsts () && ScalarSize == 16 && NumSrcElts == 2 &&
1256+ (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Reverse ||
1257+ Kind == TTI::SK_PermuteSingleSrc))
1258+ return 0 ;
1259+ }
1260+
12471261 unsigned EltsPerReg = 32 / ScalarSize;
1248- if (RequestedElts == 0 )
1249- return 0 ;
12501262 switch (Kind) {
12511263 case TTI::SK_Broadcast:
1264+ // A single v_perm_b32 can be re-used for all destination registers.
1265+ return 1 ;
12521266 case TTI::SK_Reverse:
1253- case TTI::SK_PermuteSingleSrc: {
1254- // With op_sel VOP3P instructions freely can access the low half or high
1255- // half of a register, so any swizzle of two elements is free.
1256- if (ST->hasVOP3PInsts () && ScalarSize == 16 && NumVectorElts == 2 )
1257- return 0 ;
1258- unsigned NumPerms = alignTo (RequestedElts, EltsPerReg) / EltsPerReg;
1259- // SK_Broadcast just reuses the same mask
1260- unsigned NumPermMasks = Kind == TTI::SK_Broadcast ? 1 : NumPerms;
1261- return NumPerms + NumPermMasks;
1262- }
1267+ // One instruction per register.
1268+ if (auto *DstVecTy = dyn_cast<FixedVectorType>(DstTy))
1269+ return divideCeil (DstVecTy->getNumElements (), EltsPerReg);
1270+ return InstructionCost::getInvalid ();
12631271 case TTI::SK_ExtractSubvector:
1272+ if (Index % EltsPerReg == 0 )
1273+ return 0 ; // Shuffling at register granularity
1274+ if (auto *DstVecTy = dyn_cast<FixedVectorType>(DstTy))
1275+ return divideCeil (DstVecTy->getNumElements (), EltsPerReg);
1276+ return InstructionCost::getInvalid ();
12641277 case TTI::SK_InsertSubvector: {
1265- // Even aligned accesses are free
1266- if (!(Index % 2 ))
1267- return 0 ;
1268- // Insert/extract subvectors only require shifts / extract code to get the
1269- // relevant bits
1270- return alignTo (RequestedElts, EltsPerReg) / EltsPerReg;
1278+ auto *DstVecTy = dyn_cast<FixedVectorType>(DstTy);
1279+ if (!DstVecTy)
1280+ return InstructionCost::getInvalid ();
1281+ unsigned NumDstElts = DstVecTy->getNumElements ();
1282+ unsigned NumInsertElts = cast<FixedVectorType>(SubTp)->getNumElements ();
1283+ unsigned EndIndex = Index + NumInsertElts;
1284+ unsigned BeginSubIdx = Index % EltsPerReg;
1285+ unsigned EndSubIdx = EndIndex % EltsPerReg;
1286+ unsigned Cost = 0 ;
1287+
1288+ if (BeginSubIdx != 0 ) {
1289+ // Need to shift the inserted vector into place. The cost is the number
1290+ // of destination registers overlapped by the inserted vector.
1291+ Cost = divideCeil (EndIndex, EltsPerReg) - (Index / EltsPerReg);
1292+ }
1293+
1294+ // If the last register overlap is partial, there may be three source
1295+ // registers feeding into it; that takes an extra instruction.
1296+ if (EndIndex < NumDstElts && BeginSubIdx < EndSubIdx)
1297+ Cost += 1 ;
1298+
1299+ return Cost;
12711300 }
1272- case TTI::SK_PermuteTwoSrc:
1273- case TTI::SK_Splice:
1274- case TTI::SK_Select: {
1275- unsigned NumPerms = alignTo (RequestedElts, EltsPerReg) / EltsPerReg;
1276- // SK_Select just reuses the same mask
1277- unsigned NumPermMasks = Kind == TTI::SK_Select ? 1 : NumPerms;
1278- return NumPerms + NumPermMasks;
1301+ case TTI::SK_Splice: {
1302+ auto *DstVecTy = dyn_cast<FixedVectorType>(DstTy);
1303+ if (!DstVecTy)
1304+ return InstructionCost::getInvalid ();
1305+ unsigned NumElts = DstVecTy->getNumElements ();
1306+ assert (NumElts == cast<FixedVectorType>(SrcTy)->getNumElements ());
1307+ // Determine the sub-region of the result vector that requires
1308+ // sub-register shuffles / mixing.
1309+ unsigned EltsFromLHS = NumElts - Index;
1310+ bool LHSIsAligned = (Index % EltsPerReg) == 0 ;
1311+ bool RHSIsAligned = (EltsFromLHS % EltsPerReg) == 0 ;
1312+ if (LHSIsAligned && RHSIsAligned)
1313+ return 0 ;
1314+ if (LHSIsAligned && !RHSIsAligned)
1315+ return divideCeil (NumElts, EltsPerReg) - (EltsFromLHS / EltsPerReg);
1316+ if (!LHSIsAligned && RHSIsAligned)
1317+ return divideCeil (EltsFromLHS, EltsPerReg);
1318+ return divideCeil (NumElts, EltsPerReg);
12791319 }
1280-
12811320 default :
12821321 break ;
12831322 }
1323+
1324+ if (!Mask.empty ()) {
1325+ unsigned NumSrcElts = cast<FixedVectorType>(SrcTy)->getNumElements ();
1326+
1327+ // Generically estimate the cost by assuming that each destination
1328+ // register is derived from sources via v_perm_b32 instructions if it
1329+ // can't be copied as-is.
1330+ //
1331+ // For each destination register, derive the cost of obtaining it based
1332+ // on the number of source registers that feed into it.
1333+ unsigned Cost = 0 ;
1334+ for (unsigned DstIdx = 0 ; DstIdx < Mask.size (); DstIdx += EltsPerReg) {
1335+ SmallVector<int , 4 > Regs;
1336+ bool Aligned = true ;
1337+ for (unsigned I = 0 ; I < EltsPerReg && DstIdx + I < Mask.size (); ++I) {
1338+ int SrcIdx = Mask[DstIdx + I];
1339+ if (SrcIdx == -1 )
1340+ continue ;
1341+ int Reg;
1342+ if (SrcIdx < (int )NumSrcElts) {
1343+ Reg = SrcIdx / EltsPerReg;
1344+ if (SrcIdx % EltsPerReg != I)
1345+ Aligned = false ;
1346+ } else {
1347+ Reg = NumSrcElts + (SrcIdx - NumSrcElts) / EltsPerReg;
1348+ if ((SrcIdx - NumSrcElts) % EltsPerReg != I)
1349+ Aligned = false ;
1350+ }
1351+ if (!llvm::is_contained (Regs, Reg))
1352+ Regs.push_back (Reg);
1353+ }
1354+ if (Regs.size () >= 2 )
1355+ Cost += Regs.size () - 1 ;
1356+ else if (!Aligned)
1357+ Cost += 1 ;
1358+ }
1359+ return Cost;
1360+ }
12841361 }
12851362
12861363 return BaseT::getShuffleCost (Kind, DstTy, SrcTy, Mask, CostKind, Index,
0 commit comments