[Refactor][DSL] Make MMA hierarchy symmetric with Copy #344
[Refactor][DSL] Make MMA hierarchy symmetric with Copy #344sjfeng1999 wants to merge 6 commits intomainfrom
Conversation
Split concrete MMA ops from `MmaAtomType` and route lowering/bindings through the new wrapper, matching the existing `CopyOp`/`CopyAtom` layering.
There was a problem hiding this comment.
Pull request overview
This PR refactors the Fly DSL’s MMA “atom” hierarchy to mirror the existing Copy atom design by introducing an explicit MmaOp* layer (implementing a renamed MmaOpTypeInterface) and wrapping it in a new !fly.mma_atom<...> type, with corresponding updates across lowering, Python bindings, and visualization utilities.
Changes:
- Introduce
!fly.mma_atom<...>and rename/reshape MMA types fromMmaAtom*toMmaOp*(CDNA3 MFMA, GFX1250 WMMA, universal FMA), updating C++/TableGen, lowering, and Python APIs. - Update Python DSL surface APIs (
make_mma_atom,rocdl.MFMA/WMMA, typing/value casters) and add Typst visualization support forMmaAtom. - Extend/adjust MLIR conversion tests to cover GEMM from a tiled MMA argument (and update existing MMA atom call test typing).
Reviewed changes
Copilot reviewed 25 out of 25 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/mlir/Conversion/mma_atom.mlir | Updates MMA atom test to use !fly.mma_atom<...> and adds a GEMM-from-tiled-mma-arg test. |
| python/flydsl/expr/utils/print_typst.py | Adds Typst rendering for MmaAtom and updates dispatch/docs. |
| python/flydsl/expr/typing.py | Adds MmaAtom value caster exposing atom layouts/shape for Python-side introspection. |
| python/flydsl/expr/rocdl/universal.py | Renames ROCDL MMA type imports/constructors to MmaOp*. |
| python/flydsl/expr/rocdl.py | Renames ROCDL MMA type imports/constructors and updates public exports. |
| python/flydsl/expr/primitive.py | Adjusts make_mma_atom to build a MmaAtomType wrapper and renames universal FMA type usage. |
| python/flydsl/expr/derived.py | Removes the old Python-side MmaAtom wrapper class in favor of the typing.py caster. |
| lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp | Renames type to MmaOpGFX1250_WMMAType and wraps rebuild via MmaAtomType::get(...). |
| lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp | Renames type to MmaOpCDNA3_MFMAType and wraps rebuild via MmaAtomType::get(...). |
| lib/Dialect/Fly/Transforms/LayoutLowering.cpp | Switches tiled MMA lowering to require MmaAtomType (wrapper) instead of the old interface. |
| lib/Dialect/Fly/IR/FlyUniversalOps.cpp | Moves universal copy/FMA type implementations into a dedicated compilation unit. |
| lib/Dialect/Fly/IR/FlyTypeDefs.cpp | Implements MmaAtomType delegation to the underlying MmaOpTypeInterface. |
| lib/Dialect/Fly/IR/FlyOps.cpp | Updates inference error messages and checks to align with MmaAtomType wrapper. |
| lib/Dialect/Fly/CMakeLists.txt | Adds IR/FlyUniversalOps.cpp to the Fly dialect library. |
| lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | Updates mma_atom_call lowering to dispatch via MmaAtomType.getMmaOp() (MmaOp* types). |
| lib/Bindings/Python/TiledOpTraits.h / .cpp | Updates helper signatures to take MmaAtomType rather than the removed interface. |
| lib/Bindings/Python/FlyROCDLExtension.cpp | Renames Python-exposed ROCDL MMA types to MmaOp*. |
| lib/Bindings/Python/FlyExtension.cpp | Adds Python binding for MmaAtomType and binds MmaOpUniversalFMAType. |
| include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td | Renames FlyROCDL MMA types to MmaOp* and updates mnemonics. |
| include/flydsl/Dialect/FlyROCDL/IR/Dialect.td | Renames the FlyROCDL MMA typedef base class to FlyxROCL_MmaOp using Fly_MmaOpTypeInterface. |
| include/flydsl/Dialect/Fly/Utils/TiledOpUtils.h | Updates tiled MMA layout utilities to accept MmaAtomType wrapper. |
| include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td | Introduces Fly_MmaAtom wrapper type and renames universal FMA type to Fly_MmaOpUniversalFMA. |
| include/flydsl/Dialect/Fly/IR/FlyInterfaces.td | Renames Fly_MmaAtomTypeInterface to Fly_MmaOpTypeInterface. |
| examples/utils/print_typst.py | Demonstrates Typst printing for a standalone mma_atom in addition to tiled MMA/copy. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tests/mlir/Conversion/mma_atom.mlir
Outdated
| %b: !fly.memref<f32, register, 1:1>, | ||
| %c: !fly.memref<f32, register, 4:1>) { | ||
| %atom = fly.make_mma_atom : !fly_rocdl.atom.cdna3.mfma<16x16x4, (f32, f32) -> f32> | ||
| %atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.atom.cdna3.mfma<16x16x4, (f32, f32) -> f32>> |
There was a problem hiding this comment.
FlyROCDL_MmaOpCDNA3_MFMA's TableGen mnemonic was changed to cdna3.mfma (see include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td), but this test still uses !fly_rocdl.atom.cdna3.mfma<...>. This will fail to parse/roundtrip once the new mnemonic is in effect. Update the IR here to the new type spelling.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist