diff --git a/examples/utils/print_typst.py b/examples/utils/print_typst.py index 20e21a6a..a2b3cc4f 100644 --- a/examples/utils/print_typst.py +++ b/examples/utils/print_typst.py @@ -18,8 +18,10 @@ @flyc.jit def visualize(): + mma_atom = fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 16, fx.Float16)) + tiled_mma = fx.make_tiled_mma( - fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 16, fx.Float16)), + mma_atom, fx.make_layout((1, 2, 1), (0, 1, 0)), fx.make_tile(16, 32, fx.make_layout((4, 4, 2), (1, 8, 4))), ) @@ -34,6 +36,7 @@ def visualize(): tiled_mma, ) + fx.utils.print_typst(mma_atom, file=OUTPUT_TYPST) fx.utils.print_typst(tiled_mma, file=OUTPUT_TYPST) fx.utils.print_typst(tiled_copy, file=OUTPUT_TYPST) fx.utils.print_typst(swizzle_layout, file=OUTPUT_TYPST) diff --git a/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td b/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td index dd1b83fc..0b0e13e6 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td +++ b/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td @@ -60,7 +60,7 @@ def Fly_CopyOpTypeInterface : TypeInterface<"CopyOpTypeInterface", [Fly_MayStati ]; } -def Fly_MmaAtomTypeInterface : TypeInterface<"MmaAtomTypeInterface", [Fly_MayStaticTypeInterface]> { +def Fly_MmaOpTypeInterface : TypeInterface<"MmaOpTypeInterface", [Fly_MayStaticTypeInterface]> { let cppNamespace = "::mlir::fly"; let methods = [ InterfaceMethod<"", "::mlir::Attribute", "getThrLayout", (ins)>, diff --git a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td index e5340a13..42ac167f 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td +++ b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td @@ -263,6 +263,33 @@ def Fly_CopyAtom : Fly_Type<"CopyAtom", "copy_atom", [ ]; } +def Fly_MmaAtom : Fly_Type<"MmaAtom", "mma_atom", [ + DeclareTypeInterfaceMethods +]> { + let parameters = (ins + "Type":$mmaOp + ); + let assemblyFormat = "`<` $mmaOp `>`"; + + let extraClassDeclaration = [{ + ::mlir::Attribute getThrLayout() const; + ::mlir::Attribute getShapeMNK() const; + ::mlir::Type getValTypeA() const; + ::mlir::Type getValTypeB() const; + ::mlir::Type getValTypeC() const; + ::mlir::Type getValTypeD() const; + ::mlir::Attribute getThrValLayoutA() const; + ::mlir::Attribute getThrValLayoutB() const; + ::mlir::Attribute getThrValLayoutC() const; + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$mmaOp), [{ + return $_get(mmaOp.getContext(), mmaOp); + }]> + ]; +} + def Fly_CopyOpUniversalCopy : Fly_Type<"CopyOpUniversalCopy", "universal_copy", [ DeclareTypeInterfaceMethods, DeclareTypeInterfaceMethods @@ -273,9 +300,9 @@ def Fly_CopyOpUniversalCopy : Fly_Type<"CopyOpUniversalCopy", "universal_copy", let assemblyFormat = "`<` $bitSize `>`"; } -def Fly_MmaAtomUniversalFMA : Fly_Type<"MmaAtomUniversalFMA", "universal_fma", [ +def Fly_MmaOpUniversalFMA : Fly_Type<"MmaOpUniversalFMA", "universal_fma", [ DeclareTypeInterfaceMethods, - DeclareTypeInterfaceMethods + DeclareTypeInterfaceMethods ]> { let parameters = (ins "Type":$elemTy); diff --git a/include/flydsl/Dialect/Fly/Utils/TiledOpUtils.h b/include/flydsl/Dialect/Fly/Utils/TiledOpUtils.h index 91d3fb42..fa28e35a 100644 --- a/include/flydsl/Dialect/Fly/Utils/TiledOpUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/TiledOpUtils.h @@ -68,7 +68,7 @@ Layout layoutTiledCopyThrValView(LayoutBuilder &builder, CopyAtomType co } template -Layout layoutTiledMmaThrValView(LayoutBuilder &builder, MmaAtomTypeInterface mmaAtom, +Layout layoutTiledMmaThrValView(LayoutBuilder &builder, MmaAtomType mmaAtom, LayoutAttr tiledShape2D, IntTupleAttr atomShape2D, LayoutAttr atomLayoutThrVal, TileAttr permutation2D, Layout trgLayout) { @@ -262,7 +262,7 @@ Layout layoutTiledCopyRetile(LayoutBuilder &builder, CopyAtomType copyAt } template -Layout layoutTiledMmaThrValOperandView(LayoutBuilder &builder, MmaAtomTypeInterface mmaAtom, +Layout layoutTiledMmaThrValOperandView(LayoutBuilder &builder, MmaAtomType mmaAtom, LayoutAttr atomLayoutMNK, TileAttr permutationMNK, MmaOperand operandId, Layout trgLayout) { auto *ctx = atomLayoutMNK.getContext(); diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td index 37be1bf5..b1a5c4b9 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td @@ -21,10 +21,10 @@ def FlyROCDL_Dialect : Dialect { let useDefaultTypePrinterParser = 1; } -class FlyxROCL_MmaAtom traits = []> +class FlyxROCL_MmaOp traits = []> : TypeDef, - DeclareTypeInterfaceMethods + DeclareTypeInterfaceMethods ])> { let mnemonic = typeMnemonic; } diff --git a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td index 6bd8fa44..8d020a0a 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td @@ -7,10 +7,10 @@ include "flydsl/Dialect/FlyROCDL/IR/Dialect.td" //===----------------------------------------------------------------------===// -// MmaAtom CDNA3 +// MmaOp CDNA3 //===----------------------------------------------------------------------===// -def FlyROCDL_MmaAtomCDNA3_MFMA : FlyxROCL_MmaAtom<"MmaAtomCDNA3_MFMA", "atom.cdna3.mfma", []> { +def FlyROCDL_MmaOpCDNA3_MFMA : FlyxROCL_MmaOp<"MmaOpCDNA3_MFMA", "cdna3.mfma", []> { let parameters = (ins "int32_t":$m, "int32_t":$n, @@ -30,16 +30,16 @@ def FlyROCDL_MmaAtomCDNA3_MFMA : FlyxROCL_MmaAtom<"MmaAtomCDNA3_MFMA", "atom.cdn } //===----------------------------------------------------------------------===// -// MmaAtom CDNA4 +// MmaOp CDNA4 //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// -// MmaAtom GFX1250 — WMMA wave32 +// MmaOp GFX1250 — WMMA wave32 //===----------------------------------------------------------------------===// -def FlyROCDL_MmaAtomGFX1250_WMMA : FlyxROCL_MmaAtom<"MmaAtomGFX1250_WMMA", "atom.gfx1250.wmma", []> { +def FlyROCDL_MmaOpGFX1250_WMMA : FlyxROCL_MmaOp<"MmaOpGFX1250_WMMA", "gfx1250.wmma", []> { let parameters = (ins "int32_t":$m, "int32_t":$n, diff --git a/lib/Bindings/Python/FlyExtension.cpp b/lib/Bindings/Python/FlyExtension.cpp index a432c7af..3561584d 100644 --- a/lib/Bindings/Python/FlyExtension.cpp +++ b/lib/Bindings/Python/FlyExtension.cpp @@ -471,28 +471,6 @@ struct PyCoordTensorType : PyConcreteType { } }; -// --------------------------------------------------------------------------- -// CopyOpUniversalCopyType -// --------------------------------------------------------------------------- -struct PyCopyOpUniversalCopyType : PyConcreteType { - FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::CopyOpUniversalCopyType, "CopyOpUniversalCopyType"); - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](int32_t bitSize, DefaultingPyMlirContext context) { - MLIRContext *ctx = unwrap(context.get()->get()); - return PyCopyOpUniversalCopyType(context->getRef(), - wrap(CopyOpUniversalCopyType::get(ctx, bitSize))); - }, - "bitSize"_a, nb::kw_only(), "context"_a = nb::none(), - "Create a CopyOpUniversalCopyType with bit size"); - - c.def_prop_ro("bit_size", - [](PyCopyOpUniversalCopyType &self) { return self.toCppType().getBitSize(); }); - } -}; - // --------------------------------------------------------------------------- // CopyAtomType // --------------------------------------------------------------------------- @@ -529,43 +507,37 @@ struct PyCopyAtomType : PyConcreteType { }; // --------------------------------------------------------------------------- -// MmaAtomUniversalFMAType +// MmaAtomType // --------------------------------------------------------------------------- -struct PyMmaAtomUniversalFMAType : PyConcreteType { - FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::MmaAtomUniversalFMAType, "MmaAtomUniversalFMAType"); +struct PyMmaAtomType : PyConcreteType { + FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::MmaAtomType, "MmaAtomType"); static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyType &elemTyObj, DefaultingPyMlirContext context) { - return PyMmaAtomUniversalFMAType(context->getRef(), - wrap(MmaAtomUniversalFMAType::get(unwrap(elemTyObj)))); + [](PyType &mmaOp, DefaultingPyMlirContext context) { + return PyMmaAtomType(context->getRef(), wrap(MmaAtomType::get(unwrap(mmaOp)))); }, - "elem_ty"_a, nb::kw_only(), "context"_a = nb::none(), - "Create a MmaAtomUniversalFMAType with element type"); + "mma_op"_a, nb::kw_only(), "context"_a = nb::none(), + "Create a MmaAtomType wrapping an MmaOpTypeInterface type"); - c.def_prop_ro("elem_ty", [](PyMmaAtomUniversalFMAType &self) -> MlirType { - return wrap(self.toCppType().getElemTy()); + c.def_prop_ro("mma_op", [](PyMmaAtomType &self) -> MlirType { + return wrap(self.toCppType().getMmaOp()); }); - c.def_prop_ro("thr_layout", [](PyMmaAtomUniversalFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrLayout()))); + c.def_prop_ro("thr_layout", [](PyMmaAtomType &self) -> MlirType { + return wrap(LayoutType::get(cast(self.toCppType().getThrLayout()))); }); - c.def_prop_ro("shape_mnk", [](PyMmaAtomUniversalFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(IntTupleType::get(cast(ty.getShapeMNK()))); + c.def_prop_ro("shape_mnk", [](PyMmaAtomType &self) -> MlirType { + return wrap(IntTupleType::get(cast(self.toCppType().getShapeMNK()))); }); - c.def_prop_ro("tv_layout_a", [](PyMmaAtomUniversalFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrValLayoutA()))); + c.def_prop_ro("tv_layout_a", [](PyMmaAtomType &self) -> MlirType { + return wrap(LayoutType::get(cast(self.toCppType().getThrValLayoutA()))); }); - c.def_prop_ro("tv_layout_b", [](PyMmaAtomUniversalFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrValLayoutB()))); + c.def_prop_ro("tv_layout_b", [](PyMmaAtomType &self) -> MlirType { + return wrap(LayoutType::get(cast(self.toCppType().getThrValLayoutB()))); }); - c.def_prop_ro("tv_layout_c", [](PyMmaAtomUniversalFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrValLayoutC()))); + c.def_prop_ro("tv_layout_c", [](PyMmaAtomType &self) -> MlirType { + return wrap(LayoutType::get(cast(self.toCppType().getThrValLayoutC()))); }); } }; @@ -621,34 +593,34 @@ struct PyTiledMmaType : PyConcreteType { }); c.def_prop_ro("tile_size_mnk", [](PyTiledMmaType &self) -> MlirType { auto ty = self.toCppType(); - auto mmaAtom = cast(ty.getMmaAtom()); + auto mmaAtom = cast(ty.getMmaAtom()); auto result = tiledMmaGetTileSizeMNK(mmaAtom, ty.getAtomLayout().getAttr(), ty.getPermutation().getAttr()); return wrap(IntTupleType::get(result)); }); c.def_prop_ro("thr_layout_vmnk", [](PyTiledMmaType &self) -> MlirType { auto ty = self.toCppType(); - auto mmaAtom = cast(ty.getMmaAtom()); + auto mmaAtom = cast(ty.getMmaAtom()); auto result = tiledMmaGetThrLayoutVMNK(mmaAtom, ty.getAtomLayout().getAttr()); return wrap(LayoutType::get(result)); }); c.def_prop_ro("tiled_tv_layout_a", [](PyTiledMmaType &self) -> MlirType { auto ty = self.toCppType(); - auto mmaAtom = cast(ty.getMmaAtom()); + auto mmaAtom = cast(ty.getMmaAtom()); auto result = tiledMmaGetTiledThrValLayout(mmaAtom, ty.getAtomLayout().getAttr(), ty.getPermutation().getAttr(), MmaOperand::A); return wrap(LayoutType::get(result)); }); c.def_prop_ro("tiled_tv_layout_b", [](PyTiledMmaType &self) -> MlirType { auto ty = self.toCppType(); - auto mmaAtom = cast(ty.getMmaAtom()); + auto mmaAtom = cast(ty.getMmaAtom()); auto result = tiledMmaGetTiledThrValLayout(mmaAtom, ty.getAtomLayout().getAttr(), ty.getPermutation().getAttr(), MmaOperand::B); return wrap(LayoutType::get(result)); }); c.def_prop_ro("tiled_tv_layout_c", [](PyTiledMmaType &self) -> MlirType { auto ty = self.toCppType(); - auto mmaAtom = cast(ty.getMmaAtom()); + auto mmaAtom = cast(ty.getMmaAtom()); auto result = tiledMmaGetTiledThrValLayout(mmaAtom, ty.getAtomLayout().getAttr(), ty.getPermutation().getAttr(), MmaOperand::C); return wrap(LayoutType::get(result)); @@ -656,6 +628,47 @@ struct PyTiledMmaType : PyConcreteType { } }; +// --------------------------------------------------------------------------- +// CopyOpUniversalCopyType +// --------------------------------------------------------------------------- +struct PyCopyOpUniversalCopyType : PyConcreteType { + FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::CopyOpUniversalCopyType, "CopyOpUniversalCopyType"); + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int32_t bitSize, DefaultingPyMlirContext context) { + MLIRContext *ctx = unwrap(context.get()->get()); + return PyCopyOpUniversalCopyType(context->getRef(), + wrap(CopyOpUniversalCopyType::get(ctx, bitSize))); + }, + "bitSize"_a, nb::kw_only(), "context"_a = nb::none(), + "Create a CopyOpUniversalCopyType with bit size"); + } +}; + +// --------------------------------------------------------------------------- +// MmaOpUniversalFMAType +// --------------------------------------------------------------------------- +struct PyMmaOpUniversalFMAType : PyConcreteType { + FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::MmaOpUniversalFMAType, "MmaOpUniversalFMAType"); + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elemTyObj, DefaultingPyMlirContext context) { + return PyMmaOpUniversalFMAType(context->getRef(), + wrap(MmaOpUniversalFMAType::get(unwrap(elemTyObj)))); + }, + "elem_ty"_a, nb::kw_only(), "context"_a = nb::none(), + "Create a MmaOpUniversalFMAType with element type"); + + c.def_prop_ro("elem_ty", [](PyMmaOpUniversalFMAType &self) -> MlirType { + return wrap(self.toCppType().getElemTy()); + }); + } +}; + } // namespace fly } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python @@ -729,11 +742,12 @@ NB_MODULE(_mlirDialectsFly, m) { ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyPointerType::bind(m); ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyMemRefType::bind(m); ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyCoordTensorType::bind(m); - ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyCopyOpUniversalCopyType::bind(m); ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyCopyAtomType::bind(m); - ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyMmaAtomUniversalFMAType::bind(m); + ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyMmaAtomType::bind(m); ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyTiledCopyType::bind(m); ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyTiledMmaType::bind(m); + ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyCopyOpUniversalCopyType::bind(m); + ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyMmaOpUniversalFMAType::bind(m); m.def( "set_llvm_option_bool", @@ -743,12 +757,10 @@ NB_MODULE(_mlirDialectsFly, m) { if (rc == 1) throw std::runtime_error("Unknown LLVM option: " + name); if (rc == 2) - throw std::runtime_error("LLVM option '" + name + - "' is not a bool option"); + throw std::runtime_error("LLVM option '" + name + "' is not a bool option"); return oldValue; }, - "name"_a, "value"_a, - "Set an LLVM bool cl::opt at runtime; returns the previous value."); + "name"_a, "value"_a, "Set an LLVM bool cl::opt at runtime; returns the previous value."); m.def( "set_llvm_option_int", @@ -758,12 +770,10 @@ NB_MODULE(_mlirDialectsFly, m) { if (rc == 1) throw std::runtime_error("Unknown LLVM option: " + name); if (rc == 2) - throw std::runtime_error("LLVM option '" + name + - "' is not an int option"); + throw std::runtime_error("LLVM option '" + name + "' is not an int option"); return oldValue; }, - "name"_a, "value"_a, - "Set an LLVM int cl::opt at runtime; returns the previous value."); + "name"_a, "value"_a, "Set an LLVM int cl::opt at runtime; returns the previous value."); m.def( "set_llvm_option_str", @@ -773,12 +783,10 @@ NB_MODULE(_mlirDialectsFly, m) { if (rc == 1) throw std::runtime_error("Unknown LLVM option: " + name); if (rc == 2) - throw std::runtime_error("LLVM option '" + name + - "' is not a string option"); + throw std::runtime_error("LLVM option '" + name + "' is not a string option"); std::string result(oldValue ? oldValue : ""); flydslFreeLLVMOptionStr(oldValue); return result; }, - "name"_a, "value"_a, - "Set an LLVM string cl::opt at runtime; returns the previous value."); + "name"_a, "value"_a, "Set an LLVM string cl::opt at runtime; returns the previous value."); } diff --git a/lib/Bindings/Python/FlyROCDLExtension.cpp b/lib/Bindings/Python/FlyROCDLExtension.cpp index 022310e3..ed68ee22 100644 --- a/lib/Bindings/Python/FlyROCDLExtension.cpp +++ b/lib/Bindings/Python/FlyROCDLExtension.cpp @@ -20,109 +20,40 @@ namespace python { namespace MLIR_BINDINGS_PYTHON_DOMAIN { namespace fly_rocdl { -struct PyMmaAtomCDNA3_MFMAType : PyConcreteType { - FLYDSL_REGISTER_TYPE_BINDING(MmaAtomCDNA3_MFMAType, "MmaAtomCDNA3_MFMAType"); +struct PyMmaOpCDNA3_MFMAType : PyConcreteType { + FLYDSL_REGISTER_TYPE_BINDING(MmaOpCDNA3_MFMAType, "MmaOpCDNA3_MFMAType"); static void bindDerived(ClassTy &c) { c.def_static( "get", [](int32_t m, int32_t n, int32_t k, PyType &elemTyA, PyType &elemTyB, PyType &elemTyAcc, DefaultingPyMlirContext context) { - return PyMmaAtomCDNA3_MFMAType( - context->getRef(), - wrap(MmaAtomCDNA3_MFMAType::get(m, n, k, unwrap(elemTyA), unwrap(elemTyB), - unwrap(elemTyAcc)))); + return PyMmaOpCDNA3_MFMAType(context->getRef(), wrap(MmaOpCDNA3_MFMAType::get( + m, n, k, unwrap(elemTyA), + unwrap(elemTyB), unwrap(elemTyAcc)))); }, "m"_a, "n"_a, "k"_a, "elem_ty_a"_a, "elem_ty_b"_a, "elem_ty_acc"_a, nb::kw_only(), "context"_a = nb::none(), - "Create a MmaAtomCDNA3_MFMAType with m, n, k dimensions and element types"); - - c.def_prop_ro("m", [](PyMmaAtomCDNA3_MFMAType &self) { return self.toCppType().getM(); }); - c.def_prop_ro("n", [](PyMmaAtomCDNA3_MFMAType &self) { return self.toCppType().getN(); }); - c.def_prop_ro("k", [](PyMmaAtomCDNA3_MFMAType &self) { return self.toCppType().getK(); }); - c.def_prop_ro("elem_ty_a", [](PyMmaAtomCDNA3_MFMAType &self) -> MlirType { - return wrap(self.toCppType().getElemTyA()); - }); - c.def_prop_ro("elem_ty_b", [](PyMmaAtomCDNA3_MFMAType &self) -> MlirType { - return wrap(self.toCppType().getElemTyB()); - }); - c.def_prop_ro("elem_ty_acc", [](PyMmaAtomCDNA3_MFMAType &self) -> MlirType { - return wrap(self.toCppType().getElemTyAcc()); - }); - - c.def_prop_ro("thr_layout", [](PyMmaAtomCDNA3_MFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrLayout()))); - }); - c.def_prop_ro("shape_mnk", [](PyMmaAtomCDNA3_MFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(IntTupleType::get(cast(ty.getShapeMNK()))); - }); - c.def_prop_ro("tv_layout_a", [](PyMmaAtomCDNA3_MFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrValLayoutA()))); - }); - c.def_prop_ro("tv_layout_b", [](PyMmaAtomCDNA3_MFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrValLayoutB()))); - }); - c.def_prop_ro("tv_layout_c", [](PyMmaAtomCDNA3_MFMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrValLayoutC()))); - }); + "Create a MmaOpCDNA3_MFMAType with m, n, k dimensions and element types"); } }; -struct PyMmaAtomGFX1250_WMMAType : PyConcreteType { - FLYDSL_REGISTER_TYPE_BINDING(MmaAtomGFX1250_WMMAType, "MmaAtomGFX1250_WMMAType"); +struct PyMmaOpGFX1250_WMMAType : PyConcreteType { + FLYDSL_REGISTER_TYPE_BINDING(MmaOpGFX1250_WMMAType, "MmaOpGFX1250_WMMAType"); static void bindDerived(ClassTy &c) { c.def_static( "get", [](int32_t m, int32_t n, int32_t k, PyType &elemTyA, PyType &elemTyB, PyType &elemTyAcc, DefaultingPyMlirContext context) { - return PyMmaAtomGFX1250_WMMAType( + return PyMmaOpGFX1250_WMMAType( context->getRef(), - wrap(MmaAtomGFX1250_WMMAType::get(m, n, k, unwrap(elemTyA), unwrap(elemTyB), - unwrap(elemTyAcc)))); + wrap(MmaOpGFX1250_WMMAType::get(m, n, k, unwrap(elemTyA), unwrap(elemTyB), + unwrap(elemTyAcc)))); }, "m"_a, "n"_a, "k"_a, "elem_ty_a"_a, "elem_ty_b"_a, "elem_ty_acc"_a, nb::kw_only(), "context"_a = nb::none(), - "Create a MmaAtomGFX1250_WMMAType with m, n, k dimensions and element types"); - - c.def_prop_ro("m", [](PyMmaAtomGFX1250_WMMAType &self) { return self.toCppType().getM(); }); - c.def_prop_ro("n", [](PyMmaAtomGFX1250_WMMAType &self) { return self.toCppType().getN(); }); - c.def_prop_ro("k", [](PyMmaAtomGFX1250_WMMAType &self) { return self.toCppType().getK(); }); - c.def_prop_ro("elem_ty_a", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { - return wrap(self.toCppType().getElemTyA()); - }); - c.def_prop_ro("elem_ty_b", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { - return wrap(self.toCppType().getElemTyB()); - }); - c.def_prop_ro("elem_ty_acc", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { - return wrap(self.toCppType().getElemTyAcc()); - }); - - c.def_prop_ro("thr_layout", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrLayout()))); - }); - c.def_prop_ro("shape_mnk", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(IntTupleType::get(cast(ty.getShapeMNK()))); - }); - c.def_prop_ro("tv_layout_a", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrValLayoutA()))); - }); - c.def_prop_ro("tv_layout_b", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrValLayoutB()))); - }); - c.def_prop_ro("tv_layout_c", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { - auto ty = cast(self.toCppType()); - return wrap(LayoutType::get(cast(ty.getThrValLayoutC()))); - }); + "Create a MmaOpGFX1250_WMMAType with m, n, k dimensions and element types"); } }; @@ -139,9 +70,6 @@ struct PyCopyOpCDNA3BufferCopyType : PyConcreteType }, "bit_size"_a, nb::kw_only(), "context"_a = nb::none(), "Create a CopyOpCDNA3BufferCopyType with the given bit size"); - - c.def_prop_ro("bit_size", - [](PyCopyOpCDNA3BufferCopyType &self) { return self.toCppType().getBitSize(); }); } }; @@ -153,7 +81,7 @@ struct PyCopyOpCDNA3BufferCopyType : PyConcreteType NB_MODULE(_mlirDialectsFlyROCDL, m) { m.doc() = "MLIR Python FlyROCDL Extension"; - ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyMmaAtomCDNA3_MFMAType::bind(m); - ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyMmaAtomGFX1250_WMMAType::bind(m); + ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyMmaOpCDNA3_MFMAType::bind(m); + ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyMmaOpGFX1250_WMMAType::bind(m); ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyCopyOpCDNA3BufferCopyType::bind(m); } diff --git a/lib/Bindings/Python/TiledOpTraits.cpp b/lib/Bindings/Python/TiledOpTraits.cpp index f751ab83..d504235e 100644 --- a/lib/Bindings/Python/TiledOpTraits.cpp +++ b/lib/Bindings/Python/TiledOpTraits.cpp @@ -65,7 +65,7 @@ LayoutAttr tiledCopyGetTiledThrValLayoutDst(CopyAtomType copyAtom, LayoutAttr ti cast(copyAtom.getThrValLayoutDst())); } -LayoutAttr tiledMmaGetTiledThrValLayout(MmaAtomTypeInterface mmaAtom, LayoutAttr atomLayoutMNK, +LayoutAttr tiledMmaGetTiledThrValLayout(MmaAtomType mmaAtom, LayoutAttr atomLayoutMNK, TileAttr permutationMNK, MmaOperand operandId) { auto *ctx = atomLayoutMNK.getContext(); LayoutBuilder attrBuilder(ctx); @@ -165,7 +165,7 @@ LayoutAttr tiledMmaGetTiledThrValLayout(MmaAtomTypeInterface mmaAtom, LayoutAttr return LayoutAttr::get(finalShape, finalStride); } -IntTupleAttr tiledMmaGetTileSizeMNK(MmaAtomTypeInterface mmaAtom, LayoutAttr atomLayoutMNK, +IntTupleAttr tiledMmaGetTileSizeMNK(MmaAtomType mmaAtom, LayoutAttr atomLayoutMNK, TileAttr permutationMNK) { auto *ctx = atomLayoutMNK.getContext(); LayoutBuilder attrBuilder(ctx); @@ -187,7 +187,7 @@ IntTupleAttr tiledMmaGetTileSizeMNK(MmaAtomTypeInterface mmaAtom, LayoutAttr ato return IntTupleAttr::get(ArrayAttr::get(ctx, tileSizeElems)); } -LayoutAttr tiledMmaGetThrLayoutVMNK(MmaAtomTypeInterface mmaAtom, LayoutAttr atomLayoutMNK) { +LayoutAttr tiledMmaGetThrLayoutVMNK(MmaAtomType mmaAtom, LayoutAttr atomLayoutMNK) { auto *ctx = atomLayoutMNK.getContext(); LayoutBuilder attrBuilder(ctx); diff --git a/lib/Bindings/Python/TiledOpTraits.h b/lib/Bindings/Python/TiledOpTraits.h index 7aa45c43..ecf58a50 100644 --- a/lib/Bindings/Python/TiledOpTraits.h +++ b/lib/Bindings/Python/TiledOpTraits.h @@ -14,13 +14,13 @@ LayoutAttr tiledCopyGetTiledThrValLayoutSrc(CopyAtomType copyAtom, LayoutAttr ti LayoutAttr tiledCopyGetTiledThrValLayoutDst(CopyAtomType copyAtom, LayoutAttr tiledLayoutThrVal, TileAttr tileMN); -LayoutAttr tiledMmaGetTiledThrValLayout(MmaAtomTypeInterface mmaAtom, LayoutAttr atomLayoutMNK, +LayoutAttr tiledMmaGetTiledThrValLayout(MmaAtomType mmaAtom, LayoutAttr atomLayoutMNK, TileAttr permutationMNK, MmaOperand operandId); -IntTupleAttr tiledMmaGetTileSizeMNK(MmaAtomTypeInterface mmaAtom, LayoutAttr atomLayoutMNK, +IntTupleAttr tiledMmaGetTileSizeMNK(MmaAtomType mmaAtom, LayoutAttr atomLayoutMNK, TileAttr permutationMNK); -LayoutAttr tiledMmaGetThrLayoutVMNK(MmaAtomTypeInterface mmaAtom, LayoutAttr atomLayoutMNK); +LayoutAttr tiledMmaGetThrLayoutVMNK(MmaAtomType mmaAtom, LayoutAttr atomLayoutMNK); } // namespace mlir::fly diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index 09117ec5..317a1e69 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -334,8 +334,8 @@ class PtrLoadOpLowering : public OpConversionPattern { if (auto vecTy = dyn_cast(loadTy)) { auto swizzle = flyPtrTy.getSwizzle(); if (!swizzle.isTrivialSwizzle()) { - int64_t vecBytes = vecTy.getNumElements() * - vecTy.getElementType().getIntOrFloatBitWidth() / 8; + int64_t vecBytes = + vecTy.getNumElements() * vecTy.getElementType().getIntOrFloatBitWidth() / 8; int64_t baseBytes = int64_t{1} << swizzle.getBase(); if (baseBytes % vecBytes != 0) return rewriter.notifyMatchFailure( @@ -378,8 +378,8 @@ class PtrStoreOpLowering : public OpConversionPattern { if (auto vecTy = dyn_cast(value.getType())) { auto swizzle = flyPtrTy.getSwizzle(); if (!swizzle.isTrivialSwizzle()) { - int64_t vecBytes = vecTy.getNumElements() * - vecTy.getElementType().getIntOrFloatBitWidth() / 8; + int64_t vecBytes = + vecTy.getNumElements() * vecTy.getElementType().getIntOrFloatBitWidth() / 8; int64_t baseBytes = int64_t{1} << swizzle.getBase(); if (baseBytes % vecBytes != 0) return rewriter.notifyMatchFailure( @@ -536,10 +536,9 @@ class MmaAtomCallLowering : public OpConversionPattern { LogicalResult matchAndRewrite(MmaAtomCall op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type mmaAtomType = op.getMmaAtom().getType(); - if (!isa(mmaAtomType)) - return rewriter.notifyMatchFailure(op, - "expected MmaAtomTypeInterface type for mmaAtom operand"); + auto mmaAtomTy = dyn_cast(op.getMmaAtom().getType()); + if (!mmaAtomTy) + return rewriter.notifyMatchFailure(op, "expected MmaAtomType for mmaAtom operand"); Location loc = op.getLoc(); @@ -553,11 +552,11 @@ class MmaAtomCallLowering : public OpConversionPattern { !isa(bPtr.getType()) || !isa(cPtr.getType())) return rewriter.notifyMatchFailure(op, "expected llvm.ptr operands after type conversion"); - if (auto universalFma = dyn_cast(mmaAtomType)) + if (auto universalFma = dyn_cast(mmaAtomTy.getMmaOp())) return lowerUniversalFMA(op, rewriter, loc, universalFma, dPtr, aPtr, bPtr, cPtr); - else if (auto cdna3Mfma = dyn_cast(mmaAtomType)) + else if (auto cdna3Mfma = dyn_cast(mmaAtomTy.getMmaOp())) return lowerCDNA3MFMA(op, rewriter, loc, cdna3Mfma, dPtr, aPtr, bPtr, cPtr); - else if (auto gfx1250Wmma = dyn_cast(mmaAtomType)) + else if (auto gfx1250Wmma = dyn_cast(mmaAtomTy.getMmaOp())) return lowerGFX1250WMMA(op, rewriter, loc, gfx1250Wmma, dPtr, aPtr, bPtr, cPtr); return rewriter.notifyMatchFailure(op, "unsupported MmaAtom type"); @@ -565,8 +564,8 @@ class MmaAtomCallLowering : public OpConversionPattern { private: LogicalResult lowerUniversalFMA(MmaAtomCall op, ConversionPatternRewriter &rewriter, Location loc, - MmaAtomUniversalFMAType atomTy, Value dPtr, Value aPtr, - Value bPtr, Value cPtr) const { + MmaOpUniversalFMAType atomTy, Value dPtr, Value aPtr, Value bPtr, + Value cPtr) const { Type elemTy = atomTy.getElemTy(); Value a = LLVM::LoadOp::create(rewriter, loc, elemTy, aPtr); @@ -660,7 +659,7 @@ class MmaAtomCallLowering : public OpConversionPattern { } LogicalResult lowerCDNA3MFMA(MmaAtomCall op, ConversionPatternRewriter &rewriter, Location loc, - fly_rocdl::MmaAtomCDNA3_MFMAType atomTy, Value dPtr, Value aPtr, + fly_rocdl::MmaOpCDNA3_MFMAType atomTy, Value dPtr, Value aPtr, Value bPtr, Value cPtr) const { int32_t m = atomTy.getM(); int32_t n = atomTy.getN(); @@ -798,19 +797,19 @@ class MmaAtomCallLowering : public OpConversionPattern { Value res; if constexpr (Variant == WmmaVariant::ModsAllReuse) { res = WmmaOp::create(rewriter, loc, accTy, - /*signA=*/false, a, /*signB=*/false, b, - /*modC=*/(uint16_t)0, c) + /*signA=*/false, a, /*signB=*/false, b, + /*modC=*/(uint16_t)0, c) .getResult(); } else if constexpr (Variant == WmmaVariant::ModsC) { res = WmmaOp::create(rewriter, loc, accTy, a, b, - /*modC=*/(uint16_t)0, c, - /*reuseA=*/false, /*reuseB=*/false) + /*modC=*/(uint16_t)0, c, + /*reuseA=*/false, /*reuseB=*/false) .getResult(); } else { static_assert(Variant == WmmaVariant::ModsABClamp); res = WmmaOp::create(rewriter, loc, accTy, - /*signA=*/false, a, /*signB=*/false, b, c, - /*reuseA=*/false, /*reuseB=*/false, /*clamp=*/false) + /*signA=*/false, a, /*signB=*/false, b, c, + /*reuseA=*/false, /*reuseB=*/false, /*clamp=*/false) .getResult(); } LLVM::StoreOp::create(rewriter, loc, res, dPtr); @@ -819,7 +818,7 @@ class MmaAtomCallLowering : public OpConversionPattern { } LogicalResult lowerGFX1250WMMA(MmaAtomCall op, ConversionPatternRewriter &rewriter, Location loc, - fly_rocdl::MmaAtomGFX1250_WMMAType atomTy, Value dPtr, Value aPtr, + fly_rocdl::MmaOpGFX1250_WMMAType atomTy, Value dPtr, Value aPtr, Value bPtr, Value cPtr) const { int32_t m = atomTy.getM(); int32_t n = atomTy.getN(); @@ -842,8 +841,8 @@ class MmaAtomCallLowering : public OpConversionPattern { #define DISPATCH_WMMA(M_, K_, PRED, OP, VARIANT) \ if (m == M_ && n == M_ && k == K_ && (PRED)) \ - return emitWmma(op, rewriter, loc, abTyA, abTyB, accTy, \ - aPtr, bPtr, cPtr, dPtr); + return emitWmma(op, rewriter, loc, abTyA, abTyB, accTy, aPtr, \ + bPtr, cPtr, dPtr); #define DISPATCH_WMMA_FP8(K_, ACC_PRED, ACC_PREFIX) \ DISPATCH_WMMA(16, K_, isFP8(elemTyA) && isFP8(elemTyB) && ACC_PRED, \ @@ -1073,8 +1072,7 @@ class FlyROCDLClusterAttrPass void runOnOperation() override { getOperation()->walk([&](LLVM::LLVMFuncOp func) { - auto clusterAttr = - func->getAttrOfType("rocdl.cluster_dims"); + auto clusterAttr = func->getAttrOfType("rocdl.cluster_dims"); if (!clusterAttr) return; diff --git a/lib/Dialect/Fly/CMakeLists.txt b/lib/Dialect/Fly/CMakeLists.txt index 38194075..3d7c6695 100644 --- a/lib/Dialect/Fly/CMakeLists.txt +++ b/lib/Dialect/Fly/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRFlyDialect IR/FlyDialect.cpp IR/FlyOps.cpp IR/FlyTypeDefs.cpp + IR/FlyUniversalOps.cpp IR/FlyAttrDefs.cpp Utils/IntUtils.cpp Utils/IntTupleUtils.cpp diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index 268dc4b8..7ad60328 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -1554,11 +1554,11 @@ FLY_INFER_RETURN_TYPES(TiledMmaPartitionOp) { "TiledMmaPartitionOp: expected IntTupleType for operand #2, got ", operands[2].getType()); - auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); + auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); if (!mmaAtom) return emitOptionalError( location, - "TiledMmaPartitionOp: TiledMmaType's mma atom does not implement MmaAtomTypeInterface"); + "TiledMmaPartitionOp: TiledMmaType's mma atom is not a MmaAtomType"); LayoutAttr atomLayout = tiledMmaTy.getAtomLayout().getAttr(); TileAttr permutationMNK = tiledMmaTy.getPermutation().getAttr(); @@ -1602,10 +1602,10 @@ FLY_INFER_RETURN_TYPES(TiledMmaPartitionShapeOp) { "TiledMmaPartitionShapeOp: expected IntTupleType for operand #1, got ", operands[1].getType()); - auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); + auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); if (!mmaAtom) - return emitOptionalError(location, "TiledMmaPartitionShapeOp: TiledMmaType's mma atom does not " - "implement MmaAtomTypeInterface"); + return emitOptionalError(location, "TiledMmaPartitionShapeOp: TiledMmaType's mma atom is not " + "a MmaAtomType"); LayoutAttr atomLayout = tiledMmaTy.getAtomLayout().getAttr(); TileAttr permutationMNK = tiledMmaTy.getPermutation().getAttr(); @@ -1641,11 +1641,11 @@ FLY_INFER_RETURN_TYPES(MmaMakeFragmentOp) { "MmaMakeFragmentOp: expected MemRefType for operand #1, got ", operands[1].getType()); - auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); + auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); if (!mmaAtom) return emitOptionalError( location, - "MmaMakeFragmentOp: TiledMmaType's mma atom does not implement MmaAtomTypeInterface"); + "MmaMakeFragmentOp: TiledMmaType's mma atom is not a MmaAtomType"); Type elemTy; switch (operandId) { diff --git a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp index 5126d6ed..1cc1ec32 100644 --- a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp +++ b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp @@ -4,6 +4,7 @@ #include "mlir/IR/DialectImplementation.h" #include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/IntTupleUtils.h" #include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" #include "flydsl/Dialect/Fly/Utils/NormalForm.h" @@ -372,35 +373,12 @@ void MemRefType::print(AsmPrinter &printer) const { printer << ">"; } -#include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" - TileType TiledMmaType::getDefaultPermutationMNK(MLIRContext *ctx) { Attribute noneVal = IntAttr::getNone(ctx); SmallVector elems(3, noneVal); return TileType::get(ctx, TileAttr::get(ArrayAttr::get(ctx, elems))); } -bool CopyOpUniversalCopyType::isStatic() const { return true; } - -Value CopyOpUniversalCopyType::rebuildStaticValue(OpBuilder &builder, Location loc, - Value currentValue) const { - if (currentValue && isa(currentValue.getDefiningOp())) - return nullptr; - return MakeCopyAtomOp::create(builder, loc, CopyAtomType::get(*this, getBitSize()), getBitSize()); -} - -Attribute CopyOpUniversalCopyType::getThrLayout() const { return FxLayout(FxC(1), FxC(1)); } - -Attribute CopyOpUniversalCopyType::getThrBitLayoutSrc() const { - return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); -} -Attribute CopyOpUniversalCopyType::getThrBitLayoutDst() const { - return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); -} -Attribute CopyOpUniversalCopyType::getThrBitLayoutRef() const { - return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); -} - bool CopyAtomType::isStatic() const { auto copyOp = dyn_cast(getCopyOp()); if (!copyOp) @@ -429,66 +407,37 @@ Attribute CopyAtomType::getThrValLayoutRef() { return layoutRecast(builder, cast(copyOp.getThrBitLayoutRef()), 1, getValBits()); } -bool MmaAtomUniversalFMAType::isStatic() const { return true; } +bool MmaAtomType::isStatic() const { + auto mmaOp = dyn_cast(getMmaOp()); + if (!mmaOp) + return false; + return mmaOp.isStatic(); +} -Value MmaAtomUniversalFMAType::rebuildStaticValue(OpBuilder &builder, Location loc, - Value currentValue) const { +Value MmaAtomType::rebuildStaticValue(OpBuilder &builder, Location loc, Value currentValue) const { if (currentValue && isa(currentValue.getDefiningOp())) return nullptr; return MakeMmaAtomOp::create(builder, loc, Type(*this)); } -Attribute MmaAtomUniversalFMAType::getShapeMNK() const { - return IntTupleAttr::get(ArrayAttr::get(getContext(), {FxC(1), FxC(1), FxC(1)})); -} - -Attribute MmaAtomUniversalFMAType::getThrLayout() const { return FxLayout(FxC(1), FxC(1)); } - -Type MmaAtomUniversalFMAType::getValTypeA() const { return getElemTy(); } -Type MmaAtomUniversalFMAType::getValTypeB() const { return getElemTy(); } -Type MmaAtomUniversalFMAType::getValTypeC() const { return getElemTy(); } -Type MmaAtomUniversalFMAType::getValTypeD() const { return getElemTy(); } - -Attribute MmaAtomUniversalFMAType::getThrValLayoutA() const { - return FxLayout(FxShape(FxC(1), FxC(1)), FxStride(FxC(1), FxC(1))); +Attribute MmaAtomType::getThrLayout() const { + return cast(getMmaOp()).getThrLayout(); } -Attribute MmaAtomUniversalFMAType::getThrValLayoutB() const { - return FxLayout(FxShape(FxC(1), FxC(1)), FxStride(FxC(1), FxC(1))); +Attribute MmaAtomType::getShapeMNK() const { + return cast(getMmaOp()).getShapeMNK(); } -Attribute MmaAtomUniversalFMAType::getThrValLayoutC() const { - return FxLayout(FxShape(FxC(1), FxC(1)), FxStride(FxC(1), FxC(1))); +Type MmaAtomType::getValTypeA() const { return cast(getMmaOp()).getValTypeA(); } +Type MmaAtomType::getValTypeB() const { return cast(getMmaOp()).getValTypeB(); } +Type MmaAtomType::getValTypeC() const { return cast(getMmaOp()).getValTypeC(); } +Type MmaAtomType::getValTypeD() const { return cast(getMmaOp()).getValTypeD(); } +Attribute MmaAtomType::getThrValLayoutA() const { + return cast(getMmaOp()).getThrValLayoutA(); } - -Type MmaAtomUniversalFMAType::parse(AsmParser &parser) { - Type elemTyA, elemTyB, elemTyC; - if (parser.parseLess()) - return {}; - int32_t m, n, k; - if (parseMNKDimensionList(parser, m, n, k)) - return {}; - if (m != 1 || n != 1 || k != 1) { - parser.emitError(parser.getCurrentLocation()) - << "expected 1x1x1 dimensions for universal FMA, got " << m << "x" << n << "x" << k; - return {}; - } - // Parse ", (elemTy, elemTy) -> elemTy>" - if (parser.parseComma() || parser.parseLParen() || parser.parseType(elemTyA) || - parser.parseComma() || parser.parseType(elemTyB) || parser.parseRParen() || - parser.parseArrow() || parser.parseType(elemTyC) || parser.parseGreater()) - return {}; - // For universal FMA, all element types should be the same - if (elemTyA != elemTyB || elemTyB != elemTyC) { - parser.emitError(parser.getCurrentLocation()) - << "expected all element types to be the same for universal FMA"; - return {}; - } - return get(parser.getContext(), elemTyA); +Attribute MmaAtomType::getThrValLayoutB() const { + return cast(getMmaOp()).getThrValLayoutB(); } - -void MmaAtomUniversalFMAType::print(AsmPrinter &printer) const { - printer << "<"; - printMNKDimensionList(printer, 1, 1, 1); - printer << ", (" << getElemTy() << ", " << getElemTy() << ") -> " << getElemTy() << ">"; +Attribute MmaAtomType::getThrValLayoutC() const { + return cast(getMmaOp()).getThrValLayoutC(); } } // namespace mlir::fly diff --git a/lib/Dialect/Fly/IR/FlyUniversalOps.cpp b/lib/Dialect/Fly/IR/FlyUniversalOps.cpp new file mode 100644 index 00000000..aed19d5e --- /dev/null +++ b/lib/Dialect/Fly/IR/FlyUniversalOps.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/DialectImplementation.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" + +namespace mlir::fly { + +bool CopyOpUniversalCopyType::isStatic() const { return true; } + +Value CopyOpUniversalCopyType::rebuildStaticValue(OpBuilder &builder, Location loc, + Value currentValue) const { + if (currentValue && isa(currentValue.getDefiningOp())) + return nullptr; + return MakeCopyAtomOp::create(builder, loc, CopyAtomType::get(*this, getBitSize()), getBitSize()); +} + +Attribute CopyOpUniversalCopyType::getThrLayout() const { return FxLayout(FxC(1), FxC(1)); } + +Attribute CopyOpUniversalCopyType::getThrBitLayoutSrc() const { + return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); +} +Attribute CopyOpUniversalCopyType::getThrBitLayoutDst() const { + return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); +} +Attribute CopyOpUniversalCopyType::getThrBitLayoutRef() const { + return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); +} + +bool MmaOpUniversalFMAType::isStatic() const { return true; } + +Value MmaOpUniversalFMAType::rebuildStaticValue(OpBuilder &builder, Location loc, + Value currentValue) const { + if (currentValue && isa(currentValue.getDefiningOp())) + return nullptr; + return MakeMmaAtomOp::create(builder, loc, MmaAtomType::get(*this)); +} + +Attribute MmaOpUniversalFMAType::getShapeMNK() const { + return IntTupleAttr::get(ArrayAttr::get(getContext(), {FxC(1), FxC(1), FxC(1)})); +} + +Attribute MmaOpUniversalFMAType::getThrLayout() const { return FxLayout(FxC(1), FxC(1)); } + +Type MmaOpUniversalFMAType::getValTypeA() const { return getElemTy(); } +Type MmaOpUniversalFMAType::getValTypeB() const { return getElemTy(); } +Type MmaOpUniversalFMAType::getValTypeC() const { return getElemTy(); } +Type MmaOpUniversalFMAType::getValTypeD() const { return getElemTy(); } + +Attribute MmaOpUniversalFMAType::getThrValLayoutA() const { + return FxLayout(FxShape(FxC(1), FxC(1)), FxStride(FxC(1), FxC(1))); +} +Attribute MmaOpUniversalFMAType::getThrValLayoutB() const { + return FxLayout(FxShape(FxC(1), FxC(1)), FxStride(FxC(1), FxC(1))); +} +Attribute MmaOpUniversalFMAType::getThrValLayoutC() const { + return FxLayout(FxShape(FxC(1), FxC(1)), FxStride(FxC(1), FxC(1))); +} + +Type MmaOpUniversalFMAType::parse(AsmParser &parser) { + Type elemTyA, elemTyB, elemTyC; + if (parser.parseLess()) + return {}; + int32_t m, n, k; + if (parseMNKDimensionList(parser, m, n, k)) + return {}; + if (m != 1 || n != 1 || k != 1) { + parser.emitError(parser.getCurrentLocation()) + << "expected 1x1x1 dimensions for universal FMA, got " << m << "x" << n << "x" << k; + return {}; + } + // Parse ", (elemTy, elemTy) -> elemTy>" + if (parser.parseComma() || parser.parseLParen() || parser.parseType(elemTyA) || + parser.parseComma() || parser.parseType(elemTyB) || parser.parseRParen() || + parser.parseArrow() || parser.parseType(elemTyC) || parser.parseGreater()) + return {}; + // For universal FMA, all element types should be the same + if (elemTyA != elemTyB || elemTyB != elemTyC) { + parser.emitError(parser.getCurrentLocation()) + << "expected all element types to be the same for universal FMA"; + return {}; + } + return get(parser.getContext(), elemTyA); +} + +void MmaOpUniversalFMAType::print(AsmPrinter &printer) const { + printer << "<"; + printMNKDimensionList(printer, 1, 1, 1); + printer << ", (" << getElemTy() << ", " << getElemTy() << ") -> " << getElemTy() << ">"; +} + +} // namespace mlir::fly diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index fc122af8..e1a717ab 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -1765,7 +1765,7 @@ class TiledMmaPartitionOpLowering : public OpRewritePattern Value inputIter = makeViewOp.getIter(); Value inputLayoutValue = makeViewOp.getLayout(); - auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); + auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); if (!mmaAtom) return failure(); @@ -1901,7 +1901,7 @@ class TiledMmaPartitionShapeOpLowering : public OpRewritePattern>(shape))) return failure(); - auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); + auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); if (!mmaAtom) return failure(); @@ -1937,7 +1937,7 @@ class MmaMakeFragmentOpLowering : public OpRewritePattern { if (!tiledMmaTy) return failure(); - auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); + auto mmaAtom = dyn_cast(tiledMmaTy.getMmaAtom()); if (!mmaAtom) return failure(); @@ -2069,13 +2069,9 @@ class ExpandGemmOpLowering : public OpRewritePattern { auto *ctx = rewriter.getContext(); Value mmaAtomVal = op.getMmaAtom(); - MmaAtomTypeInterface mmaAtomTy; if (auto tiledMmaOp = mmaAtomVal.getDefiningOp()) { mmaAtomVal = tiledMmaOp.getMmaAtom(); } - mmaAtomTy = dyn_cast(mmaAtomVal.getType()); - if (!mmaAtomTy) - return failure(); Value d = op.getD(); Value a = op.getA(); @@ -2092,6 +2088,15 @@ class ExpandGemmOpLowering : public OpRewritePattern { int32_t bRank = bLayoutAttr.rank(); int32_t cRank = cLayoutAttr.rank(); + if (dRank == 1 && aRank == 1 && bRank == 1 && cRank == 1) { + MmaAtomCall::create(rewriter, loc, mmaAtomVal, d, a, b, c); + rewriter.eraseOp(op); + return success(); + } + + if (dRank != 3 || cRank != 3 || aRank < 2 || bRank < 2) + return failure(); + IntTupleBuilder attrBuilder(ctx); auto get_static_product = [&](IntTupleAttr shape) { return intTupleProduct(attrBuilder, shape).getLeafAsInt().getValue(); @@ -2105,15 +2110,6 @@ class ExpandGemmOpLowering : public OpRewritePattern { assert(loop_m == get_static_product(cLayoutAttr.getShape().at(1)) && "Mismatch in loop_m"); assert(loop_n == get_static_product(cLayoutAttr.getShape().at(2)) && "Mismatch in loop_n"); - if (dRank == 1 && aRank == 1 && bRank == 1 && cRank == 1) { - MmaAtomCall::create(rewriter, loc, mmaAtomVal, d, a, b, c); - rewriter.eraseOp(op); - return success(); - } - - if (dRank != 3 || cRank != 3 || aRank < 2 || bRank < 2) - return failure(); - auto getSliceCoord = [&](ArrayRef idx) { SmallVector coordElems; // Keep mode-0 unchanged for all operands. @@ -2127,10 +2123,10 @@ class ExpandGemmOpLowering : public OpRewritePattern { if (aRank == 2 && bRank == 2) { auto emitMmaCall2D = [&](int32_t m, int32_t n) { Value aSlice = SliceOp::create(rewriter, loc, a, getSliceCoord({m})); - Value bSlice = SliceOp::create(rewriter, loc, b, getSliceCoord({n})); - Value cSlice = SliceOp::create(rewriter, loc, c, getSliceCoord({m, n})); - Value dSlice = SliceOp::create(rewriter, loc, d, getSliceCoord({m, n})); - MmaAtomCall::create(rewriter, loc, mmaAtomVal, dSlice, aSlice, bSlice, cSlice); + Value bSlice = SliceOp::create(rewriter, loc, b, getSliceCoord({n})); + Value cSlice = SliceOp::create(rewriter, loc, c, getSliceCoord({m, n})); + Value dSlice = SliceOp::create(rewriter, loc, d, getSliceCoord({m, n})); + MmaAtomCall::create(rewriter, loc, mmaAtomVal, dSlice, aSlice, bSlice, cSlice); }; int32_t totalIters = loop_m * loop_n; @@ -2228,11 +2224,11 @@ class ExpandGemmOpLowering : public OpRewritePattern { bool &visited = mnVisited[m * loop_n + n]; Value cSrc = visited ? d : c; visited = true; - Value aSlice = SliceOp::create(rewriter, loc, a, getSliceCoord({m, k})); - Value bSlice = SliceOp::create(rewriter, loc, b, getSliceCoord({n, k})); - Value cSlice = SliceOp::create(rewriter, loc, cSrc, getSliceCoord({m, n})); - Value dSlice = SliceOp::create(rewriter, loc, d, getSliceCoord({m, n})); - MmaAtomCall::create(rewriter, loc, mmaAtomVal, dSlice, aSlice, bSlice, cSlice); + Value aSlice = SliceOp::create(rewriter, loc, a, getSliceCoord({m, k})); + Value bSlice = SliceOp::create(rewriter, loc, b, getSliceCoord({n, k})); + Value cSlice = SliceOp::create(rewriter, loc, cSrc, getSliceCoord({m, n})); + Value dSlice = SliceOp::create(rewriter, loc, d, getSliceCoord({m, n})); + MmaAtomCall::create(rewriter, loc, mmaAtomVal, dSlice, aSlice, bSlice, cSlice); }; Value traversalLayoutVal = op.getTraversalLayout(); @@ -2423,6 +2419,7 @@ class MemRefLoadVecOpLowering : public OpRewritePattern { VectorType chunkVecTy = VectorType::get({vecWidth}, resVecTy.getElementType()); for (int64_t i = 0; i < numChunks; ++i) { + // Compute column-major coordinate over the rest flat dims. IntTupleAttr restCoord = layoutIdx2CrdColMajor(attrBuilder, attrBuilder.materializeConstantLeaf(i), restFlatShape); @@ -2535,6 +2532,7 @@ class MemRefStoreVecOpLowering : public OpRewritePattern { vec = permuteForStore(rewriter, loc, vec, flatShape, flatRank, contigIdx, vecWidth, numChunks); for (int64_t i = 0; i < numChunks; ++i) { + // Compute column-major coordinate over the rest flat dims. IntTupleAttr restCoord = layoutIdx2CrdColMajor(attrBuilder, attrBuilder.materializeConstantLeaf(i), restFlatShape); diff --git a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp index d9148597..00770906 100644 --- a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp @@ -32,35 +32,35 @@ namespace cdna4 {} namespace mlir::fly_rocdl { -bool MmaAtomCDNA3_MFMAType::isStatic() const { return true; } +bool MmaOpCDNA3_MFMAType::isStatic() const { return true; } -Value MmaAtomCDNA3_MFMAType::rebuildStaticValue(OpBuilder &builder, Location loc, - Value currentValue) const { +Value MmaOpCDNA3_MFMAType::rebuildStaticValue(OpBuilder &builder, Location loc, + Value currentValue) const { if (currentValue && isa(currentValue.getDefiningOp())) return nullptr; - return MakeMmaAtomOp::create(builder, loc, Type(*this)); + return MakeMmaAtomOp::create(builder, loc, MmaAtomType::get(*this)); } -Attribute MmaAtomCDNA3_MFMAType::getThrLayout() const { return FxLayout(FxC(64), FxC(1)); } +Attribute MmaOpCDNA3_MFMAType::getThrLayout() const { return FxLayout(FxC(64), FxC(1)); } -Attribute MmaAtomCDNA3_MFMAType::getShapeMNK() const { +Attribute MmaOpCDNA3_MFMAType::getShapeMNK() const { return IntTupleAttr::get(ArrayAttr::get(getContext(), {FxC(getM()), FxC(getN()), FxC(getK())})); } -Type MmaAtomCDNA3_MFMAType::getValTypeA() const { return getElemTyA(); } -Type MmaAtomCDNA3_MFMAType::getValTypeB() const { return getElemTyB(); } -Type MmaAtomCDNA3_MFMAType::getValTypeC() const { return getElemTyAcc(); } -Type MmaAtomCDNA3_MFMAType::getValTypeD() const { return getElemTyAcc(); } +Type MmaOpCDNA3_MFMAType::getValTypeA() const { return getElemTyA(); } +Type MmaOpCDNA3_MFMAType::getValTypeB() const { return getElemTyB(); } +Type MmaOpCDNA3_MFMAType::getValTypeC() const { return getElemTyAcc(); } +Type MmaOpCDNA3_MFMAType::getValTypeD() const { return getElemTyAcc(); } -Attribute MmaAtomCDNA3_MFMAType::getThrValLayoutA() const { +Attribute MmaOpCDNA3_MFMAType::getThrValLayoutA() const { return cdna3::getThrValLayoutAB(getContext(), getM(), getN(), getK(), getElemTyA(), getElemTyB(), getElemTyAcc()); } -Attribute MmaAtomCDNA3_MFMAType::getThrValLayoutB() const { +Attribute MmaOpCDNA3_MFMAType::getThrValLayoutB() const { return cdna3::getThrValLayoutAB(getContext(), getM(), getN(), getK(), getElemTyA(), getElemTyB(), getElemTyAcc()); } -Attribute MmaAtomCDNA3_MFMAType::getThrValLayoutC() const { +Attribute MmaOpCDNA3_MFMAType::getThrValLayoutC() const { int M = getM(); int N = getN(); @@ -72,9 +72,9 @@ Attribute MmaAtomCDNA3_MFMAType::getThrValLayoutC() const { FxStride(FxThr(M, ValM0), FxVal(1, ValM0 * GroupM))); } -LogicalResult MmaAtomCDNA3_MFMAType::verify(function_ref emitError, int32_t m, - int32_t n, int32_t k, Type elemTyA, Type elemTyB, - Type elemTyAcc) { +LogicalResult MmaOpCDNA3_MFMAType::verify(function_ref emitError, int32_t m, + int32_t n, int32_t k, Type elemTyA, Type elemTyB, + Type elemTyAcc) { assert(m == n && "M and N must be equal"); if (m != n) { return emitError() << "invalid MNK dimensions for CDNA3 MFMA: " << m << "x" << n << "x" << k; diff --git a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp index 54564387..a71b110d 100644 --- a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp @@ -45,8 +45,7 @@ LayoutAttr getThrValLayoutAB(MLIRContext *ctx, int32_t K, Type elemTy) { // pos = (l%16)*1 + (l/16)*128 + val_within*16 [+ block*256] int numBlocks = valsPerLane / 8; if (numBlocks == 1) { - return FxLayout(FxShape(FxThr(16, 2), FxVal(8)), - FxStride(FxThr(1, 128), FxVal(16))); + return FxLayout(FxShape(FxThr(16, 2), FxVal(8)), FxStride(FxThr(1, 128), FxVal(16))); } return FxLayout(FxShape(FxThr(16, 2), FxVal(8, numBlocks)), FxStride(FxThr(1, 128), FxVal(16, 256))); @@ -69,89 +68,76 @@ LayoutAttr getThrValLayoutCD(MLIRContext *ctx, Type elemTyAcc) { int elemBits = elemTyAcc.getIntOrFloatBitWidth(); if (elemBits >= 32) { - return FxLayout(FxShape(FxThr(16, 2), FxVal(8)), - FxStride(FxThr(16, 8), FxVal(1))); + return FxLayout(FxShape(FxThr(16, 2), FxVal(8)), FxStride(FxThr(16, 8), FxVal(1))); } // 16-bit: 4 VGPRs × 2 sub-elements = 8 values. - return FxLayout(FxShape(FxThr(16, 2), FxVal(4, 2)), - FxStride(FxThr(16, 8), FxVal(2, 1))); + return FxLayout(FxShape(FxThr(16, 2), FxVal(4, 2)), FxStride(FxThr(16, 8), FxVal(2, 1))); } } // namespace gfx1250 namespace mlir::fly_rocdl { -bool MmaAtomGFX1250_WMMAType::isStatic() const { return true; } +bool MmaOpGFX1250_WMMAType::isStatic() const { return true; } -Value MmaAtomGFX1250_WMMAType::rebuildStaticValue(OpBuilder &builder, Location loc, - Value currentValue) const { +Value MmaOpGFX1250_WMMAType::rebuildStaticValue(OpBuilder &builder, Location loc, + Value currentValue) const { if (currentValue && isa(currentValue.getDefiningOp())) return nullptr; - return MakeMmaAtomOp::create(builder, loc, Type(*this)); + return MakeMmaAtomOp::create(builder, loc, MmaAtomType::get(*this)); } -Type MmaAtomGFX1250_WMMAType::getValTypeA() const { return getElemTyA(); } -Type MmaAtomGFX1250_WMMAType::getValTypeB() const { return getElemTyB(); } -Type MmaAtomGFX1250_WMMAType::getValTypeC() const { return getElemTyAcc(); } -Type MmaAtomGFX1250_WMMAType::getValTypeD() const { return getElemTyAcc(); } +Type MmaOpGFX1250_WMMAType::getValTypeA() const { return getElemTyA(); } +Type MmaOpGFX1250_WMMAType::getValTypeB() const { return getElemTyB(); } +Type MmaOpGFX1250_WMMAType::getValTypeC() const { return getElemTyAcc(); } +Type MmaOpGFX1250_WMMAType::getValTypeD() const { return getElemTyAcc(); } -Attribute MmaAtomGFX1250_WMMAType::getThrLayout() const { - return FxLayout(FxC(32), FxC(1)); -} +Attribute MmaOpGFX1250_WMMAType::getThrLayout() const { return FxLayout(FxC(32), FxC(1)); } -Attribute MmaAtomGFX1250_WMMAType::getShapeMNK() const { - return IntTupleAttr::get( - ArrayAttr::get(getContext(), {FxC(getM()), FxC(getN()), FxC(getK())})); +Attribute MmaOpGFX1250_WMMAType::getShapeMNK() const { + return IntTupleAttr::get(ArrayAttr::get(getContext(), {FxC(getM()), FxC(getN()), FxC(getK())})); } -Attribute MmaAtomGFX1250_WMMAType::getThrValLayoutA() const { +Attribute MmaOpGFX1250_WMMAType::getThrValLayoutA() const { return gfx1250::getThrValLayoutAB(getContext(), getK(), getElemTyA()); } -Attribute MmaAtomGFX1250_WMMAType::getThrValLayoutB() const { +Attribute MmaOpGFX1250_WMMAType::getThrValLayoutB() const { return gfx1250::getThrValLayoutAB(getContext(), getK(), getElemTyB()); } -Attribute MmaAtomGFX1250_WMMAType::getThrValLayoutC() const { +Attribute MmaOpGFX1250_WMMAType::getThrValLayoutC() const { return gfx1250::getThrValLayoutCD(getContext(), getElemTyAcc()); } -LogicalResult -MmaAtomGFX1250_WMMAType::verify(function_ref emitError, - int32_t m, int32_t n, int32_t k, Type elemTyA, - Type elemTyB, Type elemTyAcc) { +LogicalResult MmaOpGFX1250_WMMAType::verify(function_ref emitError, int32_t m, + int32_t n, int32_t k, Type elemTyA, Type elemTyB, + Type elemTyAcc) { if (m != 16 || n != 16) - return emitError() << "GFX1250 WMMA requires M=N=16, got " << m << "x" - << n; + return emitError() << "GFX1250 WMMA requires M=N=16, got " << m << "x" << n; - auto isF8 = [](Type ty) { - return isa(ty) || isa(ty); - }; + auto isF8 = [](Type ty) { return isa(ty) || isa(ty); }; bool valid = false; if (k == 4 && elemTyA.isF32() && elemTyB.isF32() && elemTyAcc.isF32()) valid = true; - if (k == 32 && elemTyA.isF16() && elemTyB.isF16() && - (elemTyAcc.isF32() || elemTyAcc.isF16())) + if (k == 32 && elemTyA.isF16() && elemTyB.isF16() && (elemTyAcc.isF32() || elemTyAcc.isF16())) valid = true; - if (k == 32 && elemTyA.isBF16() && elemTyB.isBF16() && - (elemTyAcc.isF32() || elemTyAcc.isBF16())) + if (k == 32 && elemTyA.isBF16() && elemTyB.isBF16() && (elemTyAcc.isF32() || elemTyAcc.isBF16())) valid = true; if ((k == 64 || k == 128) && isF8(elemTyA) && isF8(elemTyB) && (elemTyAcc.isF32() || elemTyAcc.isF16())) valid = true; - if (k == 64 && elemTyA.isInteger(8) && elemTyB.isInteger(8) && - elemTyAcc.isInteger(32)) + if (k == 64 && elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32)) valid = true; if (!valid) { - return emitError() << "unsupported GFX1250 WMMA configuration: " << m - << "x" << n << "x" << k << " with A=" << elemTyA - << ", B=" << elemTyB << ", Acc=" << elemTyAcc; + return emitError() << "unsupported GFX1250 WMMA configuration: " << m << "x" << n << "x" << k + << " with A=" << elemTyA << ", B=" << elemTyB << ", Acc=" << elemTyAcc; } return success(); } diff --git a/python/flydsl/expr/derived.py b/python/flydsl/expr/derived.py index d9aa1c89..25538a1b 100644 --- a/python/flydsl/expr/derived.py +++ b/python/flydsl/expr/derived.py @@ -20,7 +20,6 @@ tAr = thr_mma.partition_A(sA) """ -from .._mlir import ir from .._mlir.dialects._fly_enum_gen import MmaOperand from .meta import traced_op from .primitive import * @@ -28,7 +27,6 @@ __all__ = [ # Tiled Operation - "MmaAtom", "ThrCopy", "ThrMma", "make_layout_tv", @@ -39,61 +37,6 @@ ] -class Atom: - """Base class for hardware instruction atoms (copy/MMA). - - An atom wraps a single MLIR ``ir.Value`` that represents a hardware - instruction descriptor in the Fly dialect type system. - """ - - def __init__(self, value: ir.Value): - self.value = value - self.atom_ty = self.value.type - - @classmethod - def __fly_construct__(cls, values): - return cls(values[0]) - - def __fly_values__(self): - return [self.value] - - -class MmaAtom(Atom): - """Atom describing a single MMA (matrix multiply-accumulate) instruction. - - Wraps MFMA instruction descriptors with thread-value layouts for - operands A, B, and C, plus the MNK shape of the instruction. - - Properties: - thr_layout: Thread layout of the MMA atom. - shape_mnk: The (M, N, K) shape of the MMA instruction. - tv_layout_A/B/C: Thread-value layouts for operands A, B, C. - """ - - def __str__(self): - return f"MmaAtom({self.atom_ty})" - - @property - def thr_layout(self): - return static(self.atom_ty.thr_layout) - - @property - def shape_mnk(self): - return static(self.atom_ty.shape_mnk) - - @property - def tv_layout_A(self): - return static(self.atom_ty.tv_layout_a) - - @property - def tv_layout_B(self): - return static(self.atom_ty.tv_layout_b) - - @property - def tv_layout_C(self): - return static(self.atom_ty.tv_layout_c) - - class ThrCopy(TiledCopy): """Per-thread view of a TiledCopy for partitioning source/destination tensors. diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index ff1c869b..a19f4852 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -13,16 +13,17 @@ CoordTensorType, CopyAtomType, CopyOpUniversalCopyType, + GemmTraversalOrder, IntTupleType, LayoutType, MemRefType, - MmaAtomUniversalFMAType, + MmaAtomType, MmaOperand, + MmaOpUniversalFMAType, PointerType, SwizzleType, TiledCopyType, TiledMmaType, - GemmTraversalOrder, # has_none, ) @@ -47,10 +48,11 @@ "MemRefType", "CoordTensorType", "CopyAtomType", + "MmaAtomType", "TiledCopyType", "TiledMmaType", "CopyOpUniversalCopyType", - "MmaAtomUniversalFMAType", + "MmaOpUniversalFMAType", # UniversalOps "UniversalCopy", "UniversalCopy8b", @@ -174,7 +176,7 @@ UniversalCopy64b = lambda: CopyOpUniversalCopyType.get(64) UniversalCopy128b = lambda: CopyOpUniversalCopyType.get(128) -UniversalFMA = lambda ty: MmaAtomUniversalFMAType.get(ty.ir_type) +UniversalFMA = lambda ty: MmaOpUniversalFMAType.get(ty.ir_type) def const_expr(x): @@ -639,10 +641,9 @@ def tile_to_shape(block, trg_shape, ord_shape, loc=None, ip=None): @traced_op -def make_mma_atom(atom_type, loc=None, ip=None): - from .derived import MmaAtom - - return MmaAtom(fly.make_mma_atom(atom_type, loc=loc, ip=ip)) +def make_mma_atom(mma_op_type, loc=None, ip=None): + mma_atom_ty = MmaAtomType.get(mma_op=mma_op_type) + return fly.make_mma_atom(mma_atom_ty, loc=loc, ip=ip) @traced_op @@ -660,7 +661,7 @@ def make_copy_atom(copy_op_type, elem_type, loc=None, ip=None): val_bits = elem_type else: raise TypeError(f"make_copy_atom: elem_type must be NumericType, ir.Type, or int, got {type(elem_type)}") - copy_atom_ty = CopyAtomType.get(copy_op_type, val_bits) + copy_atom_ty = CopyAtomType.get(copy_op=copy_op_type, val_bits=val_bits) return fly.make_copy_atom(copy_atom_ty, val_bits=val_bits, loc=loc, ip=ip) diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index bcb97e91..b4556bcf 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -17,8 +17,8 @@ >>> rocdl.barrier() """ -from .._mlir._mlir_libs._mlirDialectsFlyROCDL import CopyOpCDNA3BufferCopyType, MmaAtomCDNA3_MFMAType -from .._mlir._mlir_libs._mlirDialectsFlyROCDL import MmaAtomGFX1250_WMMAType +from .._mlir._mlir_libs._mlirDialectsFlyROCDL import CopyOpCDNA3BufferCopyType, MmaOpCDNA3_MFMAType +from .._mlir._mlir_libs._mlirDialectsFlyROCDL import MmaOpGFX1250_WMMAType from .._mlir.dialects.rocdl import * # noqa: F401,F403 from .._mlir.extras import types as T @@ -51,7 +51,7 @@ def MFMA(m, n, k, elem_type, elem_type_b=None, elem_type_acc=None): ty_acc = ( ty if elem_type_acc is None else (elem_type_acc.ir_type if hasattr(elem_type_acc, "ir_type") else elem_type_acc) ) - return MmaAtomCDNA3_MFMAType.get(m, n, k, ty, ty_b, ty_acc) + return MmaOpCDNA3_MFMAType.get(m, n, k, ty, ty_b, ty_acc) def WMMA(m, n, k, elem_type, elem_type_b=None, elem_type_acc=None): @@ -74,7 +74,7 @@ def WMMA(m, n, k, elem_type, elem_type_b=None, elem_type_acc=None): ty_b = ty if elem_type_b is None else (elem_type_b.ir_type if hasattr(elem_type_b, 'ir_type') else elem_type_b) ty_acc = ty if elem_type_acc is None else (elem_type_acc.ir_type if hasattr(elem_type_acc, 'ir_type') else elem_type_acc) - return MmaAtomGFX1250_WMMAType.get(m, n, k, ty, ty_b, ty_acc) + return MmaOpGFX1250_WMMAType.get(m, n, k, ty, ty_b, ty_acc) def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): @@ -670,9 +670,9 @@ def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): "BufferCopy64b", "BufferCopy128b", # MMA atom types - "MmaAtomCDNA3_MFMAType", + "MmaOpCDNA3_MFMAType", "MFMA", - "MmaAtomGFX1250_WMMAType", + "MmaOpGFX1250_WMMAType", "WMMA", # Convenience wrappers "make_buffer_tensor", diff --git a/python/flydsl/expr/rocdl/universal.py b/python/flydsl/expr/rocdl/universal.py index cd7b79d3..02575450 100644 --- a/python/flydsl/expr/rocdl/universal.py +++ b/python/flydsl/expr/rocdl/universal.py @@ -2,11 +2,11 @@ # Copyright (c) 2025 FlyDSL Project Contributors from ..._mlir import ir -from ..._mlir._mlir_libs._mlirDialectsFlyROCDL import MmaAtomGFX1250_WMMAType +from ..._mlir._mlir_libs._mlirDialectsFlyROCDL import MmaOpGFX1250_WMMAType from ..._mlir.dialects import arith, fly from ..._mlir.dialects._fly_enum_gen import AddressSpace from ..._mlir.dialects.fly import PointerType -from ..._mlir.dialects.fly_rocdl import CopyOpCDNA3BufferCopyType, MmaAtomCDNA3_MFMAType +from ..._mlir.dialects.fly_rocdl import CopyOpCDNA3BufferCopyType, MmaOpCDNA3_MFMAType from ..._mlir.extras import types as T from ..primitive import ( get_iter, @@ -31,7 +31,7 @@ def MFMA(m, n, k, elem_ty_ab, elem_ty_acc=None): ty_acc = T.f32() else: ty_acc = elem_ty_acc.ir_type if hasattr(elem_ty_acc, "ir_type") else elem_ty_acc - return MmaAtomCDNA3_MFMAType.get(m, n, k, ty_ab, ty_ab, ty_acc) + return MmaOpCDNA3_MFMAType.get(m, n, k, ty_ab, ty_ab, ty_acc) def WMMA(m, n, k, elem_ty_ab, elem_ty_acc=None): @@ -40,7 +40,7 @@ def WMMA(m, n, k, elem_ty_ab, elem_ty_acc=None): ty_acc = ir.F32Type.get() else: ty_acc = elem_ty_acc.ir_type if hasattr(elem_ty_acc, "ir_type") else elem_ty_acc - return MmaAtomGFX1250_WMMAType.get(m, n, k, ty_ab, ty_ab, ty_acc) + return MmaOpGFX1250_WMMAType.get(m, n, k, ty_ab, ty_ab, ty_acc) def make_buffer_tensor(tensor: Tensor) -> Tensor: diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index 9686d805..58bacff2 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -612,6 +612,33 @@ def layout_ref_tv(self): return static(self.type.tv_layout_ref) +@ir.register_value_caster(MmaAtomType.static_typeid, replace=True) +class MmaAtom(BuiltinDslType): + @property + def thr_layout(self): + return static(self.type.thr_layout) + + @property + def thr_id(self): + return self.thr_layout + + @property + def shape_mnk(self): + return static(self.type.shape_mnk) + + @property + def layout_A_tv(self): + return static(self.type.tv_layout_a) + + @property + def layout_B_tv(self): + return static(self.type.tv_layout_b) + + @property + def layout_C_tv(self): + return static(self.type.tv_layout_c) + + @ir.register_value_caster(TiledCopyType.static_typeid, replace=True) class TiledCopy(BuiltinDslType): @property diff --git a/python/flydsl/expr/utils/print_typst.py b/python/flydsl/expr/utils/print_typst.py index 0edc8df6..5c50425f 100644 --- a/python/flydsl/expr/utils/print_typst.py +++ b/python/flydsl/expr/utils/print_typst.py @@ -110,6 +110,17 @@ def _typst_text_panel(lines: list[str]) -> str: return "\n".join(panel) +def _mma_atom_text_lines(mma_atom) -> list[str]: + return [ + "MMA Atom", + f" Thr Layout: {mma_atom.thr_layout}", + f" Shape MNK: {mma_atom.shape_mnk}", + f" TV Layout A: {mma_atom.layout_A_tv}", + f" TV Layout B: {mma_atom.layout_B_tv}", + f" TV Layout C: {mma_atom.layout_C_tv}", + ] + + def _tiled_mma_text_lines(mma) -> list[str]: return [ "Tiled MMA", @@ -224,6 +235,34 @@ def _typst_layout( return _typst_text_panel([f"{layout_str}"]) + "\n\n" + _typst_grid_block(M, N, cells, title="Layout") + "\n" +def _typst_mma_atom( + mma_atom, + color: Callable[[int, int], str], +) -> str: + shape_mnk = mma_atom.shape_mnk + M, N, K = shape_mnk.to_py_value() + + layout_A = mma_atom.layout_A_tv + layout_B = mma_atom.layout_B_tv + layout_C = mma_atom.layout_C_tv + + cells_C = _tv_cells(layout_C, M, N, color) + cells_A = _tv_cells(layout_A, M, K, color) + cells_B = _tv_cells_B_top(layout_B, N, K, color) + + doc = _typst_text_panel(_mma_atom_text_lines(mma_atom)) + + doc += "\n\n#grid(\n columns: (auto, auto, auto),\n rows: (auto, auto),\n gutter: 12pt,\n align: center + horizon,\n" + doc += " [],\n [\n" + doc += _typst_grid_block(K, N, cells_B, title="B (K x N)") + doc += "\n ],\n [],\n [\n" + doc += _typst_grid_block(M, K, cells_A, title="A (M x K)") + doc += "\n ],\n [\n" + doc += _typst_grid_block(M, N, cells_C, title="C (M x N)") + doc += "\n ],\n [],\n)\n" + return doc + + def _typst_mma( mma, color: Callable[[int, int], str], @@ -289,6 +328,7 @@ def print_typst( Dispatches based on the type of *arg*: * **Layout / ComposedLayout** -- index grid coloured by linear index. + * **MmaAtom** -- atom info text plus B-over-C / A-left-of-C TV grids. * **TiledMma** -- LayoutABC text plus a B-over-C / A-left-of-C view. * **TiledCopy** -- side-by-side Src, Dst thread-value grids. @@ -296,24 +336,26 @@ def print_typst( separators; the Typst header is emitted only once. Args: - arg: A static Layout, ComposedLayout, TiledMma, or TiledCopy. + arg: A static Layout, ComposedLayout, MmaAtom, TiledMma, or TiledCopy. color: Optional colour function. For Layout the signature is - ``(idx) -> hex``; for TiledMma/TiledCopy it is + ``(idx) -> hex``; for MmaAtom/TiledMma/TiledCopy it is ``(tid, vid) -> hex``. Defaults to grayscale for Layout and pastel-by-thread for MMA/Copy. file: Output filename. ``None`` -> ``sys.stdout``. """ - from ..typing import ComposedLayout, Layout, TiledCopy, TiledMma + from ..typing import ComposedLayout, Layout, MmaAtom, TiledCopy, TiledMma if isinstance(arg, (Layout, ComposedLayout)): body = _typst_layout(arg, color or _color_bw) + elif isinstance(arg, MmaAtom): + body = _typst_mma_atom(arg, color or _color_tv) elif isinstance(arg, TiledMma): body = _typst_mma(arg, color or _color_tv) elif isinstance(arg, TiledCopy): body = _typst_copy(arg, color or _color_tv) else: raise ValueError( - f"print_typst expects Layout, ComposedLayout, TiledMma, or TiledCopy, got {type(arg).__name__}" + f"print_typst expects Layout, ComposedLayout, MmaAtom, TiledMma, or TiledCopy, got {type(arg).__name__}" ) if file is None: diff --git a/tests/mlir/Conversion/mma_atom.mlir b/tests/mlir/Conversion/mma_atom.mlir index 944bf278..fec2a782 100644 --- a/tests/mlir/Conversion/mma_atom.mlir +++ b/tests/mlir/Conversion/mma_atom.mlir @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (c) 2025 FlyDSL Project Contributors -// RUN: %fly-opt %s --convert-fly-to-rocdl | FileCheck %s +// RUN: %fly-opt %s --fly-rewrite-func-signature --fly-canonicalize --fly-layout-lowering --convert-fly-to-rocdl | FileCheck %s // MMA atom call lowering tests: // fly.mma_atom_call -> rocdl.mfma intrinsic @@ -14,12 +14,24 @@ func.func @test_mma_atom_call( %a: !fly.memref, %b: !fly.memref, %c: !fly.memref) { - %atom = fly.make_mma_atom : !fly_rocdl.atom.cdna3.mfma<16x16x4, (f32, f32) -> f32> + %atom = fly.make_mma_atom : !fly.mma_atom f32>> // CHECK: %[[A_VAL:.*]] = llvm.load %[[A]] : !llvm.ptr<5> -> f32 // CHECK: %[[B_VAL:.*]] = llvm.load %[[B]] : !llvm.ptr<5> -> f32 // CHECK: %[[C_VAL:.*]] = llvm.load %[[C]] : !llvm.ptr<5> -> vector<4xf32> // CHECK: %[[RES:.*]] = rocdl.mfma.f32.16x16x4f32 %[[A_VAL]], %[[B_VAL]], %[[C_VAL]] // CHECK: llvm.store %[[RES]], %[[D]] : vector<4xf32>, !llvm.ptr<5> - fly.mma_atom_call(%atom, %d, %a, %b, %c) : (!fly_rocdl.atom.cdna3.mfma<16x16x4, (f32, f32) -> f32>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + fly.mma_atom_call(%atom, %d, %a, %b, %c) : (!fly.mma_atom f32>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + return +} + +// CHECK-LABEL: @test_gemm_from_tiled_mma_arg +// CHECK: rocdl.mfma.f32.16x16x4f32 +func.func @test_gemm_from_tiled_mma_arg( + %tiled_mma: !fly.tiled_mma f32>>, <(1,4,1):(0,1,0)>>, + %d: !fly.memref, + %a: !fly.memref, + %b: !fly.memref, + %c: !fly.memref) { + fly.gemm(%tiled_mma, %d, %a, %b, %c) : (!fly.tiled_mma f32>>, <(1,4,1):(0,1,0)>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () return }