Skip to content

Commit 9aa3fcd

Browse files
Sair Teamcopybara-github
Sair Team
authored andcommitted
Use llvm::cast/dyn_cast/isa since alternatives are deprecated in llvm/llvm-project#135556
PiperOrigin-RevId: 749049145
1 parent fb5b686 commit 9aa3fcd

11 files changed

+79
-75
lines changed

mapped_domain.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ mlir::LogicalResult MappedDomain::ResolveUnification(
6060
mlir::Operation *dim_op = dimension.value.defining_op().GetDuplicatedOp();
6161
if (isa<SairPlaceholderOp>(dim_op)) return mlir::success();
6262

63-
if (constraint.isa<MappingNoneExpr, MappingUnknownExpr>()) {
63+
if (llvm::isa<MappingNoneExpr, MappingUnknownExpr>(constraint)) {
6464
// If the dimension is new, extend the domain.
6565
constraint = MappingDimExpr::get(domain_.size(), context());
6666
assert(dimension.mapping.IsSurjective());

sair_attributes.cc

+45-37
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ llvm::SmallBitVector MappingExpr::DependencyMask(int domain_size) const {
7474
bool MappingExpr::HasNoneExprs() const {
7575
bool has_none_exprs = false;
7676
Walk([&](MappingExpr sub_expr) {
77-
has_none_exprs |= sub_expr.isa<MappingNoneExpr>();
77+
has_none_exprs |= llvm::isa<MappingNoneExpr>(sub_expr);
7878
});
7979
return has_none_exprs;
8080
}
8181

8282
bool MappingExpr::HasUnknownExprs() const {
8383
bool has_unknown_exprs = false;
8484
Walk([&](MappingExpr sub_expr) {
85-
has_unknown_exprs |= sub_expr.isa<MappingUnknownExpr>();
85+
has_unknown_exprs |= llvm::isa<MappingUnknownExpr>(sub_expr);
8686
});
8787
return has_unknown_exprs;
8888
}
@@ -109,10 +109,10 @@ int MappingExpr::MinDomainSize() const {
109109
// expression is `?` or `none`. Returns `nullptr` if unification fails.
110110
static MappingExpr ResolveNoneAndUnknownUnification(MappingExpr lhs,
111111
MappingExpr rhs) {
112-
if (lhs.isa<MappingNoneExpr>()) return rhs;
113-
if (rhs.isa<MappingNoneExpr>()) return lhs;
114-
if (lhs.isa<MappingUnknownExpr>()) return rhs;
115-
if (rhs.isa<MappingUnknownExpr>()) return lhs;
112+
if (llvm::isa<MappingNoneExpr>(lhs)) return rhs;
113+
if (llvm::isa<MappingNoneExpr>(rhs)) return lhs;
114+
if (llvm::isa<MappingUnknownExpr>(lhs)) return rhs;
115+
if (llvm::isa<MappingUnknownExpr>(rhs)) return lhs;
116116
return MappingExpr();
117117
}
118118

@@ -383,7 +383,7 @@ mlir::LogicalResult MappingStripeExpr::SetInverse(
383383
MappingExpr MappingStripeExpr::FindInInverse(
384384
llvm::ArrayRef<MappingExpr> inverse) const {
385385
auto operand_inverse = operand().FindInInverse(inverse);
386-
if (operand_inverse.isa<MappingUnknownExpr, MappingNoneExpr>()) {
386+
if (llvm::isa<MappingUnknownExpr, MappingNoneExpr>(operand_inverse)) {
387387
return operand_inverse;
388388
}
389389
auto unstripe_expr = llvm::cast<MappingUnStripeExpr>(operand_inverse);
@@ -545,7 +545,7 @@ MappingExpr MappingUnStripeExpr::Unify(
545545

546546
// If the last operand is `none` or `?`, we can replace it by an arbitrary
547547
// number of operands.
548-
if (min_operands.back().isa<MappingNoneExpr, MappingUnknownExpr>()) {
548+
if (llvm::isa<MappingNoneExpr, MappingUnknownExpr>(min_operands.back())) {
549549
min_operands = min_operands.drop_back();
550550
min_factors = min_factors.drop_back();
551551
}
@@ -568,7 +568,7 @@ MappingExpr MappingUnStripeExpr::FindInInverse(
568568
MappingExpr operand_inverse;
569569
for (int i = 0, e = operands().size(); i < e; ++i) {
570570
operand_inverse = operands()[i].FindInInverse(inverse);
571-
if (operand_inverse.isa<MappingUnknownExpr, MappingNoneExpr>()) continue;
571+
if (llvm::isa<MappingUnknownExpr, MappingNoneExpr>(operand_inverse)) continue;
572572
return llvm::cast<MappingStripeExpr>(operand_inverse).operand();
573573
}
574574
// Unstripe has at least one operand.
@@ -797,7 +797,7 @@ MappingAttr MappingAttr::MakeSurjective() const {
797797
new_exprs.reserve(size());
798798
for (MappingExpr expr : Dimensions()) {
799799
MappingExpr new_expr = expr.Map([&](MappingExpr sub_expr) -> MappingExpr {
800-
if (!sub_expr.isa<MappingNoneExpr>()) return sub_expr;
800+
if (!llvm::isa<MappingNoneExpr>(sub_expr)) return sub_expr;
801801
return MappingDimExpr::get(num_dimensions++, getContext());
802802
});
803803
new_exprs.push_back(new_expr);
@@ -810,7 +810,7 @@ MappingAttr MappingAttr::MakeFullySpecified() const {
810810
auto new_exprs =
811811
llvm::to_vector<4>(llvm::map_range(Dimensions(), [&](auto expr) {
812812
return expr.Map([&](MappingExpr sub_expr) -> MappingExpr {
813-
return sub_expr.isa<MappingUnknownExpr>() ? none : sub_expr;
813+
return llvm::isa<MappingUnknownExpr>(sub_expr) ? none : sub_expr;
814814
});
815815
}));
816816
return MappingAttr::get(getContext(), UseDomainSize(), new_exprs);
@@ -946,8 +946,8 @@ MappingAttr MappingAttr::UnifyUnknownExprs(MappingAttr other) const {
946946
for (auto [lhs, rhs] : llvm::zip(Dimensions(), other.Dimensions())) {
947947
MappingExpr unified =
948948
lhs.Unify(rhs, [](MappingExpr sub_lhs, MappingExpr sub_rhs) {
949-
if (sub_lhs.isa<MappingUnknownExpr>()) return sub_rhs;
950-
if (sub_rhs.isa<MappingUnknownExpr>()) return sub_lhs;
949+
if (llvm::isa<MappingUnknownExpr>(sub_lhs)) return sub_rhs;
950+
if (llvm::isa<MappingUnknownExpr>(sub_rhs)) return sub_lhs;
951951
return MappingExpr();
952952
});
953953
if (unified == nullptr) return nullptr;
@@ -1236,7 +1236,7 @@ static DomainShapeDim StripeAccessedShape(MappingStripeExpr expr,
12361236
static DomainShapeDim UnStripeAccessedShape(MappingUnStripeExpr expr,
12371237
DomainShapeDim inner_shape,
12381238
MappingAttr inverted_mapping) {
1239-
if (inner_shape.type().isa<DynRangeType>()) return inner_shape;
1239+
if (llvm::isa<DynRangeType>(inner_shape.type())) return inner_shape;
12401240
auto type = llvm::cast<StaticRangeType>(inner_shape.type());
12411241
int new_step = type.getStep() / expr.factors().front();
12421242
return DomainShapeDim(
@@ -1460,10 +1460,10 @@ bool LoopAttr::classof(mlir::Attribute attr) {
14601460
if (!derived) return false;
14611461

14621462
auto name = derived.get("name");
1463-
if (!name.isa_and_nonnull<mlir::StringAttr>()) return false;
1463+
if (!llvm::isa_and_nonnull<mlir::StringAttr>(name)) return false;
14641464

14651465
auto iter = derived.get("iter");
1466-
if (!iter.isa_and_nonnull<sair::MappingExpr>()) return false;
1466+
if (!llvm::isa_and_nonnull<sair::MappingExpr>(iter)) return false;
14671467

14681468
auto unroll = derived.get("unroll");
14691469
if (!unroll) return derived.size() == 2;
@@ -1481,23 +1481,25 @@ mlir::StringAttr LoopAttr::name() const {
14811481
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
14821482
auto name = derived.get("name");
14831483
assert(name && "attribute not found.");
1484-
assert(name.isa<mlir::StringAttr>() && "incorrect Attribute type found.");
1484+
assert(llvm::isa<mlir::StringAttr>(name) &&
1485+
"incorrect Attribute type found.");
14851486
return llvm::cast<mlir::StringAttr>(name);
14861487
}
14871488

14881489
MappingExpr LoopAttr::iter() const {
14891490
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
14901491
auto iter = derived.get("iter");
14911492
assert(iter && "attribute not found.");
1492-
assert(iter.isa<MappingExpr>() && "incorrect Attribute type found.");
1493+
assert(llvm::isa<MappingExpr>(iter) && "incorrect Attribute type found.");
14931494
return llvm::cast<MappingExpr>(iter);
14941495
}
14951496

14961497
mlir::IntegerAttr LoopAttr::unroll() const {
14971498
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
14981499
auto unroll = derived.get("unroll");
14991500
if (!unroll) return nullptr;
1500-
assert(unroll.isa<mlir::IntegerAttr>() && "incorrect Attribute type found.");
1501+
assert(llvm::isa<mlir::IntegerAttr>(unroll) &&
1502+
"incorrect Attribute type found.");
15011503
return llvm::cast<mlir::IntegerAttr>(unroll);
15021504
}
15031505

@@ -1531,19 +1533,19 @@ bool BufferAttr::classof(mlir::Attribute attr) {
15311533
int num_absent_attrs = 0;
15321534

15331535
auto space = derived.get("space");
1534-
if (!space.isa_and_nonnull<mlir::StringAttr>()) return false;
1536+
if (!llvm::isa_and_nonnull<mlir::StringAttr>(space)) return false;
15351537

15361538
auto name = derived.get("name");
15371539
if (!name) {
15381540
++num_absent_attrs;
1539-
} else if (!name.isa<mlir::StringAttr>()) {
1541+
} else if (!llvm::isa<mlir::StringAttr>(name)) {
15401542
return false;
15411543
}
15421544

15431545
auto layout = derived.get("layout");
15441546
if (!layout) {
15451547
++num_absent_attrs;
1546-
} else if (!layout.isa<NamedMappingAttr>()) {
1548+
} else if (!llvm::isa<NamedMappingAttr>(layout)) {
15471549
return false;
15481550
}
15491551

@@ -1554,23 +1556,25 @@ mlir::StringAttr BufferAttr::space() const {
15541556
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
15551557
auto space = derived.get("space");
15561558
assert(space && "attribute not found.");
1557-
assert(space.isa<mlir::StringAttr>() && "incorrect Attribute type found.");
1559+
assert(llvm::isa<mlir::StringAttr>(space) && "incorrect Attribute type found.");
15581560
return llvm::cast<mlir::StringAttr>(space);
15591561
}
15601562

15611563
mlir::StringAttr BufferAttr::name() const {
15621564
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
15631565
auto name = derived.get("name");
15641566
if (!name) return nullptr;
1565-
assert(name.isa<mlir::StringAttr>() && "incorrect Attribute type found.");
1567+
assert(llvm::isa<mlir::StringAttr>(name) &&
1568+
"incorrect Attribute type found.");
15661569
return llvm::cast<mlir::StringAttr>(name);
15671570
}
15681571

15691572
NamedMappingAttr BufferAttr::layout() const {
15701573
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
15711574
auto layout = derived.get("layout");
15721575
if (!layout) return nullptr;
1573-
assert(layout.isa<NamedMappingAttr>() && "incorrect Attribute type found.");
1576+
assert(llvm::isa<NamedMappingAttr>(layout) &&
1577+
"incorrect Attribute type found.");
15741578
return llvm::cast<NamedMappingAttr>(layout);
15751579
}
15761580

@@ -1640,7 +1644,7 @@ bool DecisionsAttr::classof(mlir::Attribute attr) {
16401644
auto loop_nest_attr = llvm::dyn_cast<mlir::ArrayAttr>(loop_nest);
16411645
if (!loop_nest_attr) return false;
16421646
if (llvm::any_of(loop_nest_attr, [](mlir::Attribute attr) {
1643-
return !attr.isa_and_nonnull<LoopAttr>();
1647+
return !llvm::isa_and_nonnull<LoopAttr>(attr);
16441648
})) {
16451649
return false;
16461650
}
@@ -1649,21 +1653,21 @@ bool DecisionsAttr::classof(mlir::Attribute attr) {
16491653
auto storage = derived.get("storage");
16501654
if (!storage) {
16511655
++num_absent_attrs;
1652-
} else if (!storage.isa<mlir::ArrayAttr>()) {
1656+
} else if (!llvm::isa<mlir::ArrayAttr>(storage)) {
16531657
return false;
16541658
}
16551659

16561660
auto expansion = derived.get("expansion");
16571661
if (!expansion) {
16581662
++num_absent_attrs;
1659-
} else if (!expansion.isa<mlir::StringAttr>()) {
1663+
} else if (!llvm::isa<mlir::StringAttr>(expansion)) {
16601664
return false;
16611665
}
16621666

16631667
auto copy_of = derived.get("copy_of");
16641668
if (!copy_of) {
16651669
++num_absent_attrs;
1666-
} else if (!copy_of.isa<CopyAttr, InstanceAttr, mlir::UnitAttr>()) {
1670+
} else if (!llvm::isa<CopyAttr, InstanceAttr, mlir::UnitAttr>(copy_of)) {
16671671
return false;
16681672
}
16691673

@@ -1673,8 +1677,8 @@ bool DecisionsAttr::classof(mlir::Attribute attr) {
16731677
} else {
16741678
auto operands_attr = llvm::dyn_cast<mlir::ArrayAttr>(operands);
16751679
if (llvm::any_of(operands_attr, [](mlir::Attribute attr) {
1676-
return !attr.isa_and_nonnull<CopyAttr, InstanceAttr,
1677-
mlir::UnitAttr>();
1680+
return !llvm::isa_and_nonnull<CopyAttr, InstanceAttr, mlir::UnitAttr>(
1681+
attr);
16781682
})) {
16791683
return false;
16801684
}
@@ -1687,7 +1691,7 @@ mlir::IntegerAttr DecisionsAttr::sequence() const {
16871691
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
16881692
auto sequence = derived.get("sequence");
16891693
if (!sequence) return nullptr;
1690-
assert(sequence.isa<mlir::IntegerAttr>() &&
1694+
assert(llvm::isa<mlir::IntegerAttr>(sequence) &&
16911695
"incorrect Attribute type found.");
16921696
return llvm::cast<mlir::IntegerAttr>(sequence);
16931697
}
@@ -1696,23 +1700,25 @@ mlir::ArrayAttr DecisionsAttr::loop_nest() const {
16961700
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
16971701
auto loop_nest = derived.get("loop_nest");
16981702
if (!loop_nest) return nullptr;
1699-
assert(loop_nest.isa<mlir::ArrayAttr>() && "incorrect Attribute type found.");
1703+
assert(llvm::isa<mlir::ArrayAttr>(loop_nest) &&
1704+
"incorrect Attribute type found.");
17001705
return llvm::cast<mlir::ArrayAttr>(loop_nest);
17011706
}
17021707

17031708
mlir::ArrayAttr DecisionsAttr::storage() const {
17041709
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
17051710
auto storage = derived.get("storage");
17061711
if (!storage) return nullptr;
1707-
assert(storage.isa<mlir::ArrayAttr>() && "incorrect Attribute type found.");
1712+
assert(llvm::isa<mlir::ArrayAttr>(storage) &&
1713+
"incorrect Attribute type found.");
17081714
return llvm::cast<mlir::ArrayAttr>(storage);
17091715
}
17101716

17111717
mlir::StringAttr DecisionsAttr::expansion() const {
17121718
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
17131719
auto expansion = derived.get("expansion");
17141720
if (!expansion) return nullptr;
1715-
assert(expansion.isa<mlir::StringAttr>() &&
1721+
assert(llvm::isa<mlir::StringAttr>(expansion) &&
17161722
"incorrect Attribute type found.");
17171723
return llvm::cast<mlir::StringAttr>(expansion);
17181724
}
@@ -1721,15 +1727,17 @@ mlir::Attribute DecisionsAttr::copy_of() const {
17211727
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
17221728
auto copy_of = derived.get("copy_of");
17231729
if (!copy_of) return nullptr;
1724-
assert(copy_of.isa<mlir::Attribute>() && "incorrect Attribute type found.");
1730+
assert(llvm::isa<mlir::Attribute>(copy_of) &&
1731+
"incorrect Attribute type found.");
17251732
return llvm::cast<mlir::Attribute>(copy_of);
17261733
}
17271734

17281735
mlir::ArrayAttr DecisionsAttr::operands() const {
17291736
auto derived = llvm::cast<mlir::DictionaryAttr>(*this);
17301737
auto operands = derived.get("operands");
17311738
if (!operands) return nullptr;
1732-
assert(operands.isa<mlir::ArrayAttr>() && "incorrect Attribute type found.");
1739+
assert(llvm::isa<mlir::ArrayAttr>(operands) &&
1740+
"incorrect Attribute type found.");
17331741
return llvm::cast<mlir::ArrayAttr>(operands);
17341742
}
17351743

sair_base.td

+8-11
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def SairEmptyDomainShapeAttr :
158158
def SairResultDomainShapeAttr :
159159
DerivedAttr<"DomainShapeAttr", [{
160160
mlir::Type type = getOperation()->getResult(0).getType();
161-
return type.cast<ShapedType>().Shape();
161+
return llvm::cast<ShapedType>(type).Shape();
162162
}]> {
163163
let convertFromStorage = [{$_self}];
164164
}
@@ -248,7 +248,7 @@ def SairValue : Type<CPred<"isa<ValueType>($_self)">, "value">;
248248

249249
// Predicate that checks the element type of a Sair value.
250250
class SairElementTypePred<Type type>
251-
: SubstLeaves<"$_self", "$_self.cast<ValueType>().ElementType()",
251+
: SubstLeaves<"$_self", "llvm::cast<ValueType>($_self).ElementType()",
252252
type.predicate>;
253253

254254
// Type constraint for Sair values with a specific element type.
@@ -420,7 +420,7 @@ def SairOpInterface : OpInterface<"SairOp"> {
420420
"Returns lowering decisions for the given operation instance",
421421
"DecisionsAttr", "GetDecisions", (ins "int":$instance), [{}], [{
422422
mlir::ArrayAttr instances = *$_op.getInstances();
423-
return instances.getValue()[instance].cast<DecisionsAttr>();
423+
return llvm::cast<DecisionsAttr>(instances.getValue()[instance]);
424424
}]
425425
>,
426426
InterfaceMethod<
@@ -504,24 +504,23 @@ def SairValueProducerOp : OpInterface<"ValueProducerOp"> {
504504
"llvm::ArrayRef<mlir::Attribute>", "GetCopies", (ins "int":$result), [{}], [{
505505
auto all_copies = $_op.getCopiesAttr();
506506
if (all_copies == nullptr) return {};
507-
return all_copies.getValue()[result]
508-
.template cast<mlir::ArrayAttr>().getValue();
507+
return llvm::cast<mlir::ArrayAttr>(all_copies.getValue()[result]).getValue();
509508
}]>,
510509
InterfaceMethod<
511510
"Indicates if the operation has any copy set in its `copies` attribute`",
512511
"bool", "HasCopies", (ins), [{}], [{
513512
auto all_copies = $_op.getCopiesAttr();
514513
if (all_copies == nullptr) return false;
515514
return llvm::any_of(all_copies.getValue(), [](mlir::Attribute attr) {
516-
return !attr.cast<mlir::ArrayAttr>().empty();
515+
return !llvm::cast<mlir::ArrayAttr>(attr).empty();
517516
});
518517
}]>,
519518
InterfaceMethod<
520519
"Set decisions for the given copy of the given result.",
521520
"void", "SetCopy",
522521
(ins "int":$result, "int":$copy, "DecisionsAttr":$decisions), [{}], [{
523522
auto all_copies = llvm::to_vector<4>(*$_op.getCopies());
524-
auto result_copies_attr = all_copies[result].template cast<mlir::ArrayAttr>();
523+
auto result_copies_attr = llvm::cast<mlir::ArrayAttr>(all_copies[result]);
525524
auto result_copies = llvm::to_vector<4>(result_copies_attr.getValue());
526525

527526
result_copies[copy] = decisions;
@@ -582,10 +581,8 @@ def SairFromToMemRefOp : OpInterface<"FromToMemRefOp"> {
582581
InterfaceMethod<"Buffer name", "llvm::StringRef", "getBufferName">,
583582
InterfaceMethod<"Memref type", "mlir::MemRefType", "MemRefType", (ins),
584583
[{}], [{
585-
return $_op.MemRef()
586-
.GetType()
587-
.ElementType()
588-
.template cast<mlir::MemRefType>();
584+
return llvm::cast<mlir::MemRefType>(
585+
$_op.MemRef().GetType().ElementType());
589586
}]>,
590587
InterfaceMethod<"Mapping from value domain to layout", "MappingAttr",
591588
"Layout", (ins), [{}], [{

sair_dialect.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,9 @@ namespace {
286286
// Accepts a ray stream so that it can be used from different flavors of
287287
// printers.
288288
void PrintMappingExpr(MappingExpr expr, llvm::raw_ostream &os) {
289-
if (expr.isa<MappingNoneExpr>()) {
289+
if (llvm::isa<MappingNoneExpr>(expr)) {
290290
os << MappingNoneExpr::kAttrName;
291-
} else if (expr.isa<MappingUnknownExpr>()) {
291+
} else if (llvm::isa<MappingUnknownExpr>(expr)) {
292292
os << MappingUnknownExpr::kAttrName;
293293
} else if (auto dim_expr = llvm::dyn_cast<MappingDimExpr>(expr)) {
294294
os << "d" << dim_expr.dimension();
@@ -325,7 +325,7 @@ void PrintDomainShapeDim(const DomainShapeDim &dimension,
325325
mlir::DialectAsmPrinter &os) {
326326
if (auto static_range = llvm::dyn_cast<StaticRangeType>(dimension.type())) {
327327
Print(static_range, os);
328-
} else if (dimension.type().isa<DynRangeType>()) {
328+
} else if (llvm::isa<DynRangeType>(dimension.type())) {
329329
os << DynRangeType::Name();
330330
} else {
331331
llvm_unreachable("unsupported dimension type");

0 commit comments

Comments
 (0)