Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class VectorCombine {
bool foldShuffleOfSelects(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
bool foldShufflesOfLengthChangingShuffles(Instruction &I);
bool foldShuffleOfIntrinsics(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
Expand Down Expand Up @@ -2877,6 +2878,174 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
return true;
}

/// Try to convert a chain of length-preserving shuffles that are fed by
/// length-changing shuffles from the same source, e.g. a chain of length 3:
///
/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
/// (shuffle y, undef)),
// (shuffle y, undef)"
///
/// into a single shuffle fed by a length-changing shuffle:
///
/// "shuffle x, (shuffle y, undef)"
///
/// Such chains arise e.g. from folding extract/insert sequences.
bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType());
if (!TrunkType)
return false;

unsigned ChainLength = 0;
SmallVector<int> Mask;
SmallVector<int> YMask;
InstructionCost OldCost = 0;
InstructionCost NewCost = 0;
Value *Trunk = &I;
unsigned NumTrunkElts = TrunkType->getNumElements();
FixedVectorType *YType = nullptr;
Value *Y = nullptr;

for (;;) {
// Match the current trunk against (commutations of) the pattern
// "shuffle trunk', (shuffle y, undef)"
ArrayRef<int> OuterMask;
Value *OuterV0, *OuterV1;
if (ChainLength != 0 && !Trunk->hasOneUse())
break;
if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1),
m_Mask(OuterMask))))
break;
if (OuterV0->getType() != TrunkType) {
// This shuffle is not length-preserving, so it cannot be part of the
// chain.
break;
}

ArrayRef<int> InnerMask0, InnerMask1;
Value *A0, *A1, *B0, *B1;
bool Match0 =
match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0)));
bool Match1 =
match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1)));
bool Match0Leaf = Match0 && A0->getType() != I.getType();
bool Match1Leaf = Match1 && A1->getType() != I.getType();
if (Match0Leaf == Match1Leaf) {
// Only handle the case of exactly one leaf in each step. The "two leaves"
// case is handled by foldShuffleOfShuffles.
break;
}

SmallVector<int> CommutedOuterMask;
if (Match0Leaf) {
std::swap(OuterV0, OuterV1);
std::swap(InnerMask0, InnerMask1);
std::swap(A0, A1);
std::swap(B0, B1);
llvm::append_range(CommutedOuterMask, OuterMask);
for (int &M : CommutedOuterMask) {
if (M == PoisonMaskElem)
continue;
if (M < (int)NumTrunkElts)
M += NumTrunkElts;
else
M -= NumTrunkElts;
}
OuterMask = CommutedOuterMask;
}
if (!OuterV1->hasOneUse())
break;

if (!isa<UndefValue>(A1)) {
if (!Y)
Y = A1;
else if (Y != A1)
break;
}
if (!isa<UndefValue>(B1)) {
if (!Y)
Y = B1;
else if (Y != B1)
break;
}

InstructionCost LocalOldCost =
TTI.getInstructionCost(cast<User>(Trunk), CostKind) +
TTI.getInstructionCost(cast<User>(OuterV1), CostKind);

// Handle the initial (start of chain) case.
if (!ChainLength) {
YType = cast<FixedVectorType>(A1->getType());
Mask.assign(OuterMask);
YMask.assign(InnerMask1);
OldCost = NewCost = LocalOldCost;
Trunk = OuterV0;
ChainLength++;
continue;
}

// For the non-root case, first attempt to combine masks.
SmallVector<int> NewYMask(YMask);
bool Valid = true;
for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, InnerMask1)) {
if (LeafM == -1 || CombinedM == LeafM)
continue;
if (CombinedM == -1) {
CombinedM = LeafM;
} else {
Valid = false;
break;
}
}
if (!Valid)
break;

SmallVector<int> NewMask;
NewMask.reserve(NumTrunkElts);
for (int M : Mask) {
if (M < 0 || M >= static_cast<int>(NumTrunkElts))
NewMask.push_back(M);
else
NewMask.push_back(OuterMask[M]);
}

// Break the chain if adding this new step complicates the shuffles such
// that it would increase the new cost by more than the old cost of this
// step.
InstructionCost LocalNewCost =
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, TrunkType,
YType, NewYMask, CostKind) +
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, TrunkType,
TrunkType, NewMask, CostKind);

if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
break;

LLVM_DEBUG({
if (ChainLength == 1) {
dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
<< I << '\n';
}
dbgs() << " next chain link: " << *Trunk << '\n'
<< " old cost: " << (OldCost + LocalOldCost)
<< " new cost: " << LocalNewCost << '\n';
});

Mask = NewMask;
YMask = NewYMask;
OldCost += LocalOldCost;
NewCost = LocalNewCost;
Trunk = OuterV0;
ChainLength++;
}
if (ChainLength <= 1)
return false;

Value *Leaf = Builder.CreateShuffleVector(Y, PoisonValue::get(YType), YMask);
Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask);
replaceValue(I, *Root);
return true;
}

/// Try to convert
/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
Expand Down Expand Up @@ -4718,6 +4887,8 @@ bool VectorCombine::run() {
return true;
if (foldShuffleOfShuffles(I))
return true;
if (foldShufflesOfLengthChangingShuffles(I))
return true;
if (foldShuffleOfIntrinsics(I))
return true;
if (foldSelectShuffle(I))
Expand Down
Loading
Loading