Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/utils/print_typst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

@flyc.jit
def visualize():
mma_atom = fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 16, fx.Float16))

tiled_mma = fx.make_tiled_mma(
fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 16, fx.Float16)),
mma_atom,
fx.make_layout((1, 2, 1), (0, 1, 0)),
fx.make_tile(16, 32, fx.make_layout((4, 4, 2), (1, 8, 4))),
)
Expand All @@ -34,6 +36,7 @@ def visualize():
tiled_mma,
)

fx.utils.print_typst(mma_atom, file=OUTPUT_TYPST)
fx.utils.print_typst(tiled_mma, file=OUTPUT_TYPST)
fx.utils.print_typst(tiled_copy, file=OUTPUT_TYPST)
fx.utils.print_typst(swizzle_layout, file=OUTPUT_TYPST)
Expand Down
2 changes: 1 addition & 1 deletion include/flydsl/Dialect/Fly/IR/FlyInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def Fly_CopyOpTypeInterface : TypeInterface<"CopyOpTypeInterface", [Fly_MayStati
];
}

def Fly_MmaAtomTypeInterface : TypeInterface<"MmaAtomTypeInterface", [Fly_MayStaticTypeInterface]> {
def Fly_MmaOpTypeInterface : TypeInterface<"MmaOpTypeInterface", [Fly_MayStaticTypeInterface]> {
let cppNamespace = "::mlir::fly";
let methods = [
InterfaceMethod<"", "::mlir::Attribute", "getThrLayout", (ins)>,
Expand Down
31 changes: 29 additions & 2 deletions include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,33 @@ def Fly_CopyAtom : Fly_Type<"CopyAtom", "copy_atom", [
];
}

def Fly_MmaAtom : Fly_Type<"MmaAtom", "mma_atom", [
DeclareTypeInterfaceMethods<Fly_MayStaticTypeInterface>
]> {
let parameters = (ins
"Type":$mmaOp
);
let assemblyFormat = "`<` $mmaOp `>`";

let extraClassDeclaration = [{
::mlir::Attribute getThrLayout() const;
::mlir::Attribute getShapeMNK() const;
::mlir::Type getValTypeA() const;
::mlir::Type getValTypeB() const;
::mlir::Type getValTypeC() const;
::mlir::Type getValTypeD() const;
::mlir::Attribute getThrValLayoutA() const;
::mlir::Attribute getThrValLayoutB() const;
::mlir::Attribute getThrValLayoutC() const;
}];

let builders = [
TypeBuilderWithInferredContext<(ins "Type":$mmaOp), [{
return $_get(mmaOp.getContext(), mmaOp);
}]>
];
}

def Fly_CopyOpUniversalCopy : Fly_Type<"CopyOpUniversalCopy", "universal_copy", [
DeclareTypeInterfaceMethods<Fly_MayStaticTypeInterface>,
DeclareTypeInterfaceMethods<Fly_CopyOpTypeInterface>
Expand All @@ -273,9 +300,9 @@ def Fly_CopyOpUniversalCopy : Fly_Type<"CopyOpUniversalCopy", "universal_copy",
let assemblyFormat = "`<` $bitSize `>`";
}

def Fly_MmaAtomUniversalFMA : Fly_Type<"MmaAtomUniversalFMA", "universal_fma", [
def Fly_MmaOpUniversalFMA : Fly_Type<"MmaOpUniversalFMA", "universal_fma", [
DeclareTypeInterfaceMethods<Fly_MayStaticTypeInterface>,
DeclareTypeInterfaceMethods<Fly_MmaAtomTypeInterface>
DeclareTypeInterfaceMethods<Fly_MmaOpTypeInterface>
]> {
let parameters = (ins "Type":$elemTy);

Expand Down
4 changes: 2 additions & 2 deletions include/flydsl/Dialect/Fly/Utils/TiledOpUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Layout layoutTiledCopyThrValView(LayoutBuilder<Layout> &builder, CopyAtomType co
}

template <class Layout>
Layout layoutTiledMmaThrValView(LayoutBuilder<Layout> &builder, MmaAtomTypeInterface mmaAtom,
Layout layoutTiledMmaThrValView(LayoutBuilder<Layout> &builder, MmaAtomType mmaAtom,
LayoutAttr tiledShape2D, IntTupleAttr atomShape2D,
LayoutAttr atomLayoutThrVal, TileAttr permutation2D,
Layout trgLayout) {
Expand Down Expand Up @@ -262,7 +262,7 @@ Layout layoutTiledCopyRetile(LayoutBuilder<Layout> &builder, CopyAtomType copyAt
}

template <class Layout>
Layout layoutTiledMmaThrValOperandView(LayoutBuilder<Layout> &builder, MmaAtomTypeInterface mmaAtom,
Layout layoutTiledMmaThrValOperandView(LayoutBuilder<Layout> &builder, MmaAtomType mmaAtom,
LayoutAttr atomLayoutMNK, TileAttr permutationMNK,
MmaOperand operandId, Layout trgLayout) {
auto *ctx = atomLayoutMNK.getContext();
Expand Down
4 changes: 2 additions & 2 deletions include/flydsl/Dialect/FlyROCDL/IR/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def FlyROCDL_Dialect : Dialect {
let useDefaultTypePrinterParser = 1;
}

class FlyxROCL_MmaAtom<string typeName, string typeMnemonic, list<Trait> traits = []>
class FlyxROCL_MmaOp<string typeName, string typeMnemonic, list<Trait> traits = []>
: TypeDef<FlyROCDL_Dialect, typeName, !listconcat(traits, [
DeclareTypeInterfaceMethods<Fly_MayStaticTypeInterface>,
DeclareTypeInterfaceMethods<Fly_MmaAtomTypeInterface>
DeclareTypeInterfaceMethods<Fly_MmaOpTypeInterface>
])> {
let mnemonic = typeMnemonic;
}
Expand Down
10 changes: 5 additions & 5 deletions include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
include "flydsl/Dialect/FlyROCDL/IR/Dialect.td"

//===----------------------------------------------------------------------===//
// MmaAtom CDNA3
// MmaOp CDNA3
//===----------------------------------------------------------------------===//

def FlyROCDL_MmaAtomCDNA3_MFMA : FlyxROCL_MmaAtom<"MmaAtomCDNA3_MFMA", "atom.cdna3.mfma", []> {
def FlyROCDL_MmaOpCDNA3_MFMA : FlyxROCL_MmaOp<"MmaOpCDNA3_MFMA", "cdna3.mfma", []> {
let parameters = (ins
"int32_t":$m,
"int32_t":$n,
Expand All @@ -30,16 +30,16 @@ def FlyROCDL_MmaAtomCDNA3_MFMA : FlyxROCL_MmaAtom<"MmaAtomCDNA3_MFMA", "atom.cdn
}

//===----------------------------------------------------------------------===//
// MmaAtom CDNA4
// MmaOp CDNA4
//===----------------------------------------------------------------------===//



//===----------------------------------------------------------------------===//
// MmaAtom GFX1250 — WMMA wave32
// MmaOp GFX1250 — WMMA wave32
//===----------------------------------------------------------------------===//

def FlyROCDL_MmaAtomGFX1250_WMMA : FlyxROCL_MmaAtom<"MmaAtomGFX1250_WMMA", "atom.gfx1250.wmma", []> {
def FlyROCDL_MmaOpGFX1250_WMMA : FlyxROCL_MmaOp<"MmaOpGFX1250_WMMA", "gfx1250.wmma", []> {
let parameters = (ins
"int32_t":$m,
"int32_t":$n,
Expand Down
140 changes: 74 additions & 66 deletions lib/Bindings/Python/FlyExtension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,28 +471,6 @@ struct PyCoordTensorType : PyConcreteType<PyCoordTensorType> {
}
};

// ---------------------------------------------------------------------------
// CopyOpUniversalCopyType
// ---------------------------------------------------------------------------
struct PyCopyOpUniversalCopyType : PyConcreteType<PyCopyOpUniversalCopyType> {
FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::CopyOpUniversalCopyType, "CopyOpUniversalCopyType");

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](int32_t bitSize, DefaultingPyMlirContext context) {
MLIRContext *ctx = unwrap(context.get()->get());
return PyCopyOpUniversalCopyType(context->getRef(),
wrap(CopyOpUniversalCopyType::get(ctx, bitSize)));
},
"bitSize"_a, nb::kw_only(), "context"_a = nb::none(),
"Create a CopyOpUniversalCopyType with bit size");

c.def_prop_ro("bit_size",
[](PyCopyOpUniversalCopyType &self) { return self.toCppType().getBitSize(); });
}
};

// ---------------------------------------------------------------------------
// CopyAtomType
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -529,43 +507,37 @@ struct PyCopyAtomType : PyConcreteType<PyCopyAtomType> {
};

// ---------------------------------------------------------------------------
// MmaAtomUniversalFMAType
// MmaAtomType
// ---------------------------------------------------------------------------
struct PyMmaAtomUniversalFMAType : PyConcreteType<PyMmaAtomUniversalFMAType> {
FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::MmaAtomUniversalFMAType, "MmaAtomUniversalFMAType");
struct PyMmaAtomType : PyConcreteType<PyMmaAtomType> {
FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::MmaAtomType, "MmaAtomType");

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType &elemTyObj, DefaultingPyMlirContext context) {
return PyMmaAtomUniversalFMAType(context->getRef(),
wrap(MmaAtomUniversalFMAType::get(unwrap(elemTyObj))));
[](PyType &mmaOp, DefaultingPyMlirContext context) {
return PyMmaAtomType(context->getRef(), wrap(MmaAtomType::get(unwrap(mmaOp))));
},
"elem_ty"_a, nb::kw_only(), "context"_a = nb::none(),
"Create a MmaAtomUniversalFMAType with element type");
"mma_op"_a, nb::kw_only(), "context"_a = nb::none(),
"Create a MmaAtomType wrapping an MmaOpTypeInterface type");

c.def_prop_ro("elem_ty", [](PyMmaAtomUniversalFMAType &self) -> MlirType {
return wrap(self.toCppType().getElemTy());
c.def_prop_ro("mma_op", [](PyMmaAtomType &self) -> MlirType {
return wrap(self.toCppType().getMmaOp());
});
c.def_prop_ro("thr_layout", [](PyMmaAtomUniversalFMAType &self) -> MlirType {
auto ty = cast<MmaAtomTypeInterface>(self.toCppType());
return wrap(LayoutType::get(cast<LayoutAttr>(ty.getThrLayout())));
c.def_prop_ro("thr_layout", [](PyMmaAtomType &self) -> MlirType {
return wrap(LayoutType::get(cast<LayoutAttr>(self.toCppType().getThrLayout())));
});
c.def_prop_ro("shape_mnk", [](PyMmaAtomUniversalFMAType &self) -> MlirType {
auto ty = cast<MmaAtomTypeInterface>(self.toCppType());
return wrap(IntTupleType::get(cast<IntTupleAttr>(ty.getShapeMNK())));
c.def_prop_ro("shape_mnk", [](PyMmaAtomType &self) -> MlirType {
return wrap(IntTupleType::get(cast<IntTupleAttr>(self.toCppType().getShapeMNK())));
});
c.def_prop_ro("tv_layout_a", [](PyMmaAtomUniversalFMAType &self) -> MlirType {
auto ty = cast<MmaAtomTypeInterface>(self.toCppType());
return wrap(LayoutType::get(cast<LayoutAttr>(ty.getThrValLayoutA())));
c.def_prop_ro("tv_layout_a", [](PyMmaAtomType &self) -> MlirType {
return wrap(LayoutType::get(cast<LayoutAttr>(self.toCppType().getThrValLayoutA())));
});
c.def_prop_ro("tv_layout_b", [](PyMmaAtomUniversalFMAType &self) -> MlirType {
auto ty = cast<MmaAtomTypeInterface>(self.toCppType());
return wrap(LayoutType::get(cast<LayoutAttr>(ty.getThrValLayoutB())));
c.def_prop_ro("tv_layout_b", [](PyMmaAtomType &self) -> MlirType {
return wrap(LayoutType::get(cast<LayoutAttr>(self.toCppType().getThrValLayoutB())));
});
c.def_prop_ro("tv_layout_c", [](PyMmaAtomUniversalFMAType &self) -> MlirType {
auto ty = cast<MmaAtomTypeInterface>(self.toCppType());
return wrap(LayoutType::get(cast<LayoutAttr>(ty.getThrValLayoutC())));
c.def_prop_ro("tv_layout_c", [](PyMmaAtomType &self) -> MlirType {
return wrap(LayoutType::get(cast<LayoutAttr>(self.toCppType().getThrValLayoutC())));
});
}
};
Expand Down Expand Up @@ -621,41 +593,82 @@ struct PyTiledMmaType : PyConcreteType<PyTiledMmaType> {
});
c.def_prop_ro("tile_size_mnk", [](PyTiledMmaType &self) -> MlirType {
auto ty = self.toCppType();
auto mmaAtom = cast<MmaAtomTypeInterface>(ty.getMmaAtom());
auto mmaAtom = cast<MmaAtomType>(ty.getMmaAtom());
auto result = tiledMmaGetTileSizeMNK(mmaAtom, ty.getAtomLayout().getAttr(),
ty.getPermutation().getAttr());
return wrap(IntTupleType::get(result));
});
c.def_prop_ro("thr_layout_vmnk", [](PyTiledMmaType &self) -> MlirType {
auto ty = self.toCppType();
auto mmaAtom = cast<MmaAtomTypeInterface>(ty.getMmaAtom());
auto mmaAtom = cast<MmaAtomType>(ty.getMmaAtom());
auto result = tiledMmaGetThrLayoutVMNK(mmaAtom, ty.getAtomLayout().getAttr());
return wrap(LayoutType::get(result));
});
c.def_prop_ro("tiled_tv_layout_a", [](PyTiledMmaType &self) -> MlirType {
auto ty = self.toCppType();
auto mmaAtom = cast<MmaAtomTypeInterface>(ty.getMmaAtom());
auto mmaAtom = cast<MmaAtomType>(ty.getMmaAtom());
auto result = tiledMmaGetTiledThrValLayout(mmaAtom, ty.getAtomLayout().getAttr(),
ty.getPermutation().getAttr(), MmaOperand::A);
return wrap(LayoutType::get(result));
});
c.def_prop_ro("tiled_tv_layout_b", [](PyTiledMmaType &self) -> MlirType {
auto ty = self.toCppType();
auto mmaAtom = cast<MmaAtomTypeInterface>(ty.getMmaAtom());
auto mmaAtom = cast<MmaAtomType>(ty.getMmaAtom());
auto result = tiledMmaGetTiledThrValLayout(mmaAtom, ty.getAtomLayout().getAttr(),
ty.getPermutation().getAttr(), MmaOperand::B);
return wrap(LayoutType::get(result));
});
c.def_prop_ro("tiled_tv_layout_c", [](PyTiledMmaType &self) -> MlirType {
auto ty = self.toCppType();
auto mmaAtom = cast<MmaAtomTypeInterface>(ty.getMmaAtom());
auto mmaAtom = cast<MmaAtomType>(ty.getMmaAtom());
auto result = tiledMmaGetTiledThrValLayout(mmaAtom, ty.getAtomLayout().getAttr(),
ty.getPermutation().getAttr(), MmaOperand::C);
return wrap(LayoutType::get(result));
});
}
};

// ---------------------------------------------------------------------------
// CopyOpUniversalCopyType
// ---------------------------------------------------------------------------
struct PyCopyOpUniversalCopyType : PyConcreteType<PyCopyOpUniversalCopyType> {
FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::CopyOpUniversalCopyType, "CopyOpUniversalCopyType");

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](int32_t bitSize, DefaultingPyMlirContext context) {
MLIRContext *ctx = unwrap(context.get()->get());
return PyCopyOpUniversalCopyType(context->getRef(),
wrap(CopyOpUniversalCopyType::get(ctx, bitSize)));
},
"bitSize"_a, nb::kw_only(), "context"_a = nb::none(),
"Create a CopyOpUniversalCopyType with bit size");
}
};

// ---------------------------------------------------------------------------
// MmaOpUniversalFMAType
// ---------------------------------------------------------------------------
struct PyMmaOpUniversalFMAType : PyConcreteType<PyMmaOpUniversalFMAType> {
FLYDSL_REGISTER_TYPE_BINDING(::mlir::fly::MmaOpUniversalFMAType, "MmaOpUniversalFMAType");

static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType &elemTyObj, DefaultingPyMlirContext context) {
return PyMmaOpUniversalFMAType(context->getRef(),
wrap(MmaOpUniversalFMAType::get(unwrap(elemTyObj))));
},
"elem_ty"_a, nb::kw_only(), "context"_a = nb::none(),
"Create a MmaOpUniversalFMAType with element type");

c.def_prop_ro("elem_ty", [](PyMmaOpUniversalFMAType &self) -> MlirType {
return wrap(self.toCppType().getElemTy());
});
}
};

} // namespace fly
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
Expand Down Expand Up @@ -729,11 +742,12 @@ NB_MODULE(_mlirDialectsFly, m) {
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyPointerType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyMemRefType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyCoordTensorType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyCopyOpUniversalCopyType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyCopyAtomType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyMmaAtomUniversalFMAType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyMmaAtomType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyTiledCopyType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyTiledMmaType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyCopyOpUniversalCopyType::bind(m);
::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly::PyMmaOpUniversalFMAType::bind(m);

m.def(
"set_llvm_option_bool",
Expand All @@ -743,12 +757,10 @@ NB_MODULE(_mlirDialectsFly, m) {
if (rc == 1)
throw std::runtime_error("Unknown LLVM option: " + name);
if (rc == 2)
throw std::runtime_error("LLVM option '" + name +
"' is not a bool option");
throw std::runtime_error("LLVM option '" + name + "' is not a bool option");
return oldValue;
},
"name"_a, "value"_a,
"Set an LLVM bool cl::opt at runtime; returns the previous value.");
"name"_a, "value"_a, "Set an LLVM bool cl::opt at runtime; returns the previous value.");

m.def(
"set_llvm_option_int",
Expand All @@ -758,12 +770,10 @@ NB_MODULE(_mlirDialectsFly, m) {
if (rc == 1)
throw std::runtime_error("Unknown LLVM option: " + name);
if (rc == 2)
throw std::runtime_error("LLVM option '" + name +
"' is not an int option");
throw std::runtime_error("LLVM option '" + name + "' is not an int option");
return oldValue;
},
"name"_a, "value"_a,
"Set an LLVM int cl::opt at runtime; returns the previous value.");
"name"_a, "value"_a, "Set an LLVM int cl::opt at runtime; returns the previous value.");

m.def(
"set_llvm_option_str",
Expand All @@ -773,12 +783,10 @@ NB_MODULE(_mlirDialectsFly, m) {
if (rc == 1)
throw std::runtime_error("Unknown LLVM option: " + name);
if (rc == 2)
throw std::runtime_error("LLVM option '" + name +
"' is not a string option");
throw std::runtime_error("LLVM option '" + name + "' is not a string option");
std::string result(oldValue ? oldValue : "");
flydslFreeLLVMOptionStr(oldValue);
return result;
},
"name"_a, "value"_a,
"Set an LLVM string cl::opt at runtime; returns the previous value.");
"name"_a, "value"_a, "Set an LLVM string cl::opt at runtime; returns the previous value.");
}
Loading
Loading