Skip to content

[Refactor][DSL] Make MMA hierarchy symmetric with Copy #344

Open
sjfeng1999 wants to merge 6 commits intomainfrom
pr/refactor-mma-atom
Open

[Refactor][DSL] Make MMA hierarchy symmetric with Copy #344
sjfeng1999 wants to merge 6 commits intomainfrom
pr/refactor-mma-atom

Conversation

@sjfeng1999
Copy link
Copy Markdown
Collaborator

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Split concrete MMA ops from `MmaAtomType` and route lowering/bindings through the new wrapper, matching the existing `CopyOp`/`CopyAtom` layering.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the Fly DSL’s MMA “atom” hierarchy to mirror the existing Copy atom design by introducing an explicit MmaOp* layer (implementing a renamed MmaOpTypeInterface) and wrapping it in a new !fly.mma_atom<...> type, with corresponding updates across lowering, Python bindings, and visualization utilities.

Changes:

  • Introduce !fly.mma_atom<...> and rename/reshape MMA types from MmaAtom* to MmaOp* (CDNA3 MFMA, GFX1250 WMMA, universal FMA), updating C++/TableGen, lowering, and Python APIs.
  • Update Python DSL surface APIs (make_mma_atom, rocdl.MFMA/WMMA, typing/value casters) and add Typst visualization support for MmaAtom.
  • Extend/adjust MLIR conversion tests to cover GEMM from a tiled MMA argument (and update existing MMA atom call test typing).

Reviewed changes

Copilot reviewed 25 out of 25 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
tests/mlir/Conversion/mma_atom.mlir Updates MMA atom test to use !fly.mma_atom<...> and adds a GEMM-from-tiled-mma-arg test.
python/flydsl/expr/utils/print_typst.py Adds Typst rendering for MmaAtom and updates dispatch/docs.
python/flydsl/expr/typing.py Adds MmaAtom value caster exposing atom layouts/shape for Python-side introspection.
python/flydsl/expr/rocdl/universal.py Renames ROCDL MMA type imports/constructors to MmaOp*.
python/flydsl/expr/rocdl.py Renames ROCDL MMA type imports/constructors and updates public exports.
python/flydsl/expr/primitive.py Adjusts make_mma_atom to build a MmaAtomType wrapper and renames universal FMA type usage.
python/flydsl/expr/derived.py Removes the old Python-side MmaAtom wrapper class in favor of the typing.py caster.
lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp Renames type to MmaOpGFX1250_WMMAType and wraps rebuild via MmaAtomType::get(...).
lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp Renames type to MmaOpCDNA3_MFMAType and wraps rebuild via MmaAtomType::get(...).
lib/Dialect/Fly/Transforms/LayoutLowering.cpp Switches tiled MMA lowering to require MmaAtomType (wrapper) instead of the old interface.
lib/Dialect/Fly/IR/FlyUniversalOps.cpp Moves universal copy/FMA type implementations into a dedicated compilation unit.
lib/Dialect/Fly/IR/FlyTypeDefs.cpp Implements MmaAtomType delegation to the underlying MmaOpTypeInterface.
lib/Dialect/Fly/IR/FlyOps.cpp Updates inference error messages and checks to align with MmaAtomType wrapper.
lib/Dialect/Fly/CMakeLists.txt Adds IR/FlyUniversalOps.cpp to the Fly dialect library.
lib/Conversion/FlyToROCDL/FlyToROCDL.cpp Updates mma_atom_call lowering to dispatch via MmaAtomType.getMmaOp() (MmaOp* types).
lib/Bindings/Python/TiledOpTraits.h / .cpp Updates helper signatures to take MmaAtomType rather than the removed interface.
lib/Bindings/Python/FlyROCDLExtension.cpp Renames Python-exposed ROCDL MMA types to MmaOp*.
lib/Bindings/Python/FlyExtension.cpp Adds Python binding for MmaAtomType and binds MmaOpUniversalFMAType.
include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td Renames FlyROCDL MMA types to MmaOp* and updates mnemonics.
include/flydsl/Dialect/FlyROCDL/IR/Dialect.td Renames the FlyROCDL MMA typedef base class to FlyxROCL_MmaOp using Fly_MmaOpTypeInterface.
include/flydsl/Dialect/Fly/Utils/TiledOpUtils.h Updates tiled MMA layout utilities to accept MmaAtomType wrapper.
include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td Introduces Fly_MmaAtom wrapper type and renames universal FMA type to Fly_MmaOpUniversalFMA.
include/flydsl/Dialect/Fly/IR/FlyInterfaces.td Renames Fly_MmaAtomTypeInterface to Fly_MmaOpTypeInterface.
examples/utils/print_typst.py Demonstrates Typst printing for a standalone mma_atom in addition to tiled MMA/copy.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

%b: !fly.memref<f32, register, 1:1>,
%c: !fly.memref<f32, register, 4:1>) {
%atom = fly.make_mma_atom : !fly_rocdl.atom.cdna3.mfma<16x16x4, (f32, f32) -> f32>
%atom = fly.make_mma_atom : !fly.mma_atom<!fly_rocdl.atom.cdna3.mfma<16x16x4, (f32, f32) -> f32>>
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlyROCDL_MmaOpCDNA3_MFMA's TableGen mnemonic was changed to cdna3.mfma (see include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td), but this test still uses !fly_rocdl.atom.cdna3.mfma<...>. This will fail to parse/roundtrip once the new mnemonic is in effect. Update the IR here to the new type spelling.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants