From 6fa4f504d61e48f7cea454ce0c1f6169907d5aaa Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Mon, 30 Sep 2024 09:47:14 +0200 Subject: [PATCH 01/18] [BACKEND] Update LLVM version to https://github.com/llvm/llvm-project/commit/29b92d07746fac26cd64c914bc9c5c3833974f6d (#4828) --- cmake/llvm-hash.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 36344442bd3a..547d6a6cd659 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -df0864e761107b07e38f5503e0cbee0cebb4c5e8 +29b92d07746fac26cd64c914bc9c5c3833974f6d From 1ef30e9f794cbdb9322483a088832746977d9c66 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Mon, 30 Sep 2024 10:52:06 +0200 Subject: [PATCH 02/18] OpenXLA-specific changes --- BUILD | 908 ++++++++++++++++++ .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 3 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 5 + .../TritonGPU/Transforms/AccelerateMatmul.cpp | 24 + .../Transforms/OptimizeDotOperands.cpp | 17 +- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 17 +- python/BUILD | 77 ++ python/test/regression/BUILD | 26 + python/test/regression/conftest.py | 12 + python/test/unit/BUILD | 180 ++++ python/test/unit/language/test_core.py | 21 + python/triton/_C/include | 2 +- python/triton/backends/__init__.py | 7 +- test/BUILD | 63 ++ test/TritonGPU/accelerate-matmul.mlir | 19 +- test/TritonGPU/canonicalize.mlir | 16 + test/TritonGPU/prefetch.mlir | 17 + third_party/amd/BUILD | 250 +++++ third_party/f2reduce/BUILD | 31 + third_party/nvidia/BUILD | 306 ++++++ third_party/nvidia/backend/BUILD | 30 + third_party/nvidia/backend/driver.c | 12 + .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 34 +- third_party/nvidia/triton_nvidia.cc | 2 +- third_party/proton/proton/_C/include | 2 +- unittest/BUILD | 144 +++ 26 files changed, 2212 insertions(+), 13 deletions(-) create mode 100644 BUILD create mode 100644 python/BUILD create mode 100644 python/test/regression/BUILD create mode 100644 python/test/regression/conftest.py create mode 100644 python/test/unit/BUILD create mode 100644 test/BUILD create mode 100644 third_party/amd/BUILD create mode 100644 third_party/f2reduce/BUILD create mode 100644 third_party/nvidia/BUILD create mode 100644 third_party/nvidia/backend/BUILD create mode 100644 unittest/BUILD diff --git a/BUILD b/BUILD new file mode 100644 index 000000000000..6381b59d31fc --- /dev/null +++ b/BUILD @@ -0,0 +1,908 @@ +# This package imports OpenAI's Triton (https://github.com/openai/triton). +# +# There are two versions of Triton in google3 at the moment. The older version +# can be found at //third_party/py/triton. This is the MLIR-based version close +# to head. We expect to transition users to this version in the following +# weeks. +# +# There is no SLA associated with this package and it may get broken by LLVM +# imports at any time. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = [":license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # # Add your project here if you need to depend on Triton's C++ sources. + # # Add a point of contact we can reach out to when needed in the comment. + # # + # # If you need to use the Python fronted, add your project to + # # google3/third_party/py/triton/BUILD instead. + # # + # # By adding your project here, you agree to the Triton SLA: go/triton-google3-sla + # "//third_party/py/jax:__subpackages__", # cjfj@ + # "//third_party/tensorflow/compiler/xla:__subpackages__", # bchetioui@ + # "//platforms/xla/experimental/gpu:__subpackages__", # csigg@ + # # Triton-internal visibility + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end + # TODO(csigg): fix and remove + features = [ + "-parse_headers", + "-use_header_modules", + ], +) + +# copybara:uncomment_begin +# license(name = "license") +# +# licenses(["notice"]) +# +# exports_files(["LICENSE"]) +# copybara:uncomment_end + +config_setting( + name = "compiler_is_msvc", + flag_values = { + # copybara:comment_begin + "@bazel_tools" + + # copybara:comment_end + "//tools/cpp:compiler": "msvc-cl", + }, +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + ":compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +td_library( + name = "td_files", + srcs = glob(["include/triton/**/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "triton_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/Triton/IR/TritonAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/Triton/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/Triton/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonInterfaces.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-enum-decls"], + "include/triton/Dialect/Triton/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/Triton/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-op-decls"], + "include/triton/Dialect/Triton/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/Triton/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/Triton/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/Triton/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=Triton", + ], + "include/triton/Dialect/Triton/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_combine_inc_gen", + # The generated file is #included without relative path. + strip_include_prefix = "lib/Dialect/Triton/Transforms", + tbl_outs = [ + ( + ["--gen-rewriters"], + "lib/Dialect/Triton/Transforms/TritonCombine.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/Triton/Transforms/Combine.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc", + ), + ( + ["--gen-attr-interface-decls"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.h.inc", + ), + ( + ["--gen-attr-interface-defs"], + "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPU", + ], + "include/triton/Dialect/TritonGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-op-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_types_inc_gen", + tbl_outs = [ + ( + ["--gen-typedef-decls"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.h.inc", + ), + ( + ["--gen-typedef-defs"], + "include/triton/Dialect/TritonNvidiaGPU/IR/Types.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUTypes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_nvidia_gpu_transforms_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNvidiaGPU", + ], + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_to_triton_gpu_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonToTritonGPU", + ], + "include/triton/Conversion/TritonToTritonGPU/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonToTritonGPU/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_target_llvmir_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonLLVMIR", + ], + "include/triton/Target/LLVMIR/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Target/LLVMIR/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonGPUToLLVM", + ], + "include/triton/Conversion/TritonGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Conversion/TritonGPUToLLVM/Passes.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "triton_type_interfaces_inc_gen", + tbl_outs = [ + ( + ["--gen-type-interface-decls"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc", + ), + ( + ["--gen-type-interface-defs"], + "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/triton/Dialect/Triton/IR/TritonTypeInterfaces.td", + deps = ["td_files"], +) + +cc_library( + name = "TritonAnalysis", + srcs = [ + "lib/Analysis/Alias.cpp", + "lib/Analysis/Allocation.cpp", + "lib/Analysis/Membar.cpp", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "lib/Analysis/Utility.cpp", + # "lib/Analysis/AxisInfo.cpp", + ], + hdrs = [ + "include/triton/Analysis/Alias.h", + "include/triton/Analysis/Allocation.h", + "include/triton/Analysis/Membar.h", + # Part of TritonDialects compilation unit to avoid circular dependencies. + # "include/triton/Analysis/AxisInfo.h", + # "include/triton/Analysis/Utility.h", + "include/triton/Conversion/MLIRTypes.h", + "include/triton/Conversion/TritonGPUToLLVM/AsmFormat.h", + "include/triton/Conversion/TritonGPUToLLVM/Utility.h", + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", + ], + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonDialects", + srcs = glob([ + "lib/Dialect/Triton/IR/*.cpp", + "lib/Dialect/TritonGPU/IR/*.cpp", + "lib/Dialect/TritonNvidiaGPU/IR/*.cpp", + "lib/Tools/*.cpp", + ]) + [ + "include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h", # Avoid circular dependency. + "lib/Analysis/AxisInfo.cpp", # Avoid circular dependency. + "lib/Analysis/Utility.cpp", # Avoid circular dependency. + "lib/Dialect/TritonGPU/Transforms/Utility.cpp", # Avoid circular dependency. + ], + hdrs = glob([ + "include/triton/Dialect/Triton/IR/*.h", + "include/triton/Dialect/TritonGPU/IR/*.h", + "include/triton/Dialect/TritonNvidiaGPU/IR/*.h", + "include/triton/Tools/*.h", + ]) + [ + "include/triton/Analysis/AxisInfo.h", # Avoid circular dependency. + "include/triton/Analysis/Utility.h", # Avoid circular dependency. + "include/triton/Dialect/TritonGPU/Transforms/Utility.h", # Avoid circular dependency. + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = ["include"], + deps = [ + ":triton_dialect_inc_gen", + ":triton_gpu_attr_inc_gen", + ":triton_gpu_dialect_inc_gen", + ":triton_gpu_ops_inc_gen", + ":triton_gpu_types_inc_gen", + ":triton_interfaces_inc_gen", + ":triton_nvidia_gpu_attr_inc_gen", + ":triton_nvidia_gpu_dialect_inc_gen", + ":triton_nvidia_gpu_ops_inc_gen", + ":triton_nvidia_gpu_types_inc_gen", + ":triton_ops_inc_gen", + ":triton_types_inc_gen", + ":triton_type_interfaces_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@triton//third_party/nvidia:NVGPUDialect", + # The following is added to make Utility compile + ":TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@triton//third_party/f2reduce", + ], +) + +cc_library( + name = "TritonTransforms", + srcs = glob(["lib/Dialect/Triton/Transforms/*.cpp"]), + hdrs = glob(["include/triton/Dialect/Triton/Transforms/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonDialects", + ":triton_combine_inc_gen", + ":triton_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], + alwayslink = True, # TritonDialect uses getCanonicalizationPatterns(). +) + +cc_library( + name = "TritonGPUTransforms", + srcs = glob( + [ + "lib/Dialect/TritonGPU/Transforms/*.cpp", + "lib/Dialect/TritonGPU/Transforms/*.h", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.cpp", + "lib/Dialect/TritonGPU/Transforms/Pipeliner/*.h", + ], + exclude = ["lib/Dialect/TritonGPU/Transforms/Utility.cpp"], + ), + hdrs = glob( + [ + "include/triton/Dialect/TritonGPU/Transforms/*.h", + ], + exclude = ["include/triton/Dialect/TritonGPU/Transforms/Utility.h"], + ) + [ + "include/triton/Tools/Sys/GetEnv.hpp", + ], + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":TritonGPUToLLVM", + ":triton_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonGPUToLLVM", + srcs = glob([ + "lib/Conversion/TritonGPUToLLVM/*.h", + "lib/Conversion/TritonGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/triton/Tools/Sys/*.hpp", + "include/triton/Conversion/TritonGPUToLLVM/*.h", + ]), + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonAnalysis", + ":TritonDialects", + ":triton_conversion_triton_gpu_to_llvm_pass_inc_gen", + ":triton_gpu_attr_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonNvidiaGPUTransforms", + srcs = glob([ + "lib/Dialect/TritonNvidiaGPU/Transforms/*.cpp", + ]), + hdrs = glob([ + "include/triton/Dialect/TritonNvidiaGPU/Transforms/*.h", + ]), + copts = select({ + ":compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-ctad-maybe-unsupported", + "-Wno-logical-op-parentheses", + "-Wno-non-virtual-dtor", + "-Wno-return-type", + "-Wno-unused-variable", + ], + }), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonToTritonGPU", + srcs = glob([ + "lib/Conversion/TritonToTritonGPU/*.h", + "lib/Conversion/TritonToTritonGPU/*.cpp", + ]), + hdrs = glob(["include/triton/Conversion/TritonToTritonGPU/*.h"]), + includes = ["include"], + deps = [ + ":TritonDialects", + ":TritonGPUTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "TritonLLVMIR", + srcs = glob([ + "lib/Target/LLVMIR/*.cpp", + "lib/Target/LLVMIR/*.h", + ]), + hdrs = glob(["include/triton/Target/LLVMIR/*.h"]), + copts = _no_unused_variable, + deps = [ + ":TritonTransforms", + ":triton_target_llvmir_passes_inc_gen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BinaryFormat", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexToLLVM", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLToLLVMIRTranslation", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + # copybara:uncomment "//third_party/py/triton/google:find_cuda", + ], +) + +cc_library( + name = "TritonPTX", + srcs = glob([ + "lib/Target/PTX/*.cpp", + ]), + hdrs = glob(["include/triton/Target/PTX/*.h"]), + deps = ["@llvm-project//llvm:Support"], +) + +cc_library( + name = "TritonHSACO", + srcs = glob([ + "lib/Target/HSACO/*.cpp", + ]), + hdrs = glob(["include/triton/Target/HSACO/*.h"]), + deps = [ + ":TritonLLVMIR", + ":TritonTools", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:ExecutionEngine", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Scalar", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ExecutionEngine", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + ], +) + +cc_library( + name = "TritonTools", + hdrs = ["include/triton/Tools/Sys/GetEnv.hpp"], +) + +cc_library( + name = "AllPassesAndDialects", + srcs = [ + "include/triton/Conversion/TritonToTritonGPU/Passes.h", + "include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h", + ], + hdrs = ["bin/RegisterTritonDialects.h"], + includes = ["."], # because it includes third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + deps = [ + ":TritonDialects", + ":TritonGPUToLLVM", + ":TritonGPUTransforms", + ":TritonLLVMIR", + ":TritonNvidiaGPUTransforms", + ":TritonToTritonGPU", + ":TritonTransforms", + ":triton_conversion_triton_to_triton_gpu_passes_inc_gen", + ":triton_nvidia_gpu_transforms_inc_gen", + "@llvm-project//mlir:AllPassesAndDialects", + "@triton//test:TritonTestAnalysis", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUToLLVM", + "@triton//third_party/amd:TritonAMDGPUTransforms", + "@triton//third_party/nvidia:NVGPUDialect", + "@triton//third_party/nvidia:NVGPUToLLVM", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +cc_binary( + name = "triton-opt", + srcs = [ + "bin/triton-opt.cpp", + ], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + ], +) + +cc_binary( + name = "triton-llvm-opt", + srcs = [ + "bin/triton-llvm-opt.cpp", + "lib/Target/LLVMIR/LLVMPasses.h", + ], + deps = [ + ":TritonLLVMIR", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Option", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + ], +) + +# See go/triton-debug for usage. +cc_binary( + name = "triton-reduce", + srcs = ["bin/triton-reduce.cpp"], + deps = [ + ":AllPassesAndDialects", + "@llvm-project//mlir:MlirReduceLib", + "@triton//third_party/amd:TritonAMDGPU", + "@triton//third_party/amd:TritonAMDGPUDialectToLLVM", + ], +) + +cc_binary( + name = "triton-tensor-layout", + srcs = ["bin/triton-tensor-layout.cpp"], + deps = [ + ":AllPassesAndDialects", + ":TritonDialects", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + ], +) + +filegroup( + name = "metadata-file", + srcs = ["METADATA"], +) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 787dee35fb25..2d21a1bfeca2 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -40,7 +40,8 @@ SmallVector reorderValues(const SmallVector &values, Type inType, auto ouEltTy = ouTensorTy.getElementType(); if (inBitWidth == ouBitWidth) return values; - if (inBitWidth == 16 && ouBitWidth == 32) { + if ((inBitWidth == 16 && ouBitWidth == 32) || + (inBitWidth == 32 && ouBitWidth == 16)) { SmallVector ret; for (unsigned i = 0; i < values.size(); i += 8) { ret.push_back(values[i]); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 48f31bdf2a9d..6bc4ca6f9eae 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2720,6 +2720,11 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding, so we want to keep this layout conversion. + if (mlir::isa( + convert.getSrc().getType().getEncoding())) + return failure(); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); return mlir::success(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index d9bbd51bd9a1..7776a93305ff 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, auto newType = MemDescType::get(argType.getShape(), argType.getElementType(), newLayout, SharedMemorySpace); rewriter.setInsertionPointAfterValue(arg); + + // LocalAllocOp lowering doesn't support going from DotOperandEncoding + // to SharedEncoding. + if (auto dotOpEnc = mlir::dyn_cast( + argType.getEncoding())) { + // Create a layout conversion from DotOperandEncoding to BlockedEncoding + // then pass it to the LocalAllocOp. + auto newArgType = RankedTensorType::get( + argType.getShape(), argType.getElementType(), dotOpEnc.getParent()); + auto dotOperandToBlockedCvt = + rewriter.create(arg.getLoc(), newArgType, arg); + return rewriter.create(arg.getLoc(), newType, + dotOperandToBlockedCvt); + } + return rewriter.create(arg.getLoc(), newType, arg); } @@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern { mutable llvm::DenseMap dotOpInstNs; static bool bwdFilter(Operation *op) { + // Dot operand layout assignment to Predicates are not currently supported + // during lowering from TritonGPU to LLVM in Triton for MMA cases. This + // condition limits visibility of the original bit-width so that predicate + // are not considered, hence, kwidth can never be = 32. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + return false; + } return op->getNumOperands() == 1 && (isa(op) || isPureUnaryInlineAsm(op) || diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 6d8279795209..e6e0ec8d7cef 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -111,7 +111,8 @@ class HoistLayoutConversion : public OpRewritePattern { PatternRewriter &rewriter) const override { // Only consider conversions to dot operand. auto cvtTy = cast(cvt.getType()); - if (!isa(cvtTy.getEncoding())) + auto dotOpEnc = dyn_cast(cvtTy.getEncoding()); + if (!dotOpEnc) return failure(); auto src = cvt.getSrc().getDefiningOp(); @@ -126,6 +127,12 @@ class HoistLayoutConversion : public OpRewritePattern { [](Type ty) { return isa(ty); })) return failure(); + // Quick handling to fix loading issues when computing the original + // bitwidth is unable to realize that there is a mixed-precision dot + // (hence kWidth = 1) but wants to hoist through the type conversion. + if (isa(src) && dotOpEnc.getKWidth() == 1) + return failure(); + // Only consider custom conversions or arith ops. // TODO(jlebar): Is this too restrictive? if (!isa(src) && !isPureUnaryInlineAsm(src) && @@ -138,6 +145,14 @@ class HoistLayoutConversion : public OpRewritePattern { if (isa(src)) return failure(); + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(src)) { + Type srcType = getElementTypeOrSelf(src->getOperand(0)); + if (srcType.isInteger(1)) + return failure(); + } + // Check that the conversion is transitively dependent on a load, and all // operations between the load and the conversion are layout preserving. // diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 2cbc00142b42..db71b3b82061 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, type.getMemorySpace()), v, offsetsVal); + // We need to assign kwidth to zero in the case where the parent layout is + // Blocked, otherwise the verifier emits a failure. The parent layout is + // Blocked only when Tensor Cores are disabled. + int kwidth = dyn_cast(dotEncoding) + ? 0 + : prefetchWidth / 8; auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); + builder.getContext(), opIdx, dotEncoding, kwidth); Value prefetchSlice = builder.create( v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); @@ -190,6 +196,15 @@ LogicalResult Prefetcher::initialize() { break; if (!op->getResult(0).hasOneUse()) break; + // Similar to issues faced in HoistLayoutConversion pattern in + // OptimizeDotOperands.cpp, we can't propagate through type casts from + // predicates as they aren't supported in Triton when encoded with dot_op + // layout. + if (isa(op)) { + Type srcType = getElementTypeOrSelf(op->getOperand(0)); + if (srcType.isInteger(1)) + break; + } rets.push_back(op->getOperand(0)); if (auto cvt = dyn_cast(op)) { foundConvertFromShared = true; diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 000000000000..334dd4aec41a --- /dev/null +++ b/python/BUILD @@ -0,0 +1,77 @@ +# NOTE: Do not depend on any targets from this directory, +# but use //third_party/py/triton instead. + +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__pkg__", + "@triton//python:__subpackages__", + ], +) + +cc_library( + name = "passes", + hdrs = ["src/passes.h"], + includes = ["src"], + visibility = ["@triton//third_party:__subpackages__"], +) + +pybind_extension( + name = "libtriton", + srcs = [ + "src/interpreter.cc", + "src/ir.cc", + "src/llvm.cc", + "src/main.cc", + "src/passes.cc", + ], + copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"], + deps = [ + ":passes", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:IPO", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:InstCombine", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//mlir:BuiltinToLLVMIRTranslation", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMIRTransforms", + "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:ToLLVMIRTranslation", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + "//:TritonHSACO", + "//:TritonLLVMIR", + "//:TritonNvidiaGPUTransforms", + "//:TritonPTX", + "//:TritonToTritonGPU", + "//:TritonTools", + "//:TritonTransforms", + "@triton//third_party/nvidia:triton_nvidia", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["triton/**/*.py"], + ), +) diff --git a/python/test/regression/BUILD b/python/test/regression/BUILD new file mode 100644 index 000000000000..a88f4eeae1f8 --- /dev/null +++ b/python/test/regression/BUILD @@ -0,0 +1,26 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests") + +package( + default_applicable_licenses = ["//:license"], +) + +pytest_multi_tests( + name = "tests", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm80", + ], + tests = glob( + include = ["test_*.py"], + exclude = [ + "test_performance.py", #TODO(b/321005767): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/regression/conftest.py b/python/test/regression/conftest.py new file mode 100644 index 000000000000..7a02d322b49f --- /dev/null +++ b/python/test/regression/conftest.py @@ -0,0 +1,12 @@ +# content of conftest.py + +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default='cuda') + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/python/test/unit/BUILD b/python/test/unit/BUILD new file mode 100644 index 000000000000..f75527bab1f7 --- /dev/null +++ b/python/test/unit/BUILD @@ -0,0 +1,180 @@ +load("//third_party/py/pytest:pytest_defs.bzl", "pytest_multi_tests", "pytest_test") + +package( + default_applicable_licenses = ["//:license"], +) + +_requires_gpu_sm80 = [ + "config-cuda-only", + "requires-gpu-sm80", +] + +_requires_config_cuda = select( + {"@local_config_cuda//cuda:using_clang_allow_exec": []}, + no_match_error = "Requires --config=cuda", +) + +EXCLUDE_TESTS = [ + "language/test_reproducer.py", # this is not an actual test, but a tool for running reproducers + "language/test_subprocess.py", # TODO(b/320224484): fix failing test + "runtime/test_launch.py", # TODO(b/320226169): fix failing tests + "tools/test_aot.py", # TODO(b/320224484): fix failing test + "tools/test_disasm.py", # TODO(b/320224484): fix failing test + "hopper/test_persistent_warp_specialized_gemm.py", # TODO (b/342348738): fix failing test + "runtime/test_cublas.py", # TODO(b/346755023): fix failing test +] + +# Runs all python tests on H100 +pytest_multi_tests( + name = "hopper", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + name_suffix = "_h100", + shard_count = 10, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["**/test_*.py"], + exclude = EXCLUDE_TESTS + [ + "language/test_core.py", + "language/test_pipeliner.py", # TODO(b/362458006): fix failing test + "hopper/test_experimental_tma.py", # TODO(b/362458006): fix failing test + ], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "hopper/language/test_core_h100", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = [ + "config-cuda-only", + "requires-gpu-sm90", + ], + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "language", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + "language/test_core.py", + ], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["language/**/test_*.py"], + exclude = EXCLUDE_TESTS + ["language/test_core.py"], + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +# Shard test_core more, as it is otherwise very slow to run. +pytest_test( + name = "language/test_core", + size = "large", + srcs = [ + "conftest.py", + "language/conftest.py", + ], + shard_count = 40, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = ["language/test_core.py"], + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "instrumentation", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["instrumentation/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "runtime", + srcs = ["conftest.py"], + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["runtime/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "tools", + size = "large", + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["tools/**/test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) + +pytest_multi_tests( + name = "unit", + size = "large", + srcs = ["conftest.py"], + shard_count = 10, + tags = _requires_gpu_sm80, + target_compatible_with = _requires_config_cuda, + tests = glob( + include = ["test_*.py"], + exclude = EXCLUDE_TESTS, + ), + deps = [ + "//third_party/py/torch:pytorch", + "//third_party/py/triton", + ], +) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 039f7ac1ac4f..3d1cbc5a82f0 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2139,6 +2139,8 @@ def kernel(X, Z, BLOCK: tl.constexpr): reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] >= 9, + reason='Reduction test produces wrong results on H100, b/342347027') @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + @@ -3642,6 +3644,25 @@ def _kernel(out): kernel[(1, )](out) assert torch.all(out == out_ref) +@pytest.mark.interpreter +def test_dot_on_broadcast(device): + @triton.jit + def _kernel(a, b, out): + a_offsets = tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + lhs = tl.load(a + a_offsets, mask=a_offsets < 32 * 64) + rhs = tl.load(b) + rhs_bc = tl.broadcast_to(rhs, [32, 32]) + c = tl.dot(lhs, rhs_bc) + out_ptr = out + tl.arange(0, 64)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(out_ptr, c) + + a = torch.ones((64, 32), dtype=getattr(torch, 'float32'), device=device) + b = torch.tensor([1.0], dtype=getattr(torch, 'float32'), device=device) + out_ref = torch.matmul(a, torch.broadcast_to(b, (32, 32))) + out = torch.zeros((64, 32), dtype=getattr(torch, 'float32'), device=device) + _kernel[(1, )](a, b, out, num_stages=1, num_warps=4) + assert torch.all(out == out_ref) + # --------------- # test arange diff --git a/python/triton/_C/include b/python/triton/_C/include index b85a409837d1..8a5dba6c4b56 120000 --- a/python/triton/_C/include +++ b/python/triton/_C/include @@ -1 +1 @@ -../../../include/ \ No newline at end of file +../../../include \ No newline at end of file diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba97b..f9bab523bf6c 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -46,5 +46,8 @@ def _discover_backends(): _find_concrete_subclasses(driver, DriverBase)) return backends - -backends = _discover_backends() +from triton.backends.nvidia.driver import CudaDriver +from triton.backends.nvidia.compiler import CUDABackend +backends = { + "nvidia": Backend(CUDABackend, CudaDriver) +} diff --git a/test/BUILD b/test/BUILD new file mode 100644 index 000000000000..0379d89208e9 --- /dev/null +++ b/test/BUILD @@ -0,0 +1,63 @@ +# copybara:uncomment_begin +# load("//third_party/llvm/build_defs:lit.bzl", "glob_lit_tests") +# load("//tools/build_defs/build_test:build_test.bzl", "build_test") +# +# package( +# default_applicable_licenses = ["//:license"], +# default_compatible_with = ["//buildenv/target:non_prod"], +# default_visibility = ["//:__subpackages__"], +# ) +# +# glob_lit_tests( +# name = "all_tests", +# data = [ +# "@llvm-project//llvm:FileCheck", +# "//:triton-llvm-opt", +# "//:triton-opt", +# "//:triton-tensor-layout", +# ], +# driver = "@llvm-project//mlir:run_lit.sh", +# exclude = [ +# "Conversion/amd/dedup-by-constancy.mlir", # AMD-specific, broken +# "TritonGPU/dot-operands.mlir", # TODO: b/283035396 - broken by cl536931041.patch +# "TritonGPU/optimize_epilogue.mlir", # TODO: b/346283526 - AMD-specific, triggering UBSAN +# ], +# test_file_exts = [ +# "mlir", +# "ll", +# ], +# ) +# +# build_test( +# name = "build_test", +# allow_empty_target = False, +# targets = [ +# "//:TritonAnalysis", +# "//:TritonDialects", +# "//:TritonGPUToLLVM", +# "//:TritonGPUTransforms", +# "//:TritonLLVMIR", +# "//:TritonPTX", +# "//:TritonToTritonGPU", +# "//:TritonTools", +# "//:TritonTransforms", +# "//:triton-opt", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "TritonTestAnalysis", + srcs = glob(["lib/Analysis/*.cpp"]), + deps = [ + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 728fd8eadfd9..62a2d469996a 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -143,7 +143,6 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : } } - // ----- // Verify that we use mmav2 when the k dim is too small for mmav3. @@ -159,3 +158,21 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return %result : tensor<128x128xf32, #blocked> } } + +// ----- + +// CHECK-DAG: #[[$BLOCKED:.*]] = #triton_gpu.blocked +// CHECK-DAG: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3 +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @local_alloc_dot_operand(%in0: tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> {tt.divisibility = 16 : i32}, %in1: f32, %in2: tensor<64x32xf32, #blocked>) -> (tensor<64x32xf32, #blocked>) { + // CHECK-LABEL: local_alloc_dot_operand + %splat_in1 = tt.splat %in1 : f32 -> tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + // CHECK: %[[LHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc + // CHECK: %[[RHS_CVT:.*]] = triton_gpu.convert_layout {{.*}} #triton_gpu.dot_op<{{.*}}> -> {{.*}} #[[$BLOCKED]] + // CHECK: %[[RHS_LOCAL_ALLOC:.*]] = triton_gpu.local_alloc %[[RHS_CVT]] + // CHECK: triton_nvidia_gpu.warp_group_dot %[[LHS_LOCAL_ALLOC]], %[[RHS_LOCAL_ALLOC]] + %res = tt.dot %in0, %splat_in1, %in2, inputPrecision = tf32 : tensor<64x256xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x32xf32, #blocked> + tt.return %res : tensor<64x32xf32, #blocked> + } +} diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index ecee359cb19a..f015f9651065 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -133,3 +133,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return %2 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> } } // end module + +// ----- + +// CHECK: #[[$BLOCKED:.*]] = #triton_gpu.blocked +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared1 = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @cvt_from_dot_op_into_local_allow_not_canonicalized(%in: tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> !tt.memdesc<256x32xf32, #shared1> { + // CHECK-LABEL: cvt_from_dot_op_into_local_allow_not_canonicalized + %cvt_in = triton_gpu.convert_layout %in : tensor<256x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<256x32xf32, #blocked> + %alloc = triton_gpu.local_alloc %cvt_in : (tensor<256x32xf32, #blocked>) -> !tt.memdesc<256x32xf32, #shared1> + // CHECK: %[[ALLOC:.*]] = triton_gpu.local_alloc {{.*}} (tensor<{{.*}}, #[[$BLOCKED]]{{.*}}>) -> + tt.return %alloc : !tt.memdesc<256x32xf32, #shared1> + } +} // end module + diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 9fbc540b92a6..f178eb24050a 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -245,3 +245,20 @@ tt.func @matmul_loop_mixed_amd(%lb : index, %ub : index, %step : index, %A : !tt } // end module // ----- + +// CHECK: tt.func @matmul_loop_on_blocked_layout +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func @matmul_loop_on_blocked_layout(%arg_lhs: !tt.memdesc<16x512xf32, #shared, mutable>, %arg_rhs: !tt.memdesc<512x32xf32, #shared, mutable>, %arg_init: tensor<16x32xf32, #blocked>, %itr_val : i32) -> (tensor<16x32xf32, #blocked>) { + %loop:3 = scf.for %itr = %itr_val to %itr_val step %itr_val iter_args(%init = %arg_init, %lhs = %arg_lhs, %rhs = %arg_rhs) -> (tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable>) : i32 { + %lhs_ll = triton_gpu.local_load %lhs : !tt.memdesc<16x512xf32, #shared, mutable> -> tensor<16x512xf32, #blocked> + %lhs_ll_cvt = triton_gpu.convert_layout %lhs_ll : tensor<16x512xf32, #blocked> -> tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %rhs_ll = triton_gpu.local_load %rhs : !tt.memdesc<512x32xf32, #shared, mutable> -> tensor<512x32xf32, #blocked> + %rhs_ll_cvt = triton_gpu.convert_layout %rhs_ll : tensor<512x32xf32, #blocked> -> tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %res = tt.dot %lhs_ll_cvt, %rhs_ll_cvt, %init : tensor<16x512xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<512x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x32xf32, #blocked> + scf.yield %res, %lhs, %rhs : tensor<16x32xf32, #blocked>, !tt.memdesc<16x512xf32, #shared, mutable>, !tt.memdesc<512x32xf32, #shared, mutable> + } + tt.return %loop#0 : tensor<16x32xf32, #blocked> + } +} // end module diff --git a/third_party/amd/BUILD b/third_party/amd/BUILD new file mode 100644 index 000000000000..bbdf7408f85e --- /dev/null +++ b/third_party/amd/BUILD @@ -0,0 +1,250 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu/fusions/triton:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# TODO(csigg): fix, enable error upstream, remove. +_no_unused_variable = select({ + "//:compiler_is_msvc": [], + "//conditions:default": ["-Wno-unused-variable"], +}) + +cc_library( + name = "TritonAMDGPUTransforms", + srcs = glob([ + "lib/TritonAMDGPUTransforms/**/*.h", + "lib/TritonAMDGPUTransforms/**/*.cpp", + ]) + ["include/TritonAMDGPUToLLVM/TargetUtils.h"], + hdrs = glob([ + "include/TritonAMDGPUTransforms/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUTransforms", + ], + deps = [ + ":triton_conversion_amdgpu_transforms_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonGPUTransforms", + ], +) + +cc_library( + name = "TritonAMDGPU", + srcs = glob([ + "lib/Dialect/TritonAMDGPU/**/*.h", + "lib/Dialect/TritonAMDGPU/**/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/TritonAMDGPU/**/*.h", + ]), + includes = [ + "..", + "include", + ], + deps = [ + ":triton_amdgpu_attr_def_inc_gen", + ":triton_amdgpu_dialect_inc_gen", + ":triton_amdgpu_ops_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:TensorDialect", + ], +) + +cc_library( + name = "TritonAMDGPUToLLVM", + srcs = glob([ + "lib/TritonAMDGPUToLLVM/**/*.h", + "lib/TritonAMDGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonAMDGPUToLLVM/**/*.h", + ]), + copts = _no_unused_variable, + includes = [ + "include", + "lib/TritonAMDGPUToLLVM", + ], + deps = [ + ":TritonAMDGPU", + ":TritonAMDGPUDialectToLLVM", + ":TritonAMDGPUTransforms", + ":triton_conversion_amdgpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:ConvertToLLVM", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:GPUToROCDLTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ROCDLDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + ], +) + +cc_library( + name = "TritonAMDGPUDialectToLLVM", + srcs = glob([ + "lib/TritonAMDGPUDialectToLLVM/**/*.h", + "lib/TritonAMDGPUDialectToLLVM/**/*.cpp", + ]), + includes = [ + "include", + ], + deps = [ + "//:TritonGPUToLLVM", + ], +) + +td_library( + name = "td_files", + srcs = glob(["include/**/*.td"]), + includes = ["include"], + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_ops_inc_gen", + tbl_outs = [ + ( + [ + "--gen-llvmir-conversions", + ], + "include/Dialect/TritonAMDGPU/IR/OpsConversions.inc", + ), + ( + [ + "--gen-op-decls", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.h.inc", + ), + ( + [ + "--gen-op-defs", + ], + "include/Dialect/TritonAMDGPU/IR/Ops.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_dialect_inc_gen", + tbl_outs = [ + ( + [ + "--gen-dialect-decls", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.h.inc", + ), + ( + [ + "--gen-dialect-defs", + "--dialect=amdgpu", + ], + "include/Dialect/TritonAMDGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUDialect.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_amdgpu_attr_def_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_to_llvm_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPUToLLVM", + ], + "include/TritonAMDGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUToLLVM/Passes.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_amdgpu_transforms_passes_inc_gen", + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonAMDGPU", + ], + "include/TritonAMDGPUTransforms/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonAMDGPUTransforms/Passes.td", + deps = [":td_files"], +) diff --git a/third_party/f2reduce/BUILD b/third_party/f2reduce/BUILD new file mode 100644 index 000000000000..93829539e1b9 --- /dev/null +++ b/third_party/f2reduce/BUILD @@ -0,0 +1,31 @@ +# copybara:uncomment load("//tools/build_defs/license:license.bzl", "license") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +# copybara:uncomment_begin +# license( +# name = "license", +# license_text = "LICENCE.txt", +# ) +# +# licenses(["notice"]) +# +# exports_files(["LICENCE.txt"]) +# copybara:uncomment_end + +cc_library( + name = "f2reduce", + srcs = ["f2reduce.cpp"], + hdrs = ["f2reduce.h"], + # copybara:uncomment strip_include_prefix = "/third_party/triton", +) diff --git a/third_party/nvidia/BUILD b/third_party/nvidia/BUILD new file mode 100644 index 000000000000..f062b61a9ee6 --- /dev/null +++ b/third_party/nvidia/BUILD @@ -0,0 +1,306 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") + +package( + # copybara:uncomment_begin + # default_applicable_licenses = ["//:license"], + # default_compatible_with = ["//buildenv/target:non_prod"], + # default_visibility = [ + # "//third_party/tensorflow/compiler/xla/service/gpu:__subpackages__", + # "//:__subpackages__", + # ], + # copybara:uncomment_end_and_comment_begin + default_visibility = ["//visibility:public"], + # copybara:comment_end +) + +pybind_library( + name = "cublas_headers", + hdrs = glob([ + "include/*.h", + ]), + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + +pybind_library( + name = "triton_nvidia", + srcs = [ + "triton_nvidia.cc", + ], + compatible_with = [], + # copybara:uncomment_begin + # visibility = [ + # "@triton//python:__subpackages__", + # ], + # copybara:uncomment_end + deps = [ + ":NVGPUDialect", + ":NVGPUToLLVM", + ":TritonNVIDIAGPUToLLVM", + ":cublas_headers", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:NVVMToLLVMIRTranslation", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:TritonNvidiaGPUTransforms", + "@triton//python:passes", + ], +) + +cc_library( + name = "NVGPUToLLVM", + srcs = glob([ + "lib/NVGPUToLLVM/*.cpp", + ]), + hdrs = glob([ + "include/NVGPUToLLVM/*.h", + ]), + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + ], + deps = [ + ":NVGPUDialect", + ":TritonNVIDIAGPUToLLVM", + ":triton_conversion_nvgpu_to_llvm_passes_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonDialects", + ], +) + +cc_library( + name = "TritonNVIDIAGPUToLLVM", + srcs = glob([ + "lib/TritonNVIDIAGPUToLLVM/*.h", + "lib/TritonNVIDIAGPUToLLVM/**/*.cpp", + ]), + hdrs = glob([ + "include/TritonNVIDIAGPUToLLVM/*.h", + ]) + [ + "lib/TritonNVIDIAGPUToLLVM/Utility.h", + ], + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + copts = select({ + "//conditions:default": [ + "-Wno-reorder-ctor", + "-Wno-unused-variable", + ], + }), + includes = [ + "..", + "include", + "lib/TritonNVIDIAGPUToLLVM", + ], + deps = [ + ":NVGPUDialect", + ":triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:IndexDialect", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//:TritonAnalysis", + "//:TritonDialects", + "//:TritonGPUToLLVM", + "//:triton_gpu_attr_inc_gen", + ], +) + +gentbl_cc_library( + name = "triton_conversion_nvgpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=NVGPUToLLVM", + ], + "include/NVGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/NVGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +gentbl_cc_library( + name = "triton_conversion_triton_nvidia_gpu_to_llvm_passes_inc_gen", + # copybara:uncomment_begin + # compatible_with = ["//buildenv/target:non_prod"], + # copybara:uncomment_end + tbl_outs = [ + ( + [ + "--gen-pass-decls", + "--name=TritonNVIDIAGPUToLLVM", + ], + "include/TritonNVIDIAGPUToLLVM/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/TritonNVIDIAGPUToLLVM/Passes.td", + deps = ["//:td_files"], +) + +td_library( + name = "td_files", + srcs = glob(["include/Dialect/NVGPU/IR/*.td"]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:CastInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:LLVMOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + +gentbl_cc_library( + name = "nvgpu_ops_inc_gen", + tbl_outs = [ + ( + ["--gen-llvmir-conversions"], + "include/Dialect/NVGPU/IR/OpsConversions.inc", + ), + ( + ["--gen-op-decls"], + "include/Dialect/NVGPU/IR/Ops.h.inc", + ), + ( + ["--gen-op-defs"], + "include/Dialect/NVGPU/IR/Ops.cpp.inc", + ), + ( + ["--gen-enum-decls"], + "include/Dialect/NVGPU/IR/OpsEnums.h.inc", + ), + ( + ["--gen-enum-defs"], + "include/Dialect/NVGPU/IR/OpsEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUOps.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_attr_inc_gen", + tbl_outs = [ + ( + ["--gen-attrdef-decls"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc", + ), + ( + ["--gen-attrdef-defs"], + "include/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUAttrDefs.td", + deps = ["td_files"], +) + +gentbl_cc_library( + name = "nvgpu_dialect_inc_gen", + tbl_outs = [ + ( + ["--gen-dialect-decls"], + "include/Dialect/NVGPU/IR/Dialect.h.inc", + ), + ( + ["--gen-dialect-defs"], + "include/Dialect/NVGPU/IR/Dialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/Dialect/NVGPU/IR/NVGPUDialect.td", + deps = ["td_files"], +) + +cc_library( + name = "NVGPUDialect", + srcs = glob([ + "lib/Dialect/NVGPU/IR/*.cpp", + ]), + hdrs = glob([ + "include/Dialect/NVGPU/IR/*.h", + ]), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + "-Wno-logical-op-parentheses", + ], + }), + includes = [ + "..", # because nvidia/include/Dialect/NVGPU/IR/Dialect.h.inc + "../..", # because third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h + "include", + ], + deps = [ + ":nvgpu_attr_inc_gen", + ":nvgpu_dialect_inc_gen", + ":nvgpu_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + # The following is added to make Utility compile + "//:TritonTools", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/third_party/nvidia/backend/BUILD b/third_party/nvidia/backend/BUILD new file mode 100644 index 000000000000..a5b34aa5c29b --- /dev/null +++ b/third_party/nvidia/backend/BUILD @@ -0,0 +1,30 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//third_party/py/triton:__subpackages__", + ], +) + +pybind_extension( + name = "cuda_utils", + srcs = ["cuda_utils.cc"], + visibility = [ + "//learning/deepmind/jax/triton/ops:__subpackages__", + "//third_party/py/triton:__subpackages__", + ], + deps = [ + "//platforms/gpus/cuda/dynamic_libcuda", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cuda_runtime", + "@llvm-project//llvm:Support", + ], +) + +filegroup( + name = "files", + srcs = glob( + include = ["**/*.py"], + ), +) diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index bb0d86888120..19c732c354d1 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -154,6 +154,7 @@ static PyObject *loadBinary(PyObject *self, PyObject *args) { typedef CUresult (*cuOccupancyMaxActiveClusters_t)( int *numClusters, CUfunction func, const CUlaunchConfig *config); +#if CUDA_VERSION >= 12000 typedef CUresult (*cuTensorMapEncodeTiled_t)( CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim, @@ -161,6 +162,7 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( const cuuint32_t *elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); +#endif #define defineGetFunctionHandle(name, symbolName) \ static symbolName##_t name() { \ @@ -187,8 +189,10 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, cuOccupancyMaxActiveClusters); +#if CUDA_VERSION >= 12000 defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, cuTensorMapEncodeTiled); +#endif static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, @@ -281,6 +285,9 @@ static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) { // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else unsigned long long global_address; uint64_t dim; uint32_t tensorDim; @@ -321,11 +328,15 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); Py_INCREF(Py_None); return Py_None; +#endif } // Simple helper to experiment creating TMA descriptors on the host. // This is a useful to test TMA operations independently. static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { +#if CUDA_VERSION < 12000 + return NULL; +#else unsigned long long global_address; uint64_t dims[2]; uint32_t tensorDims[2]; @@ -384,6 +395,7 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); Py_INCREF(Py_None); return Py_None; +#endif } static PyMethodDef ModuleMethods[] = { diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 8de0efefca84..637071275e39 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -291,10 +291,36 @@ class WGMMAWaitGroupOpPattern : public OpRewritePattern { Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const { auto outputStructType = cast(op.getType()); - uint32_t numOutputRegs = outputStructType.getBody().size(); - std::string output = - outputStructType.getBody().front().isF32() ? "=f" : "=r"; - return Constraints(numOutputRegs, output); + std::vector outputConstraints; + outputConstraints.reserve(outputStructType.getBody().size()); + for (mlir::Type type : outputStructType.getBody()) { + if (type.isF32()) { + outputConstraints.push_back("=f"); + continue; + } else if (type.isF64()) { + outputConstraints.push_back("=d"); + continue; + } + unsigned bitwidth = isa(type) ? + 64 : type.getIntOrFloatBitWidth(); + switch (bitwidth) { + case 1: + outputConstraints.push_back("=b"); + break; + case 16: + outputConstraints.push_back("=h"); + break; + case 32: + outputConstraints.push_back("=r"); + break; + case 64: + outputConstraints.push_back("=l"); + break; + default: + assert(false && "unsupported bitwidth"); + } + } + return outputConstraints; } OperandsAndConstraints diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 1269dcda00aa..3cccc5fb6a1c 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -1,4 +1,4 @@ -#include "Dialect/NVGPU/IR/Dialect.h" +#include "Dialect/NVGPU/IR/Dialect.h" #include "NVGPUToLLVM/NVGPUToLLVMPass.h" #include "TritonNVIDIAGPUToLLVM/Passes.h" #include "cublas_instance.h" diff --git a/third_party/proton/proton/_C/include b/third_party/proton/proton/_C/include index fe4f4a1aa9bd..4400934bdf78 120000 --- a/third_party/proton/proton/_C/include +++ b/third_party/proton/proton/_C/include @@ -1 +1 @@ -../../csrc/include/ \ No newline at end of file +../../csrc/include \ No newline at end of file diff --git a/unittest/BUILD b/unittest/BUILD new file mode 100644 index 000000000000..4cbadcfa4655 --- /dev/null +++ b/unittest/BUILD @@ -0,0 +1,144 @@ +load("//tools/build_defs/build_test:build_test.bzl", "build_test") + +package( + default_applicable_licenses = ["//:license"], + default_compatible_with = ["//buildenv/target:non_prod"], + default_visibility = ["//:__subpackages__"], +) + +cc_test( + name = "AnalysisTest", + srcs = glob(["Analysis/*.cpp"]), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTestCatchAll", + srcs = glob( + [ + "Dialect/**/*.cpp", + ], + exclude = [ + "Dialect/TritonGPU/DialectTest.cpp", + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "DialectTest", + srcs = [ + "Dialect/TritonGPU/DialectTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "LinearLayoutConversionsTest", + srcs = [ + "Dialect/TritonGPU/LinearLayoutConversionsTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "SwizzleTest", + srcs = [ + "Dialect/TritonGPU/SwizzleTest.cpp", + ], + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "//:TritonDialects", + ], +) + +cc_test( + name = "ConversionTest", + srcs = glob( + [ + "Conversion/**/*.cpp", + "Conversion/**/*.h", + ], + exclude = [ + "Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.cpp", + "Conversion/TritonGPUToLLVM/DumpLayout.h", + ], + ), + copts = select({ + "//:compiler_is_msvc": [], + "//conditions:default": [ + "-Wno-unused-variable", + ], + }), + deps = [ + "//testing/base/public:gunit_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "//:TritonDialects", + "//:TritonNvidiaGPUTransforms", + "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", + ], +) + +build_test( + name = "build_test", + allow_empty_target = False, + targets = [ + ":ConversionTest", + ":AnalysisTest", + ":DialectTest", + ], +) From 83cd6319efe942256fd03922cff9e3d433a61312 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Mon, 23 Sep 2024 18:23:06 +0000 Subject: [PATCH 03/18] Add preliminary logic to hoist elt-wise ops for MMAv3 --- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 12 +- .../Transforms/OptimizeDotOperands.cpp | 238 +++++++++++++++--- .../Pipeliner/MatmulLoopPipeline.cpp | 8 + test/TritonGPU/dot-operands.mlir | 24 ++ .../ConvertLayoutOpToLLVM.cpp | 34 ++- .../DecomposeUnsupportedConversions.cpp | 8 +- 6 files changed, 287 insertions(+), 37 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 2d21a1bfeca2..8a2176c55b12 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -87,8 +87,12 @@ SmallVector unpackI32(const SmallVector &inValues, Type srcTy, if (!tensorTy) return inValues; auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) + if (!encoding) + return inValues; + auto parentEnc = dyn_cast(encoding.getParent()); + if (!parentEnc || parentEnc.isHopper()) return inValues; + SmallVector outValues; for (auto v : inValues) { // cast i32 to appropriate eltType vector and extract elements @@ -109,8 +113,12 @@ SmallVector packI32(const SmallVector &inValues, Type srcTy, if (!tensorTy) return inValues; auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) + if (!encoding) + return inValues; + auto parentEnc = dyn_cast(encoding.getParent()); + if (!parentEnc || parentEnc.isHopper()) return inValues; + SmallVector outValues; auto eltType = typeConverter->convertType(tensorTy.getElementType()); int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index e6e0ec8d7cef..93621fe008a0 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -4,6 +4,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -15,6 +16,66 @@ namespace gpu { namespace { +// Helpers + +// Returns whether we can hoist DotOp Encoding through `op`. +// Roughly, whether op is elementwise and thus threads don't need +// to exchange elements. But some ops are not current supported even though +// they meet that criterion. +bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) { + // Only consider custom conversions or arith ops. + // TODO(jlebar): Is this too restrictive? + if (!isa(op) && !isPureUnaryInlineAsm(op) && + op->getDialect()->getTypeID() != TypeID::get()) + return false; + + // Quick handling to fix loading issues when computing the original + // bitwidth is unable to realize that there is a mixed-precision dot + // (hence kWidth = 1) but wants to hoist through the type conversion. + if (isa(op) && dotOpEnc.getKWidth() == 1) + return false; + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(op)) + return false; + + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(op)) { + Type opType = getElementTypeOrSelf(op->getOperand(0)); + if (opType.isInteger(1)) + return false; + } + + return true; +} + +bool canHoistDotOpEncV3(Operation* op) { + // Only consider custom conversions or arith ops. + // TODO(jlebar): Is this too restrictive? + if (!isa(op) && !isPureUnaryInlineAsm(op) && + op->getDialect()->getTypeID() != TypeID::get()) + return false; + + // Currently, these instructions are not supported during lowering of + // shared -> dot_operand layout. Not all types and type conversions are + // supported. + if (isa(op)) + return false; + + // Don't hoist through u1 -> fp casts as they aren't supported in + // ElementwiseOpToLLVM::reorderValues(). + if (isa(op)) { + Type opType = getElementTypeOrSelf(op->getOperand(0)); + if (opType.isInteger(1)) + return false; + } + + return true; +} + // Given // convert(trans(src)) #dot_operand -> // convert(local_load(trans(alloc(src)))) @@ -127,32 +188,9 @@ class HoistLayoutConversion : public OpRewritePattern { [](Type ty) { return isa(ty); })) return failure(); - // Quick handling to fix loading issues when computing the original - // bitwidth is unable to realize that there is a mixed-precision dot - // (hence kWidth = 1) but wants to hoist through the type conversion. - if (isa(src) && dotOpEnc.getKWidth() == 1) - return failure(); - - // Only consider custom conversions or arith ops. - // TODO(jlebar): Is this too restrictive? - if (!isa(src) && !isPureUnaryInlineAsm(src) && - src->getDialect()->getTypeID() != TypeID::get()) + if (!canHoistDotOpEncV2(src, dotOpEnc)) return failure(); - // Currently, these instructions are not supported during lowering of - // shared -> dot_operand layout. Not all types and type conversions are - // supported. - if (isa(src)) - return failure(); - - // Don't hoist through u1 -> fp casts as they aren't supported in - // ElementwiseOpToLLVM::reorderValues(). - if (isa(src)) { - Type srcType = getElementTypeOrSelf(src->getOperand(0)); - if (srcType.isInteger(1)) - return failure(); - } - // Check that the conversion is transitively dependent on a load, and all // operations between the load and the conversion are layout preserving. // @@ -180,12 +218,7 @@ class HoistLayoutConversion : public OpRewritePattern { if (isa(currOp)) { foundLoad = true; } else if (foundLoad) { - // Bail out if there exists an op after Load that is not FpToFp, - // Bitcast, or Arith. - if (!isa(currOp) && - !isPureUnaryInlineAsm(currOp) && - currOp->getDialect()->getTypeID() != - TypeID::get()) + if (!canHoistDotOpEncV2(currOp, dotOpEnc)) return failure(); } } @@ -315,6 +348,147 @@ struct MMAV3UseRegOperand } }; +// TODO(ggengnv) more tests (multiple elt-wise ops) and document +struct MMAV3HoistLayoutConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp, + PatternRewriter &rewriter) const override { + auto alloc = dotOp.getOperand(0).getDefiningOp(); + if (!alloc || !alloc.getSrc()) + return failure(); + + auto getEncoding = [](Value v) { + return cast(v.getType()).getEncoding(); + }; + + if (!isa(getEncoding(dotOp.getOperand(0)))) + return failure(); + + // Performs checks for early stop + NvidiaMmaEncodingAttr dstEnc; + { + auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); + dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); + // Want: A's Encoding to be Blocked and D's encoding to be NvidiaMmA v3 + if (!srcEnc || !dstEnc || dstEnc.getVersionMajor() != 3) + return failure(); + + auto src = alloc.getSrc().getDefiningOp(); + + // Value passed to alloc must have Tensor arguments and single Tensor result + if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1) + return failure(); + if (!all_of(src->getOperandTypes(), + [](Type ty) { return isa(ty); })) + return failure(); + auto srcTy = dyn_cast(src->getResult(0).getType()); + if (!srcTy) + return failure(); + + if (!canHoistDotOpEncV3(src)) + return failure(); + } + + SetVector slice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + opt.filter = [&](Operation *op) { + return (op->getParentRegion() == alloc->getParentRegion()) && !isa(op) + && (op->getNumOperands() != 0); // Ensures all ops in slice have operands + }; + + getBackwardSlice(alloc.getOperation(), &slice, opt); + + auto isBlockedRankedTensor = [&](auto val) { + return isa(getEncoding(val)) && isa(val.getType()); + }; + + SmallVector frontierOps; + for (Operation *currOp : slice) { + if (!canHoistDotOpEncV3(currOp)) + return failure(); + + // We previously ensured that all ops in slice have at least one operand + bool isFrontier = false; + for (auto operand : currOp->getOperands()) { + auto op = operand.getDefiningOp(); + if (!slice.contains(op)) { + // TODO that this is overly restrictive. Can add support for ConstantOp and LocalLoad + if (!isa(op)) + return failure(); + + isFrontier = true; + } + } + + if (isFrontier) { + if (!isa(currOp->getOperand(0).getDefiningOp())) + return failure(); + + auto res = currOp->getResult(0); + if (!isBlockedRankedTensor(res)) + return failure(); + + if (!llvm::all_of(currOp->getOperands(), isBlockedRankedTensor)) + return failure(); + + frontierOps.push_back(currOp); + } + } + + // Nothing to hoist through + if (frontierOps.empty()) + return failure(); + + auto dotOperandEnc = DotOperandEncodingAttr::get( + dotOp.getContext(), /*opIdx=*/0, dstEnc, /*kWidth=*/0); + + // For each frontierOp: + // load; frontierOp; ...; warp_group_dot + // -> load; local_alloc; local_load; convert_layout; frontierOp; ...; warp_group_dot + for (Operation *frontierOp : frontierOps) { + auto frontierTy = dyn_cast(frontierOp->getResult(0).getType()); + + SmallVector newOperands; + for (auto operand : frontierOp->getOperands()) { + // We checked earlier that all operands are ranked tensors. + auto operandTy = cast(operand.getType()); + auto operandEltTy = operandTy.getElementType(); + + auto oldAllocTy = alloc.getType(); + // TODO(ggengnv) previous encoding (oldAllocTy.getEncoding()) was for shared operand. + // Is it still appropriate for loading into registers? + auto newAllocTy = MemDescType::get(operandTy.getShape(), operandEltTy, + oldAllocTy.getEncoding(), oldAllocTy.getMemorySpace()); + auto localAlloc = rewriter.create(alloc.getLoc(), newAllocTy, operand); + auto localLoad = rewriter.create(alloc.getLoc(), operandTy, localAlloc); + + Type cvtTy = RankedTensorType::get( + operandTy.getShape(), operandTy.getElementType(), dotOperandEnc); + auto cvt = rewriter.create(alloc.getLoc(), cvtTy, localLoad); + + newOperands.push_back(cvt); + } + + auto newFrontier = rewriter.clone(*frontierOp); + for (int i = 0; i < newOperands.size(); i++) + newFrontier->setOperand(i, newOperands[i]); + newFrontier->getResult(0).setType(RankedTensorType::get( + frontierTy.getShape(), frontierTy.getElementType(), dotOperandEnc)); + + rewriter.replaceOp(frontierOp, newFrontier); + } + + // replace LHS operand with its parent (in dotOpEnc) + rewriter.modifyOpInPlace(dotOp, [&]() { dotOp.setOperand(0, alloc.getSrc()); }); + + return success(); + } +}; + } // namespace #define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS @@ -337,10 +511,12 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); - if (this->hoistLayoutConversion.getValue()) + if (this->hoistLayoutConversion.getValue()) { patterns.add(context); + } patterns.add(context); patterns.add(context); + patterns.add(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) signalPassFailure(); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index dc5f395c6753..9c9fc8983a8d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -380,6 +380,14 @@ static bool loadIsMMAv3(Operation *loadOp) { if (!sharedEnc.getHasLeadingOffset()) return false; + // In case LHS is in registers, don't pipeline for now + auto op = *alloc->getUsers().begin(); + if (auto localLoad = dyn_cast(op)) { + auto resTy = cast(localLoad->getResultTypes()[0]); + if (!resTy || isa(resTy.getEncoding())) + return false; + } + // MMA V3 case. auto newOrder = sharedEnc.getOrder(); auto ty = cast(loadOp->getResultTypes()[0]); diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 82fc1ddf7b65..5fc02aa5e3d9 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -211,3 +211,27 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : tt.return %td : tensor<128x128xf32, #mma> } } + +// ----- + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise +// CHECK: %[[A_LOADED:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_MEMDESC:.*]] = triton_gpu.local_alloc %[[A_LOADED]] : (tensor<128x64xbf16, #blocked>) -> !tt.memdesc<128x64xbf16, #shared> +// CHECK: %[[A_REG:.*]] = triton_gpu.local_load %[[A_MEMDESC]] : !tt.memdesc<128x64xbf16, #shared> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> +// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_REG]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> +// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !tt.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %a_bf16 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a = tt.fp_to_fp %a_bf16 : tensor<128x64xbf16, #blocked> -> tensor<128x64xf16, #blocked> + %dota = triton_gpu.local_alloc %a: (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared1> + %r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} + diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 8fb44ce644ba..3d9ab5f5fb7c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -40,7 +40,27 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread); +} // namespace SharedToDotOperandMMAv2 + +namespace SharedToDotOperandMMAv3 { +Value convertLayout(ConversionPatternRewriter &rewriter, + Location loc, Value tensor, + DotOperandEncodingAttr bEncoding, + const SharedMemoryObject &smemObj, + const LLVMTypeConverter *typeConverter, Value thread) { + SmallVector elems; + // TODO(ggengnv) fix + for (int i = 0; i < 16; i++) { + elems.push_back(int_val(8, 0)); + } + Type elemTy = elems[0].getType(); + MLIRContext *ctx = elemTy.getContext(); + Type structTy = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(elems.size(), elemTy)); + auto result = packLLElements(loc, typeConverter, elems, rewriter, structTy); + return result; } +} // namespace SharedToDotOperandMMAv3 namespace { @@ -88,11 +108,21 @@ struct LocalLoadOpConversion auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), llvmElemTy, rewriter); Value res; - if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 + + if (isOuter) { + assert(false && "MMA Layout does not support outer product"); + return res; + } + + if (mmaLayout.isHopper()) { // tensor core v3 + assert(dotOperandLayout.getOpIdx() == 0); + res = SharedToDotOperandMMAv3::convertLayout(rewriter, loc, src, + dotOperandLayout, smemObj, typeConverter, getThreadId(rewriter, loc)); + } else if (mmaLayout.isAmpere()) { // tensor core v2 res = SharedToDotOperandMMAv2::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, smemObj, typeConverter, getThreadId(rewriter, loc)); - } else if (!isOuter && mmaLayout.isVolta() && isMMA) { // tensor core v1 + } else if (mmaLayout.isVolta() && isMMA) { // tensor core v1 bool isMMAv1Row = mmaLayout.getMMAv1IsRow(dotOperandLayout.getOpIdx()); auto srcSharedLayout = cast(src.getType().getEncoding()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index cf0ddc248dd1..3c52fe77e74a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -36,10 +36,14 @@ class DecomposeLocalLoadToDotOperand op.getType().getEncoding()); MemDescType srcType = op.getSrc().getType(); auto sharedEncoding = cast(srcType.getEncoding()); - if (!dstDotOp || !sharedEncoding.getHasLeadingOffset()) + if (!dstDotOp) return failure(); + + auto parentEnc = cast(dstDotOp.getParent()) ; + if (!parentEnc || parentEnc.getVersionMajor() == 3 || !sharedEncoding.getHasLeadingOffset()) + return failure(); + RankedTensorType type = op.getType(); - auto parentEnc = dstDotOp.getParent(); int numWarps = triton::gpu::getNumWarpsPerCTA(parentEnc); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp( op->getParentOfType()); From 30a670ca3687b124a3a113cd71ebbcaa07c48c69 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Fri, 30 Aug 2024 22:44:26 +0000 Subject: [PATCH 04/18] Lower shared > v3 dotOp & improve hoisting logic --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 13 +-- .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 22 +++-- .../Transforms/OptimizeDotOperands.cpp | 53 +++++++----- .../Pipeliner/MatmulLoopPipeline.cpp | 2 +- .../ConvertLayoutOpToLLVM.cpp | 25 +----- .../SharedToDotOperandMMAv2.cpp | 83 +++++++++++++++---- .../DecomposeUnsupportedConversions.cpp | 8 +- .../DotOpToLLVM/WGMMA.cpp | 48 ++++++++--- 9 files changed, 163 insertions(+), 93 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 3912191f4f3e..c03ed737aa2d 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false. return get(context, vec, perPhase, maxPhase, order, CTALayout); } - // ---- begin Ampere ---- - if (mmaEnc.isAmpere()) { + // ---- begin Ampere & Hopper ---- + if (mmaEnc.isAmpere() || mmaEnc.isHopper()) { int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth()); perPhase = std::max(perPhase, 1); std::vector matShape = {8, 8, 4 * dotOpEnc.getKWidth()}; @@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false. llvm_unreachable("invalid operand index"); } - // ---- begin version 3 ---- - if (mmaEnc.isHopper()) { - llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr" - " is Hopper has not been implemented yet"); - return $_get(context, 1, 1, 1, order, CTALayout, true); - } - // ---- not implemented ---- llvm_unreachable("unsupported swizzling for provided MMA version"); }]>, @@ -1332,7 +1325,7 @@ elements along the K dim, or they use all elements of the tensor along the K dim "Attribute":$parent, "Type":$eltTy), [{ NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); - if (!parentAttr || !parentAttr.isAmpere()) + if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper())) return $_get(context, opIdx, parent, 0); unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); unsigned MMAv2kWidth = 32 / bitwidth; diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 8a2176c55b12..86600a0f602e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -33,7 +33,7 @@ SmallVector reorderValues(const SmallVector &values, Type inType, // If the parent of the dot operand is in block encoding, we don't need to // reorder elements auto parentEncoding = dyn_cast(ouEncoding.getParent()); - if (!parentEncoding) + if (!parentEncoding || parentEncoding.isHopper()) return values; size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 6bc4ca6f9eae..54ee7be91150 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1022,13 +1022,17 @@ LogicalResult DotOperandEncodingAttr::verify( return emitError() << "triton_gpu.dot_op parent paramenter cannot be null"; } if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 0 && !parentAttr.isAmpere()) + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) return emitError() << "triton_gpu.dot_op kWidth parameter can only be " - "non-zero for Ampere MMA parent"; - if (kWidth == 0 && parentAttr.isAmpere()) + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) return emitError() << "triton_gpu.dot_op kWidth parameter is mandatory for " - "Ampere MMA parent"; + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) + return emitError() + << "triton_gpu.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent"; return success(); } @@ -1960,6 +1964,7 @@ int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, int bitwidth, int opIdx) const { + assert(isAmpere() || isHopper()); auto rank = shape.size(); auto warpsPerCTA = getWarpsPerCTA(); SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; @@ -1967,7 +1972,6 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, rank == 3 ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) : 1; - assert(isAmpere()); if (opIdx == 0) return {numRepBatch, @@ -1982,6 +1986,7 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, warpsPerCTA[rank - 1]))}; } } + unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); @@ -1989,7 +1994,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( int warpsPerCTAN = getWarpsPerCTA()[1]; // H100 if (isHopper()) { - return getTotalElemsPerThread(shape, eltTy); + assert(opIdx == 0); + auto instrMNK = getInstrShape(); + auto wpt = getWarpsPerCTA(); + int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repK = ceil(shapePerCTA[1], instrMNK[2]); + return 4 * kWidth * repM * repK; } // A100 if (isAmpere()) { diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 93621fe008a0..54a3a23c2686 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -62,7 +62,7 @@ bool canHoistDotOpEncV3(Operation* op) { // Currently, these instructions are not supported during lowering of // shared -> dot_operand layout. Not all types and type conversions are // supported. - if (isa(op)) + if (isa(op)) return false; // Don't hoist through u1 -> fp casts as they aren't supported in @@ -368,6 +368,7 @@ struct MMAV3HoistLayoutConversion // Performs checks for early stop NvidiaMmaEncodingAttr dstEnc; + Type inputEltTy; { auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); dstEnc = @@ -387,6 +388,7 @@ struct MMAV3HoistLayoutConversion auto srcTy = dyn_cast(src->getResult(0).getType()); if (!srcTy) return failure(); + inputEltTy = srcTy.getElementType(); if (!canHoistDotOpEncV3(src)) return failure(); @@ -396,7 +398,7 @@ struct MMAV3HoistLayoutConversion BackwardSliceOptions opt; opt.omitBlockArguments = true; opt.filter = [&](Operation *op) { - return (op->getParentRegion() == alloc->getParentRegion()) && !isa(op) + return (op->getParentRegion() == alloc->getParentRegion()) && !isa(op) && (op->getNumOperands() != 0); // Ensures all ops in slice have operands }; @@ -416,8 +418,7 @@ struct MMAV3HoistLayoutConversion for (auto operand : currOp->getOperands()) { auto op = operand.getDefiningOp(); if (!slice.contains(op)) { - // TODO that this is overly restrictive. Can add support for ConstantOp and LocalLoad - if (!isa(op)) + if (!isa(op)) return failure(); isFrontier = true; @@ -425,9 +426,6 @@ struct MMAV3HoistLayoutConversion } if (isFrontier) { - if (!isa(currOp->getOperand(0).getDefiningOp())) - return failure(); - auto res = currOp->getResult(0); if (!isBlockedRankedTensor(res)) return failure(); @@ -443,12 +441,16 @@ struct MMAV3HoistLayoutConversion if (frontierOps.empty()) return failure(); + // convert A operand auto dotOperandEnc = DotOperandEncodingAttr::get( - dotOp.getContext(), /*opIdx=*/0, dstEnc, /*kWidth=*/0); + dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy); // For each frontierOp: - // load; frontierOp; ...; warp_group_dot - // -> load; local_alloc; local_load; convert_layout; frontierOp; ...; warp_group_dot + // load; frontierOp; [hoistableOps...]; local_alloc; warp_group_dot + // -> load; local_alloc; local_load; convert_layout; frontierOp; [hoistableOps...]; warp_group_dot + // or... + // constant; frontierOp; [hoistableOps...]; warp_group_dot + // -> constant; convert_layout; frontierOp; [hoistableOps...]; warp_group_dot for (Operation *frontierOp : frontierOps) { auto frontierTy = dyn_cast(frontierOp->getResult(0).getType()); @@ -458,17 +460,30 @@ struct MMAV3HoistLayoutConversion auto operandTy = cast(operand.getType()); auto operandEltTy = operandTy.getElementType(); - auto oldAllocTy = alloc.getType(); - // TODO(ggengnv) previous encoding (oldAllocTy.getEncoding()) was for shared operand. - // Is it still appropriate for loading into registers? - auto newAllocTy = MemDescType::get(operandTy.getShape(), operandEltTy, - oldAllocTy.getEncoding(), oldAllocTy.getMemorySpace()); - auto localAlloc = rewriter.create(alloc.getLoc(), newAllocTy, operand); - auto localLoad = rewriter.create(alloc.getLoc(), operandTy, localAlloc); + ConvertLayoutOp cvt; Type cvtTy = RankedTensorType::get( operandTy.getShape(), operandTy.getElementType(), dotOperandEnc); - auto cvt = rewriter.create(alloc.getLoc(), cvtTy, localLoad); + + if (isa(operand.getDefiningOp())) { + auto oldAllocTy = alloc.getType(); + auto oldAllocEnc = cast(oldAllocTy.getEncoding()); + + auto newAllocEnc = SharedEncodingAttr::get( + oldAllocEnc.getContext(), dotOperandEnc, operandTy.getShape(), + getOrder(operandTy.getEncoding()), + getCTALayout(operandTy.getEncoding()), + operandTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); + + auto newAllocTy = MemDescType::get(operandTy.getShape(), operandEltTy, + newAllocEnc, oldAllocTy.getMemorySpace()); + auto localAlloc = rewriter.create(alloc.getLoc(), newAllocTy, operand); + auto localLoad = rewriter.create(alloc.getLoc(), operandTy, localAlloc); + cvt = rewriter.create(alloc.getLoc(), cvtTy, localLoad); + } else { + assert(isa(operand.getDefiningOp())); + cvt = rewriter.create(alloc.getLoc(), cvtTy, operand); + } newOperands.push_back(cvt); } @@ -510,13 +525,13 @@ class TritonGPUOptimizeDotOperandsPass auto ret = pm.run(m); mlir::RewritePatternSet patterns(context); + patterns.add(context); patterns.add(context); if (this->hoistLayoutConversion.getValue()) { patterns.add(context); } patterns.add(context); patterns.add(context); - patterns.add(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) signalPassFailure(); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 9c9fc8983a8d..e920de798289 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -380,7 +380,7 @@ static bool loadIsMMAv3(Operation *loadOp) { if (!sharedEnc.getHasLeadingOffset()) return false; - // In case LHS is in registers, don't pipeline for now + // In case LHS is in registers, don't pipeline for now TODO(ggengnv) is this necessary? auto op = *alloc->getUsers().begin(); if (auto localLoad = dyn_cast(op)) { auto resTy = cast(localLoad->getResultTypes()[0]); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 3d9ab5f5fb7c..1f046e10a1f2 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -42,26 +42,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, Value thread); } // namespace SharedToDotOperandMMAv2 -namespace SharedToDotOperandMMAv3 { -Value convertLayout(ConversionPatternRewriter &rewriter, - Location loc, Value tensor, - DotOperandEncodingAttr bEncoding, - const SharedMemoryObject &smemObj, - const LLVMTypeConverter *typeConverter, Value thread) { - SmallVector elems; - // TODO(ggengnv) fix - for (int i = 0; i < 16; i++) { - elems.push_back(int_val(8, 0)); - } - Type elemTy = elems[0].getType(); - MLIRContext *ctx = elemTy.getContext(); - Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems.size(), elemTy)); - auto result = packLLElements(loc, typeConverter, elems, rewriter, structTy); - return result; -} -} // namespace SharedToDotOperandMMAv3 - namespace { using namespace mlir; @@ -116,8 +96,9 @@ struct LocalLoadOpConversion if (mmaLayout.isHopper()) { // tensor core v3 assert(dotOperandLayout.getOpIdx() == 0); - res = SharedToDotOperandMMAv3::convertLayout(rewriter, loc, src, - dotOperandLayout, smemObj, typeConverter, getThreadId(rewriter, loc)); + res = SharedToDotOperandMMAv2::convertLayout( + 0, rewriter, loc, src, dotOperandLayout, + smemObj, typeConverter, getThreadId(rewriter, loc)); } else if (mmaLayout.isAmpere()) { // tensor core v2 res = SharedToDotOperandMMAv2::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index bf033bdd5322..ea9efec56717 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -25,6 +25,7 @@ class MMA16816SmemLoader { ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, int elemBytes, + int mmaElemBytes, bool isHopper, ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, const Location &loc); @@ -67,6 +68,9 @@ class MMA16816SmemLoader { int perPhase; int maxPhase; int elemBytes; + int mmaElemBytes; + bool isHopper; + bool isHopperWidthChange; ConversionPatternRewriter &rewriter; const Location &loc; MLIRContext *ctx{}; @@ -203,10 +207,10 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value lane, Value cSwizzleOffset) { // vecWidth // <-------> // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || *t0 ... *t0 t1 ... t1 t2 ... t2 t3 ... t3 /|\ -// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | -// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height -// ... | -// t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ +// t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | +// t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 || t8 ... t8 t9 ... t9 t10 .. t10 t11 .. t11 | quad height +// ... | +// t28 ... t28 t29 .. t29 t30 .. t30 t31 .. t31 || t28 .. t28 t29 .. t29 t30 .. t30 t31 .. t31 \|/ // --------------------------------------------- || -------------------------------------------- // *#t0 ... *#t0 t1 ... t1 t2 ... t2 t3 ... t3 || t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 // t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 || t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 @@ -364,6 +368,7 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; } else { // base pointers + // ptrs[k][...] holds `vec` pointers each for (quadK == k) std::array, 2> ptrs; for (int i = 0; i < vecWidth; i++) ptrs[0][i] = getPtr(ptrIdx + i); @@ -383,11 +388,13 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, i0 = add(i0, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); i1 = add(i1, mul(i32_val(batch * warpsPerCTA[0]), smemBatchOffset)); } + // ii[m] holds the offset for (quadM == m) std::array ii = {i0, i1}; // load 4 32-bit values from shared memory // (equivalent to ldmatrix.x4) SmallVector> vptrs(4, SmallVector(vecWidth)); + // i iterates the 2x2 quads, m-first for (int i = 0; i < 4; ++i) for (int j = 0; j < vecWidth; ++j) { vptrs[i][j] = gep(ptr_ty(ctx, 3), shemTy, ptrs[i / 2][j], ii[i % 2]); @@ -402,7 +409,9 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, int canonWidth = (8 * elemBytes * inc) / canonBits; Type canonInt = int_ty(canonBits); std::array retElems; - retElems.fill(undef(vec_ty(canonInt, 32 / canonBits))); + // don't pack to 32b for Hopper + int vecSize = isHopper ? 1 : 32 / canonBits; + retElems.fill(undef(vec_ty(canonInt, vecSize))); for (int r = 0; r < 2; ++r) { for (int em = 0; em < 2 * vecWidth; em += inc) { int e = em % vecWidth; @@ -421,8 +430,11 @@ MMA16816SmemLoader::loadX4(int batch, int mat0, int mat1, ArrayRef ptrs, } if (isActualTrans) std::swap(retElems[1], retElems[2]); - return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), - bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; + + auto iTy = isHopper ? int_ty(8 * elemBytes * inc) : i32_ty; + + return {bitcast(retElems[0], iTy), bitcast(retElems[1], iTy), + bitcast(retElems[2], iTy), bitcast(retElems[3], iTy)}; } } @@ -432,7 +444,8 @@ MMA16816SmemLoader::MMA16816SmemLoader( ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, SmallVector multiDimWarpId, int perPhase, int maxPhase, - int elemBytes, ConversionPatternRewriter &rewriter, + int elemBytes, int mmaElemBytes, bool isHopper, + ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, const Location &loc) : nPerWarp(nPerWarp), order(order.begin(), order.end()), warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder), @@ -441,17 +454,25 @@ MMA16816SmemLoader::MMA16816SmemLoader( matShape(matShape.begin(), matShape.end()), multiDimWarpId(multiDimWarpId.begin(), multiDimWarpId.end()), perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes), + mmaElemBytes(mmaElemBytes), isHopper(isHopper), rewriter(rewriter), loc(loc), ctx(rewriter.getContext()) { + isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes); + contiguousMatShape = matShape[order[0]]; stridedMatShape = matShape[order[1]]; stridedSmemOffset = smemStrides[order[1]]; smemBatchOffset = smemStrides[order[2]]; - vecWidth = 4 / elemBytes; + if (isHopperWidthChange) { + vecWidth = 4 / mmaElemBytes; + } else { + vecWidth = 4 / elemBytes; + } // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; nonKOrder = (kOrder == 2) ? 1 : 2; canUseLdmatrix = elemBytes == 2 || (!needTrans); canUseLdmatrix = canUseLdmatrix && (kWidth == vecWidth); + canUseLdmatrix = canUseLdmatrix && !isHopperWidthChange; if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwarpsPerTile] if not transposed, @@ -504,10 +525,27 @@ Type getSharedMemTy(Type argType) { llvm::report_fatal_error("mma16816 data type not supported"); } +std::vector unpackInt(const std::vector &inValues, Type elTy, + ConversionPatternRewriter &rewriter, Location loc, + const LLVMTypeConverter *typeConverter) { + const int inBitWidth = inValues[0].getType().getIntOrFloatBitWidth(); + std::vector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(elTy); + auto vecType = vec_ty(eltType, inBitWidth / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < inBitWidth / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + Value composeValuesToDotOperandLayoutStruct( const ValueTable &vals, int batch, int n0, int n1, const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter &rewriter, Type elTy, bool unpack) { std::vector elems; for (int b = 0; b < batch; ++b) for (int m = 0; m < n0; ++m) @@ -519,6 +557,10 @@ Value composeValuesToDotOperandLayoutStruct( } assert(!elems.empty()); + if (unpack) { + elems = unpackInt(elems, elTy, rewriter, loc, typeConverter); + } + Type elemTy = elems[0].getType(); MLIRContext *ctx = elemTy.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( @@ -544,18 +586,20 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, const int maxPhase = sharedLayout.getMaxPhase(); const int vecPhase = sharedLayout.getVec(); const int elemBytes = descTy.getElementTypeBitWidth() / 8; + const int mmaElemBytes = 4 / kWidth; + const bool isHopper = mmaLayout.getVersionMajor() == 3; auto order = sharedLayout.getOrder(); int nPerWarp = std::max(shapePerCTA[2] / mmaLayout.getWarpsPerCTA()[2], 8); - // (a, b) is the coordinate. auto load = [=, &rewriter, &vals](int batch, int a, int b) { MMA16816SmemLoader loader( nPerWarp, warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides, shapePerCTA /*tileShape*/, instrShape, matShape, multiDimWarpId, - perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); + perPhase, maxPhase, elemBytes, mmaElemBytes, + isHopper, rewriter, typeConverter, loc); // Offset of a slice within the original tensor in shared memory Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = loader.computeOffsets(lane, cSwizzleOffset); @@ -573,6 +617,7 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, auto [ha0, ha1, ha2, ha3] = loader.loadX4( batch, (kOrder == 2) ? a : b /*mat0*/, (kOrder == 2) ? b : a /*mat1*/, ptrs, matTy, getSharedMemTy(eltTy)); + if (!isA) std::swap(ha1, ha2); // the following is incorrect @@ -595,16 +640,18 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, MemDescType descTy, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread, bool isA) { + auto mmaLayout = mlir::cast(encoding.getParent()); + bool isHopper = mmaLayout.getVersionMajor() == 3; auto shapePerCTA = getShapePerCTA(descTy); int bitwidth = descTy.getElementTypeBitWidth(); - auto mmaLayout = mlir::cast(encoding.getParent()); + int mmaBitwidth = isHopper ? 32 / encoding.getKWidth() : bitwidth; ValueTable vals; - int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; - int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; + int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / mmaBitwidth; + int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth; auto numRep = - mmaLayout.getMMAv2Rep(shapePerCTA, bitwidth, encoding.getOpIdx()); + mmaLayout.getMMAv2Rep(shapePerCTA, mmaBitwidth, encoding.getOpIdx()); int kWidth = encoding.getKWidth(); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); @@ -616,7 +663,6 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, delinearize(rewriter, loc, warp, warpsPerCTA, order); Value warpB = urem(multiDimWarpId[0], i32_val(shapePerCTA[0])); int warpsPerTile; - auto rank = shapePerCTA.size(); Value warpM = urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / 16)); Value warpN = urem(multiDimWarpId[2], i32_val(shapePerCTA[2] / 8)); if (isA) @@ -652,7 +698,8 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, // Format the values to LLVM::Struct to passing to mma codegen. return composeValuesToDotOperandLayoutStruct( - vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter); + vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter, + descTy.getElementType(), /*unpack=*/isHopper); } template diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 3c52fe77e74a..cf0ddc248dd1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -36,14 +36,10 @@ class DecomposeLocalLoadToDotOperand op.getType().getEncoding()); MemDescType srcType = op.getSrc().getType(); auto sharedEncoding = cast(srcType.getEncoding()); - if (!dstDotOp) + if (!dstDotOp || !sharedEncoding.getHasLeadingOffset()) return failure(); - - auto parentEnc = cast(dstDotOp.getParent()) ; - if (!parentEnc || parentEnc.getVersionMajor() == 3 || !sharedEncoding.getHasLeadingOffset()) - return failure(); - RankedTensorType type = op.getType(); + auto parentEnc = dstDotOp.getParent(); int numWarps = triton::gpu::getNumWarpsPerCTA(parentEnc); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp( op->getParentOfType()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 1bb55373e046..4dbcfba29526 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -264,6 +264,30 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, // Return a vector of Value of the accumulator start at startIndex and pack the // values into 32bits in case the accumulator is fp16. +// +// `elements` contains all loaded register values for operand A. +// This consists of operand A for possibly multiple wgmma instructions. +// For each wgmma, each warp in a warp group feeds a single "warp matrix" +// Each warp matrix consists of 2x2 "quads". +// Each thread holds several elements in each quad. Right before a wgmma, +// the sum of bitwidth of +// the elements in each quad should add up to 32. +// +// These values are stored unrolled in `elements`. +// The ordering of dimensions is as follows: +// batch (only 1 batch for Hopper currently) +// matM (m-index of the "warp matrix") +// matK (k-index of the "warp matrix") +// quadM (m-index of the "quad" in the core matrix) +// quadK (k-index of the "quad" in the core matrix) +// vecIdx (index of the element in the quad; this is always along the k-dim) +// +// This ordering is decided when a tensor in DotOpEnc is lowered into llvm. +// For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand. +// Thus, both lowerings must obey this above ordering for the below code to be correct. +// +// Additionally, note that WGMMA expects quadK ordered before quadM (i.e. +// iterate along m-dim first); see loadI and mmaI. llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, Location loc, const SmallVector &elements, @@ -281,20 +305,24 @@ llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, } Type elementType = elements[0].getType(); int numElemsPer32Bits = 32 / elementType.getIntOrFloatBitWidth(); + assert(numElements == 4 * numElemsPer32Bits); // For FP16 and BF16 we need to pack accumulator into 32-bit integers. - int num32BitValues = numElements / numElemsPer32Bits; - llvm::SmallVector mmaOut(num32BitValues); + llvm::SmallVector mmaOut(4); Type packTy = vec_ty(elementType, numElemsPer32Bits); - for (int i = 0; i < num32BitValues; ++i) { - Value pack = rewriter.create(loc, packTy); - for (int j = 0; j < numElemsPer32Bits; ++j) { - Value element = elements[startIndex + i * numElemsPer32Bits + j]; - pack = insert_element(packTy, pack, element, i32_val(j)); + for (int quadK = 0; quadK < 2; quadK++) + for (int quadM = 0; quadM < 2; quadM++) { + int loadI = quadM * 2 + quadK; + int mmaI = quadK * 2 + quadM; + Value pack = rewriter.create(loc, packTy); + for (int j = 0; j < numElemsPer32Bits; ++j) { + Value element = elements[startIndex + loadI * numElemsPer32Bits + j]; + pack = insert_element(packTy, pack, element, i32_val(j)); + } + pack = bitcast(pack, rewriter.getIntegerType(32)); + mmaOut[mmaI] = pack; } - pack = bitcast(pack, rewriter.getIntegerType(32)); - mmaOut[i] = pack; - } + return mmaOut; } From f4fe44bef0ef9883502604fb48326f652c2dcc62 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Fri, 6 Sep 2024 22:11:37 +0000 Subject: [PATCH 05/18] Fix test regressions --- .../Transforms/OptimizeDotOperands.cpp | 13 ++++++-- test/Conversion/tritongpu_to_llvm_hopper.mlir | 10 +++--- test/TritonGPU/dot-operands.mlir | 16 ++++----- test/TritonGPU/invalid-attributes.mlir | 14 +++++--- test/TritonGPU/loop-pipeline-hopper.mlir | 12 +++---- .../pipeline-hopper-remove-wait.mlir | 4 +-- .../SharedToDotOperandMMAv2.cpp | 33 ++++++++++++------- .../DotOpToLLVM/WGMMA.cpp | 31 ++++++++--------- 8 files changed, 77 insertions(+), 56 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 54a3a23c2686..ee80ace7ec27 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -65,11 +65,17 @@ bool canHoistDotOpEncV3(Operation* op) { if (isa(op)) return false; + // Downcasting not currently supported; it will likely require minor + // adjustments in sharedToDotOperandMMv2 + auto oprType = getElementTypeOrSelf(op->getOperand(0)); + auto resType = getElementTypeOrSelf(op->getResult(0)); + if (oprType.getIntOrFloatBitWidth() > resType.getIntOrFloatBitWidth()) + return false; + // Don't hoist through u1 -> fp casts as they aren't supported in // ElementwiseOpToLLVM::reorderValues(). if (isa(op)) { - Type opType = getElementTypeOrSelf(op->getOperand(0)); - if (opType.isInteger(1)) + if (oprType.isInteger(1)) return false; } @@ -334,8 +340,9 @@ struct MMAV3UseRegOperand dstEnc.getVersionMajor() != 3) return failure(); auto srcTy = cast(alloc.getSrc().getType()); + auto kWidth = 32 / srcTy.getElementTypeBitWidth(); auto dotOperandEnc = DotOperandEncodingAttr::get( - dotOp.getContext(), /*opIdx=*/0, srcEnc, /*kWidth=*/0); + dotOp.getContext(), /*opIdx=*/0, srcEnc, kWidth); auto newTy = RankedTensorType::get(srcTy.getShape(), srcTy.getElementType(), dotOperandEnc); if (!isMmaToDotShortcut(srcTy, newTy)) diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index d44529966274..12bf0d8f4d43 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -97,9 +97,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: - tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } } @@ -114,10 +114,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Generate a wgmma where the first operand is a struct. // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} - tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { + tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> %m = triton_nvidia_gpu.warp_group_dot %a, %b, %cst { maxNumImpreciseAcc = 1073741824 : i32, inputPrecision = 0 : i32 } : - tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> + tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<128x256xf8E5M2, #shared> -> tensor<128x256xf32, #mma1> tt.return } } @@ -193,7 +193,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: prmt.b32 // CHECK: prmt.b32 tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) { - %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %opA = triton_gpu.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> tt.return } } diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 5fc02aa5e3d9..070e1a556d7f 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -164,8 +164,8 @@ tt.func @update_kwidth_slice( #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdesc<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf16, #mma>) -> !tt.memdesc<128x64xf16, #shared1> %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> @@ -180,8 +180,8 @@ tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: !tt.memdes #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_operand_A_fp8 -// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> +// CHECK: triton_nvidia_gpu.warp_group_dot %[[A]], {{.*}} : tensor<128x64xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_operand_A_fp8(%arg0: tensor<128x64xf8E5M2, #mma>, %arg1: !tt.memdesc<64x64xf8E5M2, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %A = triton_gpu.local_alloc %arg0 : (tensor<128x64xf8E5M2, #mma>) -> !tt.memdesc<128x64xf8E5M2, #shared1> %r = triton_nvidia_gpu.warp_group_dot %A, %arg1, %arg2 : !tt.memdesc<128x64xf8E5M2, #shared1> * !tt.memdesc<64x64xf8E5M2, #shared> -> tensor<128x64xf32, #mma> @@ -222,10 +222,10 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_push_elementwise // CHECK: %[[A_LOADED:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> -// CHECK: %[[A_MEMDESC:.*]] = triton_gpu.local_alloc %[[A_LOADED]] : (tensor<128x64xbf16, #blocked>) -> !tt.memdesc<128x64xbf16, #shared> -// CHECK: %[[A_REG:.*]] = triton_gpu.local_load %[[A_MEMDESC]] : !tt.memdesc<128x64xbf16, #shared> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_REG]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> -// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +// CHECK: %[[A_MEMDESC:.*]] = triton_gpu.local_alloc %[[A_LOADED]] : (tensor<128x64xbf16, #blocked>) -> !tt.memdesc<128x64xbf16, #shared1> +// CHECK: %[[A_REG:.*]] = triton_gpu.local_load %[[A_MEMDESC]] : !tt.memdesc<128x64xbf16, #shared1> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_REG]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !tt.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %a_bf16 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> %a = tt.fp_to_fp %a_bf16 : tensor<128x64xbf16, #blocked> -> tensor<128x64xf16, #blocked> diff --git a/test/TritonGPU/invalid-attributes.mlir b/test/TritonGPU/invalid-attributes.mlir index c8b3c2ef6b0b..78e9140afd1b 100644 --- a/test/TritonGPU/invalid-attributes.mlir +++ b/test/TritonGPU/invalid-attributes.mlir @@ -2,7 +2,7 @@ // expected-error@+2 {{triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: 2}} #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}> -#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked}> +#dot_op = #triton_gpu.dot_op<{opIdx = 2, parent = #blocked, kWidth = 2}> // ----- @@ -12,19 +12,25 @@ // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere or Hopper MMA parent}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 1, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> // ----- -// expected-error@+2 {{triton_gpu.dot_op kWidth parameter can only be non-zero for Ampere MMA parent}} +// expected-error@+2 {{triton_gpu.dot_op kWidth parameter is mandatory for Ampere or Hopper MMA parent}} +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> +#dot_op = #triton_gpu.dot_op<{opIdx = 0, parent = #mma}> + +// ----- + +// expected-error@+2 {{triton_gpu.dot_op opIdx parameter must be 0 for Hopper MMA parent}} #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> #dot_op = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index d391be688c23..2c2182154d6a 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -398,8 +398,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> + %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -481,7 +481,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 @@ -519,7 +519,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> - %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -624,7 +624,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %c0_i64 = arith.constant 0 : i64 %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> + %cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %c1_i32 = arith.constant 1 : i32 %c8_i32 = arith.constant 8 : i32 @@ -685,7 +685,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> + %dot2 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } diff --git a/test/TritonGPU/pipeline-hopper-remove-wait.mlir b/test/TritonGPU/pipeline-hopper-remove-wait.mlir index 74fd2e05551b..a7064ea82204 100644 --- a/test/TritonGPU/pipeline-hopper-remove-wait.mlir +++ b/test/TritonGPU/pipeline-hopper-remove-wait.mlir @@ -113,7 +113,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %115 = triton_nvidia_gpu.warp_group_dot %113, %114, %cst :!tt.memdesc<128x128xf16, #shared> * !tt.memdesc<128x64xf16, #shared1> -> tensor<128x64xf32, #mma> %116 = arith.truncf %115 : tensor<128x64xf32, #mma> to tensor<128x64xf16, #mma> %117 = triton_gpu.local_alloc %112 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared> - %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %118 = triton_gpu.convert_layout %116 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // The first dot gets converted to dot-async + wait. The second one // doesn't have a wait because the first wait is sufficient. // CHECK: triton_nvidia_gpu.warp_group_dot @@ -121,7 +121,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: triton_nvidia_gpu.warp_group_dot // CHECK-NOT: triton_nvidia_gpu.warp_group_dot_wait // CHECK: scf.yield - %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> + %119 = triton_nvidia_gpu.warp_group_dot %118, %117, %arg23 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x128xf16, #shared> -> tensor<128x128xf32, #mma1> %120 = arith.mulf %arg24, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %121 = arith.addf %120, %arg25 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> %122 = arith.extsi %c0_i32 : i32 to i64 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index ea9efec56717..e3ee42ab2de5 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -70,7 +70,6 @@ class MMA16816SmemLoader { int elemBytes; int mmaElemBytes; bool isHopper; - bool isHopperWidthChange; ConversionPatternRewriter &rewriter; const Location &loc; MLIRContext *ctx{}; @@ -456,7 +455,11 @@ MMA16816SmemLoader::MMA16816SmemLoader( perPhase(perPhase), maxPhase(maxPhase), elemBytes(elemBytes), mmaElemBytes(mmaElemBytes), isHopper(isHopper), rewriter(rewriter), loc(loc), ctx(rewriter.getContext()) { - isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes); + // If the current elemType width is different from the MMA elemType width, i.e. + // width-changing casting is done later in DotOp Layout... then, in the case of + // Hopper, the number of bytes held by each thread after loading will no longer + // be 32B. Hence this flag is required to stipulate different logic. + bool isHopperWidthChange = isHopper && (mmaElemBytes != elemBytes); contiguousMatShape = matShape[order[0]]; stridedMatShape = matShape[order[1]]; @@ -545,19 +548,27 @@ std::vector unpackInt(const std::vector &inValues, Type elTy, Value composeValuesToDotOperandLayoutStruct( const ValueTable &vals, int batch, int n0, int n1, const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter, Type elTy, bool unpack) { + ConversionPatternRewriter &rewriter, Type elTy, bool isHopper) { std::vector elems; for (int b = 0; b < batch; ++b) for (int m = 0; m < n0; ++m) - for (int k = 0; k < n1; ++k) { - elems.push_back(vals.at({b, 2 * m, 2 * k})); - elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); - } + for (int k = 0; k < n1; ++k) + if (isHopper) { + // Hopper expects opposite ordering + elems.push_back(vals.at({b, 2 * m, 2 * k})); + elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); + elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); + elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); + } else { + elems.push_back(vals.at({b, 2 * m, 2 * k})); + elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); + elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); + elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); + } + assert(!elems.empty()); - if (unpack) { + if (isHopper) { elems = unpackInt(elems, elTy, rewriter, loc, typeConverter); } @@ -644,7 +655,7 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, bool isHopper = mmaLayout.getVersionMajor() == 3; auto shapePerCTA = getShapePerCTA(descTy); int bitwidth = descTy.getElementTypeBitWidth(); - int mmaBitwidth = isHopper ? 32 / encoding.getKWidth() : bitwidth; + int mmaBitwidth = isHopper ? (32 / encoding.getKWidth()) : bitwidth; ValueTable vals; int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / mmaBitwidth; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 4dbcfba29526..995a510131b7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -278,16 +278,17 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, // batch (only 1 batch for Hopper currently) // matM (m-index of the "warp matrix") // matK (k-index of the "warp matrix") -// quadM (m-index of the "quad" in the core matrix) // quadK (k-index of the "quad" in the core matrix) +// quadM (m-index of the "quad" in the core matrix) // vecIdx (index of the element in the quad; this is always along the k-dim) // // This ordering is decided when a tensor in DotOpEnc is lowered into llvm. // For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand. // Thus, both lowerings must obey this above ordering for the below code to be correct. // -// Additionally, note that WGMMA expects quadK ordered before quadM (i.e. -// iterate along m-dim first); see loadI and mmaI. +// Additionally, note that WGMMA expects quadK ordered before quadM, i.e. the layout +// is quadM-major. This is opposite to Ampere's ordering for ldmatrix and dotOp. +// (see SharedToDotOperandMMAv2.cpp) llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, Location loc, const SmallVector &elements, @@ -305,24 +306,20 @@ llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, } Type elementType = elements[0].getType(); int numElemsPer32Bits = 32 / elementType.getIntOrFloatBitWidth(); - assert(numElements == 4 * numElemsPer32Bits); // For FP16 and BF16 we need to pack accumulator into 32-bit integers. - llvm::SmallVector mmaOut(4); + int num32BitValues = numElements / numElemsPer32Bits; + llvm::SmallVector mmaOut(num32BitValues); Type packTy = vec_ty(elementType, numElemsPer32Bits); - for (int quadK = 0; quadK < 2; quadK++) - for (int quadM = 0; quadM < 2; quadM++) { - int loadI = quadM * 2 + quadK; - int mmaI = quadK * 2 + quadM; - Value pack = rewriter.create(loc, packTy); - for (int j = 0; j < numElemsPer32Bits; ++j) { - Value element = elements[startIndex + loadI * numElemsPer32Bits + j]; - pack = insert_element(packTy, pack, element, i32_val(j)); - } - pack = bitcast(pack, rewriter.getIntegerType(32)); - mmaOut[mmaI] = pack; + for (int i = 0; i < num32BitValues; ++i) { + Value pack = rewriter.create(loc, packTy); + for (int j = 0; j < numElemsPer32Bits; ++j) { + Value element = elements[startIndex + i * numElemsPer32Bits + j]; + pack = insert_element(packTy, pack, element, i32_val(j)); } - + pack = bitcast(pack, rewriter.getIntegerType(32)); + mmaOut[i] = pack; + } return mmaOut; } From da07d160e6457c02007ef309be9cab35e0beae9d Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Mon, 9 Sep 2024 20:31:20 +0000 Subject: [PATCH 06/18] Rewrite OptimizeDotOperands logic and add tests --- .../Transforms/OptimizeDotOperands.cpp | 196 ++++++++++-------- test/TritonGPU/dot-operands.mlir | 34 ++- 2 files changed, 133 insertions(+), 97 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index ee80ace7ec27..88bde1a8d5a9 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -1,3 +1,4 @@ +#include "mlir/IR/IRMapping.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" @@ -82,6 +83,31 @@ bool canHoistDotOpEncV3(Operation* op) { return true; } +auto cloneSlice(PatternRewriter& rewriter, const SetVector& slice) { + IRMapping sliceMap; + SetVector newSlice; + for (Operation *op : slice) { + auto newOp = rewriter.clone(*op); + newSlice.insert(newOp); + sliceMap.map(op, newOp); + for (auto [result, newResult] : llvm::zip(op->getResults(), newOp->getResults())) { + assert(result != newResult); + sliceMap.map(result, newResult); + } + } + + for (auto [op, newOp] : sliceMap.getOperationMap()) + for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) { + auto defOp = operand.getDefiningOp(); + if (!slice.contains(defOp)) + continue; + + newOp->setOperand(oprIdx, sliceMap.lookup(operand)); + } + + return std::make_tuple(newSlice, sliceMap); +} + // Given // convert(trans(src)) #dot_operand -> // convert(local_load(trans(alloc(src)))) @@ -355,13 +381,21 @@ struct MMAV3UseRegOperand } }; -// TODO(ggengnv) more tests (multiple elt-wise ops) and document +// MMAV3's analog of HoistLayoutConversion, for operand A only; will make WarpGroupDot +// accept operand A in registers instead of shmem. +// +// local_alloc(elementwise(x)) -> +// elementwise(convert(x, #dot_operand)). +// +// Whereas (MMAV2) HoistLayoutConversion hoists thru one op at a time and requires +// multiple passes will directly hoist the convert to the right place in one pass. struct MMAV3HoistLayoutConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp, PatternRewriter &rewriter) const override { + // Can only hoist operand 0 auto alloc = dotOp.getOperand(0).getDefiningOp(); if (!alloc || !alloc.getSrc()) return failure(); @@ -374,138 +408,116 @@ struct MMAV3HoistLayoutConversion return failure(); // Performs checks for early stop - NvidiaMmaEncodingAttr dstEnc; Type inputEltTy; - { - auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); - dstEnc = - dyn_cast(getEncoding(dotOp.getResult())); - // Want: A's Encoding to be Blocked and D's encoding to be NvidiaMmA v3 - if (!srcEnc || !dstEnc || dstEnc.getVersionMajor() != 3) - return failure(); - - auto src = alloc.getSrc().getDefiningOp(); - - // Value passed to alloc must have Tensor arguments and single Tensor result - if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1) - return failure(); - if (!all_of(src->getOperandTypes(), - [](Type ty) { return isa(ty); })) - return failure(); - auto srcTy = dyn_cast(src->getResult(0).getType()); - if (!srcTy) - return failure(); - inputEltTy = srcTy.getElementType(); + auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); + auto dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); + // Want: A's Encoding to be Blocked and D's encoding to be NvidiaMmaV3 + if (!srcEnc || !dstEnc || dstEnc.getVersionMajor() != 3) + return failure(); + auto src = alloc.getSrc().getDefiningOp(); + // Value passed to alloc must have Tensor arguments and single Tensor result + if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1) + return failure(); + if (!all_of(src->getOperandTypes(), + [](Type ty) { return isa(ty); })) + return failure(); + auto srcTy = dyn_cast(src->getResult(0).getType()); + if (!srcTy) + return failure(); + inputEltTy = srcTy.getElementType(); - if (!canHoistDotOpEncV3(src)) - return failure(); - } + // Check src itself can be hoisted through + if (!canHoistDotOpEncV3(src)) + return failure(); + // Obtain backward slice SetVector slice; BackwardSliceOptions opt; opt.omitBlockArguments = true; opt.filter = [&](Operation *op) { - return (op->getParentRegion() == alloc->getParentRegion()) && !isa(op) - && (op->getNumOperands() != 0); // Ensures all ops in slice have operands + // Stop before Load, ConstantOp, or LocalLoad (which is unlikely) + return (op->getParentRegion() == alloc->getParentRegion()) + && !isa(op) + && (op->getNumOperands() != 0); }; - getBackwardSlice(alloc.getOperation(), &slice, opt); + // Verify slice can be hoisted through + if (slice.empty()) + return failure(); + auto isBlockedRankedTensor = [&](auto val) { - return isa(getEncoding(val)) && isa(val.getType()); + return isa(getEncoding(val)) && + isa(val.getType()); }; - SmallVector frontierOps; for (Operation *currOp : slice) { if (!canHoistDotOpEncV3(currOp)) return failure(); // We previously ensured that all ops in slice have at least one operand - bool isFrontier = false; for (auto operand : currOp->getOperands()) { - auto op = operand.getDefiningOp(); - if (!slice.contains(op)) { - if (!isa(op)) + auto defOp = operand.getDefiningOp(); + if (!slice.contains(defOp)) { + if (!isa(defOp)) { return failure(); + } - isFrontier = true; - } - } - - if (isFrontier) { - auto res = currOp->getResult(0); - if (!isBlockedRankedTensor(res)) - return failure(); - - if (!llvm::all_of(currOp->getOperands(), isBlockedRankedTensor)) - return failure(); + auto res = currOp->getResult(0); + if (!isBlockedRankedTensor(res)) { + return failure(); + } - frontierOps.push_back(currOp); + if (!llvm::all_of(currOp->getOperands(), isBlockedRankedTensor)) { + return failure(); + } + } } } - // Nothing to hoist through - if (frontierOps.empty()) - return failure(); + auto [newSlice, sliceMap] = cloneSlice(rewriter, slice); - // convert A operand + // For each frontierOp (i.e. op whose defining op is not in slice): + // load/constant; frontierOp; [hoistableElementwiseOps...]; local_alloc; warp_group_dot + // -> load/constant; convert_layout; frontierOp; [hoistableOps...]; warp_group_dot auto dotOperandEnc = DotOperandEncodingAttr::get( dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy); + for (auto op : newSlice) { + // Convert operands + for (auto [oprIdx, operand] : llvm::enumerate(op->getOperands())) { + auto defOp = operand.getDefiningOp(); + + // Not frontier; no need to convert operand + if (newSlice.contains(defOp)) { + op->moveAfter(defOp); + continue; + } - // For each frontierOp: - // load; frontierOp; [hoistableOps...]; local_alloc; warp_group_dot - // -> load; local_alloc; local_load; convert_layout; frontierOp; [hoistableOps...]; warp_group_dot - // or... - // constant; frontierOp; [hoistableOps...]; warp_group_dot - // -> constant; convert_layout; frontierOp; [hoistableOps...]; warp_group_dot - for (Operation *frontierOp : frontierOps) { - auto frontierTy = dyn_cast(frontierOp->getResult(0).getType()); - - SmallVector newOperands; - for (auto operand : frontierOp->getOperands()) { // We checked earlier that all operands are ranked tensors. auto operandTy = cast(operand.getType()); auto operandEltTy = operandTy.getElementType(); - ConvertLayoutOp cvt; - Type cvtTy = RankedTensorType::get( operandTy.getShape(), operandTy.getElementType(), dotOperandEnc); + auto cvt = rewriter.create(defOp->getLoc(), cvtTy, operand); - if (isa(operand.getDefiningOp())) { - auto oldAllocTy = alloc.getType(); - auto oldAllocEnc = cast(oldAllocTy.getEncoding()); - - auto newAllocEnc = SharedEncodingAttr::get( - oldAllocEnc.getContext(), dotOperandEnc, operandTy.getShape(), - getOrder(operandTy.getEncoding()), - getCTALayout(operandTy.getEncoding()), - operandTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); - - auto newAllocTy = MemDescType::get(operandTy.getShape(), operandEltTy, - newAllocEnc, oldAllocTy.getMemorySpace()); - auto localAlloc = rewriter.create(alloc.getLoc(), newAllocTy, operand); - auto localLoad = rewriter.create(alloc.getLoc(), operandTy, localAlloc); - cvt = rewriter.create(alloc.getLoc(), cvtTy, localLoad); - } else { - assert(isa(operand.getDefiningOp())); - cvt = rewriter.create(alloc.getLoc(), cvtTy, operand); - } - - newOperands.push_back(cvt); + op->setOperand(oprIdx, cvt); + op->moveAfter(cvt); } - auto newFrontier = rewriter.clone(*frontierOp); - for (int i = 0; i < newOperands.size(); i++) - newFrontier->setOperand(i, newOperands[i]); - newFrontier->getResult(0).setType(RankedTensorType::get( - frontierTy.getShape(), frontierTy.getElementType(), dotOperandEnc)); + // Convert result + auto resTy = dyn_cast(op->getResult(0).getType()); - rewriter.replaceOp(frontierOp, newFrontier); + op->getResult(0).setType(RankedTensorType::get( + resTy.getShape(), resTy.getElementType(), dotOperandEnc)); } - // replace LHS operand with its parent (in dotOpEnc) - rewriter.modifyOpInPlace(dotOp, [&]() { dotOp.setOperand(0, alloc.getSrc()); }); + // replace LHS operand with alloc's parent (cloned) + auto newDotOperand = sliceMap.lookup(alloc.getSrc()); + rewriter.modifyOpInPlace(dotOp, [&]() { + dotOp.setOperand(0, newDotOperand); + }); return success(); } diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 070e1a556d7f..ab70a081a6bd 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -214,17 +214,15 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // ----- - #blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> #shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK: tt.func @mma_v3_reg_push_elementwise -// CHECK: %[[A_LOADED:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> -// CHECK: %[[A_MEMDESC:.*]] = triton_gpu.local_alloc %[[A_LOADED]] : (tensor<128x64xbf16, #blocked>) -> !tt.memdesc<128x64xbf16, #shared1> -// CHECK: %[[A_REG:.*]] = triton_gpu.local_load %[[A_MEMDESC]] : !tt.memdesc<128x64xbf16, #shared1> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_REG]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = triton_gpu.convert_layout %[[A_BLOCK]] : tensor<128x64xbf16, #blocked> -> tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = tt.fp_to_fp %[[A_DOTOP]] : tensor<128x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> // CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_CASTED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.func @mma_v3_reg_push_elementwise(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !tt.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ %a_bf16 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> @@ -235,3 +233,29 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : } } +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_push_elementwise_chained +// CHECK: %[[CST_DOTOP:.*]] = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_BLOCK:.*]] = tt.load %{{.*}} : tensor<128x64x!tt.ptr, #blocked> +// CHECK: %[[A_DOTOP:.*]] = triton_gpu.convert_layout %[[A_BLOCK]] : tensor<128x64xi8, #blocked> -> tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_CASTED:.*]] = arith.sitofp %[[A_DOTOP]] : tensor<128x64xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_SCALED:.*]] = arith.mulf %[[A_CASTED]], %[[CST_DOTOP]] : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[A_NEGATED:.*]] = arith.negf %[[A_SCALED]] : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> +// CHECK: %[[R:.*]] = triton_nvidia_gpu.warp_group_dot %[[A_NEGATED]], %{{.*}}, %{{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.func @mma_v3_reg_push_elementwise_chained(%pa: tensor<128x64x!tt.ptr, #blocked>, %dotb: !tt.memdesc<64x64xf16, #shared>, %dotc: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %a_i8 = tt.load %pa : tensor<128x64x!tt.ptr, #blocked> + %a_f16 = arith.sitofp %a_i8 : tensor<128x64xi8, #blocked> to tensor<128x64xf16, #blocked> + %a_scaled = arith.mulf %a_f16, %cst : tensor<128x64xf16, #blocked> + %a_negated = arith.negf %a_scaled : tensor<128x64xf16, #blocked> + %dota = triton_gpu.local_alloc %a_negated: (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared1> + %r = triton_nvidia_gpu.warp_group_dot %dota, %dotb, %dotc : !tt.memdesc<128x64xf16, #shared1> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> + } +} From ab36a0fba6b6e8e2ba5a3d04b0718adc8a491ba1 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Fri, 20 Sep 2024 19:25:33 +0000 Subject: [PATCH 07/18] Improve comments --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 4 ++++ .../TritonGPU/Transforms/OptimizeDotOperands.cpp | 13 +++++++------ .../SharedToDotOperandMMAv2.cpp | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c03ed737aa2d..92fd3995227f 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1310,6 +1310,10 @@ The parent field is the layout of d. kWidth defines number of consecutive elements stored by one thread along k dimension. Some layouts do not use this parameter, either because they have a fixed number of elements along the K dim, or they use all elements of the tensor along the K dim. + +We require kWidth to be provided for Hopper because the dtype at loading might be +different from the dtype at WGMMA, due to casting. The kWidth is determined by the +dtype at WGMMA. }]; let parameters = ( diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 88bde1a8d5a9..4f88410a5b92 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -55,7 +55,6 @@ bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) { bool canHoistDotOpEncV3(Operation* op) { // Only consider custom conversions or arith ops. - // TODO(jlebar): Is this too restrictive? if (!isa(op) && !isPureUnaryInlineAsm(op) && op->getDialect()->getTypeID() != TypeID::get()) return false; @@ -388,7 +387,7 @@ struct MMAV3UseRegOperand // elementwise(convert(x, #dot_operand)). // // Whereas (MMAV2) HoistLayoutConversion hoists thru one op at a time and requires -// multiple passes will directly hoist the convert to the right place in one pass. +// multiple passes, this will directly hoist the convert to the right place in one pass. struct MMAV3HoistLayoutConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -407,7 +406,7 @@ struct MMAV3HoistLayoutConversion if (!isa(getEncoding(dotOp.getOperand(0)))) return failure(); - // Performs checks for early stop + // Step 1: Performs checks for early stop Type inputEltTy; auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); auto dstEnc = @@ -431,7 +430,7 @@ struct MMAV3HoistLayoutConversion if (!canHoistDotOpEncV3(src)) return failure(); - // Obtain backward slice + // Step 2: Obtain slice of ops between load/constant and local_alloc SetVector slice; BackwardSliceOptions opt; opt.omitBlockArguments = true; @@ -443,7 +442,7 @@ struct MMAV3HoistLayoutConversion }; getBackwardSlice(alloc.getOperation(), &slice, opt); - // Verify slice can be hoisted through + // Step 3: Verify slice can be hoisted through if (slice.empty()) return failure(); @@ -476,8 +475,10 @@ struct MMAV3HoistLayoutConversion } } + // Step 4: Clone slice auto [newSlice, sliceMap] = cloneSlice(rewriter, slice); + // Step 5: Modify slice (add convert_layouts and change encoding of results) // For each frontierOp (i.e. op whose defining op is not in slice): // load/constant; frontierOp; [hoistableElementwiseOps...]; local_alloc; warp_group_dot // -> load/constant; convert_layout; frontierOp; [hoistableOps...]; warp_group_dot @@ -513,7 +514,7 @@ struct MMAV3HoistLayoutConversion resTy.getShape(), resTy.getElementType(), dotOperandEnc)); } - // replace LHS operand with alloc's parent (cloned) + // Step 6: replace LHS operand with alloc's parent in cloned slice auto newDotOperand = sliceMap.lookup(alloc.getSrc()); rewriter.modifyOpInPlace(dotOp, [&]() { dotOp.setOperand(0, newDotOperand); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index e3ee42ab2de5..de8d254c54ff 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -554,7 +554,7 @@ Value composeValuesToDotOperandLayoutStruct( for (int m = 0; m < n0; ++m) for (int k = 0; k < n1; ++k) if (isHopper) { - // Hopper expects opposite ordering + // WGMMA.cpp expects different (m-major) ordering elems.push_back(vals.at({b, 2 * m, 2 * k})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); From 3b4ffc21fd2f549c58380b8bc74c986fa9756dea Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Mon, 23 Sep 2024 23:04:14 +0000 Subject: [PATCH 08/18] Improve documentation and refactor --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 8 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 9 +- .../Transforms/OptimizeDotOperands.cpp | 153 ++++++++++-------- python/test/unit/language/test_core.py | 2 + .../SharedToDotOperandMMAv2.cpp | 14 +- .../DotOpToLLVM/MMAv2.cpp | 4 +- 6 files changed, 111 insertions(+), 79 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 92fd3995227f..3a5d03caaad8 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1215,7 +1215,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: SmallVector getMMAv1Rep(int opIdx) const; SmallVector getMMAv1ShapePerWarp(int opIdx) const; int getMMAv1Vec(int opIdx) const; - SmallVector getMMAv2Rep(ArrayRef shape, + SmallVector getMMAv2OrV3Rep(ArrayRef shape, int bitwidth, int opIdx) const; bool supportReduction() const { @@ -1324,7 +1324,7 @@ dtype at WGMMA. ); let builders = [ - // Specially for MMAV1(Volta) + // For MMAV2 and V3 AttrBuilder<(ins "unsigned":$opIdx, "Attribute":$parent, "Type":$eltTy), [{ @@ -1332,8 +1332,8 @@ dtype at WGMMA. if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper())) return $_get(context, opIdx, parent, 0); unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); - unsigned MMAv2kWidth = 32 / bitwidth; - return $_get(context, opIdx, parent, MMAv2kWidth); + unsigned kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, kWidth); }]> ]; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 54ee7be91150..d045b613f500 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1961,7 +1961,7 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { return 2 * getMMAv1Rep(opIdx)[opIdx]; } -SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, +SmallVector NvidiaMmaEncodingAttr::getMMAv2OrV3Rep(ArrayRef shape, int bitwidth, int opIdx) const { assert(isAmpere() || isHopper()); @@ -1996,14 +1996,15 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( if (isHopper()) { assert(opIdx == 0); auto instrMNK = getInstrShape(); - auto wpt = getWarpsPerCTA(); - int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); + int repM = ceil(shapePerCTA[0], instrMNK[0] * warpsPerCTAM); int repK = ceil(shapePerCTA[1], instrMNK[2]); + // For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds + // kWidth elements for each quadrant. WGMMA is repeated repM * repK times. return 4 * kWidth * repM * repK; } // A100 if (isAmpere()) { - auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); + auto rep = getMMAv2OrV3Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); if (opIdx == 0) return 4 * rep[0] * rep[1] * rep[2]; if (opIdx == 1) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 4f88410a5b92..c55aea243f8b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -21,13 +21,13 @@ namespace { // Returns whether we can hoist DotOp Encoding through `op`. // Roughly, whether op is elementwise and thus threads don't need -// to exchange elements. But some ops are not current supported even though +// to exchange elements. But some ops are not currently supported even though // they meet that criterion. bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) { // Only consider custom conversions or arith ops. // TODO(jlebar): Is this too restrictive? if (!isa(op) && !isPureUnaryInlineAsm(op) && - op->getDialect()->getTypeID() != TypeID::get()) + !isa(op->getDialect())) return false; // Quick handling to fix loading issues when computing the original @@ -53,10 +53,28 @@ bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) { return true; } +// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A +// is in registers). bool canHoistDotOpEncV3(Operation* op) { + // Must have exactly one result and at least one operand + if (op->getNumOperands() == 0 || op->getNumResults() != 1) + return false; + + auto isBlockedOrDotOpRankedTensor = [](Type ty) { + auto tensorTy = dyn_cast(ty); + if (!tensorTy) + return false; + return isa(tensorTy.getEncoding()); + }; + + // Operands and results must be of RankedTensorType and Blocked or DotOp + if (!(all_of(op->getOperandTypes(), isBlockedOrDotOpRankedTensor) && + all_of(op->getResultTypes(), isBlockedOrDotOpRankedTensor))) + return false; + // Only consider custom conversions or arith ops. if (!isa(op) && !isPureUnaryInlineAsm(op) && - op->getDialect()->getTypeID() != TypeID::get()) + !isa(op->getDialect())) return false; // Currently, these instructions are not supported during lowering of @@ -74,17 +92,22 @@ bool canHoistDotOpEncV3(Operation* op) { // Don't hoist through u1 -> fp casts as they aren't supported in // ElementwiseOpToLLVM::reorderValues(). - if (isa(op)) { - if (oprType.isInteger(1)) - return false; - } + if (isa(op) && oprType.isInteger(1)) + return false; return true; } +// Helper to perform a "deep" clone of the given slice (i.e., set of ops), +// returning a tuple (newSlice, sliceMap), where newSlice is the cloned slice, +// and sliceMap the IRMapping that maps the ops and result values of the +// original slice to those in the cloned slice. auto cloneSlice(PatternRewriter& rewriter, const SetVector& slice) { IRMapping sliceMap; SetVector newSlice; + + // First pass: clone ops; the result values are cloned as well, but the operands still + // refer to the original result values for (Operation *op : slice) { auto newOp = rewriter.clone(*op); newSlice.insert(newOp); @@ -95,6 +118,7 @@ auto cloneSlice(PatternRewriter& rewriter, const SetVector& slice) } } + // Second pass: replace operand references in cloned ops to point to cloned values for (auto [op, newOp] : sliceMap.getOperationMap()) for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) { auto defOp = operand.getDefiningOp(); @@ -383,11 +407,18 @@ struct MMAV3UseRegOperand // MMAV3's analog of HoistLayoutConversion, for operand A only; will make WarpGroupDot // accept operand A in registers instead of shmem. // -// local_alloc(elementwise(x)) -> -// elementwise(convert(x, #dot_operand)). +// Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot +// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; warp_group_dot // -// Whereas (MMAV2) HoistLayoutConversion hoists thru one op at a time and requires -// multiple passes, this will directly hoist the convert to the right place in one pass. +// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a time and +// requires multiple passes, this pattern will directly hoist the convert to the right +// place in one pass. +// +// Or, to be more precise, this pattern deletes the local_alloc op and inserts a +// convert_layout op after each load that warp_group_dot uses; so this is not simply hoisting +// a convert_layout op up as in V2, but can be considered as first changing local_alloc to +// convert_layout and then hoisting, which results in WGMMA now accepting operand A in DotOp +// layout rather than Shared. struct MMAV3HoistLayoutConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -397,45 +428,34 @@ struct MMAV3HoistLayoutConversion // Can only hoist operand 0 auto alloc = dotOp.getOperand(0).getDefiningOp(); if (!alloc || !alloc.getSrc()) - return failure(); + return rewriter.notifyMatchFailure(dotOp, + "operand A must be produced by local_alloc"); auto getEncoding = [](Value v) { return cast(v.getType()).getEncoding(); }; if (!isa(getEncoding(dotOp.getOperand(0)))) - return failure(); + return rewriter.notifyMatchFailure(dotOp, + "requires Shared encoding for operand A"); // Step 1: Performs checks for early stop - Type inputEltTy; auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); - auto dstEnc = - dyn_cast(getEncoding(dotOp.getResult())); - // Want: A's Encoding to be Blocked and D's encoding to be NvidiaMmaV3 - if (!srcEnc || !dstEnc || dstEnc.getVersionMajor() != 3) - return failure(); - auto src = alloc.getSrc().getDefiningOp(); - // Value passed to alloc must have Tensor arguments and single Tensor result - if (!src || src->getNumOperands() == 0 || src->getNumResults() != 1) - return failure(); - if (!all_of(src->getOperandTypes(), - [](Type ty) { return isa(ty); })) - return failure(); - auto srcTy = dyn_cast(src->getResult(0).getType()); - if (!srcTy) - return failure(); - inputEltTy = srcTy.getElementType(); + if (!srcEnc) + return rewriter.notifyMatchFailure(alloc, + "requires src to have Blocked encoding"); - // Check src itself can be hoisted through - if (!canHoistDotOpEncV3(src)) - return failure(); + auto dstEnc = dyn_cast(getEncoding(dotOp.getResult())); + if (!dstEnc || dstEnc.getVersionMajor() != 3) + return rewriter.notifyMatchFailure(dotOp, + "requires result in NvidiaMma encoding"); // Step 2: Obtain slice of ops between load/constant and local_alloc SetVector slice; BackwardSliceOptions opt; opt.omitBlockArguments = true; opt.filter = [&](Operation *op) { - // Stop before Load, ConstantOp, or LocalLoad (which is unlikely) + // Stop before Load, ConstantOp, or LocalLoad return (op->getParentRegion() == alloc->getParentRegion()) && !isa(op) && (op->getNumOperands() != 0); @@ -444,33 +464,24 @@ struct MMAV3HoistLayoutConversion // Step 3: Verify slice can be hoisted through if (slice.empty()) - return failure(); - - auto isBlockedRankedTensor = [&](auto val) { - return isa(getEncoding(val)) && - isa(val.getType()); - }; + return rewriter.notifyMatchFailure(dotOp, "nothing to hoist through"); + // We define frontierOp as an op outside this slice whose result is used by an op in + // this slice. We must eventually convert the result of all frontierOps to + // DotOperandEncoding. This is done via the insertion of ConvertLayout after each + // frontierOp. + // We currently support frontierOp to be load or constant. for (Operation *currOp : slice) { if (!canHoistDotOpEncV3(currOp)) - return failure(); + return rewriter.notifyMatchFailure(currOp, "cannot hoist through"); // We previously ensured that all ops in slice have at least one operand for (auto operand : currOp->getOperands()) { auto defOp = operand.getDefiningOp(); if (!slice.contains(defOp)) { - if (!isa(defOp)) { - return failure(); - } - - auto res = currOp->getResult(0); - if (!isBlockedRankedTensor(res)) { - return failure(); - } - - if (!llvm::all_of(currOp->getOperands(), isBlockedRankedTensor)) { - return failure(); - } + // ensure frontierOp is load or constant + if (!isa(defOp)) + return rewriter.notifyMatchFailure(defOp, "must be load or constant"); } } } @@ -478,43 +489,51 @@ struct MMAV3HoistLayoutConversion // Step 4: Clone slice auto [newSlice, sliceMap] = cloneSlice(rewriter, slice); - // Step 5: Modify slice (add convert_layouts and change encoding of results) - // For each frontierOp (i.e. op whose defining op is not in slice): - // load/constant; frontierOp; [hoistableElementwiseOps...]; local_alloc; warp_group_dot - // -> load/constant; convert_layout; frontierOp; [hoistableOps...]; warp_group_dot + // Step 5: Modify the cloned slice to have dotOp encoding. + // Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot + // After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; warp_group_dot + // + // Specifically, this step will change all value types from #blocked to #dot_op + // encoding in the cloned slice, and for those values produced by frontierOps (i.e., + // outside the slice), we will insert convert_layout's after the frontierOp. + auto srcTy = cast(alloc.getSrc().getType()); + Type inputEltTy = srcTy.getElementType(); auto dotOperandEnc = DotOperandEncodingAttr::get( dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy); + for (auto op : newSlice) { - // Convert operands + // Step 5a: If any operand is defined by a frontierOp, we must insert a + // convert_layout(#dot_op) after the frontierOp and before currOp for (auto [oprIdx, operand] : llvm::enumerate(op->getOperands())) { + auto defOp = operand.getDefiningOp(); - // Not frontier; no need to convert operand - if (newSlice.contains(defOp)) { - op->moveAfter(defOp); + // defOp is not frontier (i.e. it's within slice); no need to convert the + // layout of its result + if (newSlice.contains(defOp)) continue; - } - // We checked earlier that all operands are ranked tensors. + // We checked earlier that all operands are ranked tensors auto operandTy = cast(operand.getType()); auto operandEltTy = operandTy.getElementType(); Type cvtTy = RankedTensorType::get( operandTy.getShape(), operandTy.getElementType(), dotOperandEnc); + rewriter.setInsertionPoint(op); auto cvt = rewriter.create(defOp->getLoc(), cvtTy, operand); op->setOperand(oprIdx, cvt); - op->moveAfter(cvt); } - // Convert result + // Step 5b: Change the result to have DotOp rather than Blocked encoding auto resTy = dyn_cast(op->getResult(0).getType()); - op->getResult(0).setType(RankedTensorType::get( resTy.getShape(), resTy.getElementType(), dotOperandEnc)); } - // Step 6: replace LHS operand with alloc's parent in cloned slice + // Step 6: replace LHS operand with alloc's parent in the cloned slice + // This changes the warpGroupDot to accept a DotOp tensor as operand A instead of + // a Shared memdesc. auto newDotOperand = sliceMap.lookup(alloc.getSrc()); rewriter.modifyOpInPlace(dotOp, [&]() { dotOp.setOperand(0, newDotOperand); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3d1cbc5a82f0..6718a96df4ad 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -11,6 +11,8 @@ import pytest import torch import os +os.environ['TRITON_ALWAYS_COMPILE'] = '1' +os.environ['MLIR_ENABLE_DUMP'] = '1' import inspect from numpy.random import RandomState diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index de8d254c54ff..a15f486c7dd9 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -550,11 +550,18 @@ Value composeValuesToDotOperandLayoutStruct( const LLVMTypeConverter *typeConverter, Location loc, ConversionPatternRewriter &rewriter, Type elTy, bool isHopper) { std::vector elems; + // Existing convention for the ordering of quad values in llvm.struct + // is m-major for Hopper and k-major for Ampere, even though both Ampere + // and Hopper MMA's expect m-major ordering in PTX. + // + // To unify the ordering conventions would potentially require touching + // `ConvertLayoutOpToLLVM.cpp`, `ElementwiseOpToLLVM.cpp`, `MMAv2.cpp`, + // `WGMMA.cpp`, and possibly others. For now, we are using an if-check + // here to route to the correct ordering. for (int b = 0; b < batch; ++b) for (int m = 0; m < n0; ++m) for (int k = 0; k < n1; ++k) if (isHopper) { - // WGMMA.cpp expects different (m-major) ordering elems.push_back(vals.at({b, 2 * m, 2 * k})); elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); @@ -655,6 +662,9 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, bool isHopper = mmaLayout.getVersionMajor() == 3; auto shapePerCTA = getShapePerCTA(descTy); int bitwidth = descTy.getElementTypeBitWidth(); + // For Hopper WGMMA, the sum of bitwidth of the elements in each quad should add + // up to 32. We use kWidth to compute the element bitwidth of the input to WGMMA, + // which could be different from `bitwidth` due to later casting. int mmaBitwidth = isHopper ? (32 / encoding.getKWidth()) : bitwidth; ValueTable vals; @@ -662,7 +672,7 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth; auto numRep = - mmaLayout.getMMAv2Rep(shapePerCTA, mmaBitwidth, encoding.getOpIdx()); + mmaLayout.getMMAv2OrV3Rep(shapePerCTA, mmaBitwidth, encoding.getOpIdx()); int kWidth = encoding.getKWidth(); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index af897ef546dd..928b46cbbd90 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -318,10 +318,10 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); auto repA = cast(dotOpA.getParent()) - .getMMAv2Rep(aShapePerCTA, bitwidth, dotOpA.getOpIdx()); + .getMMAv2OrV3Rep(aShapePerCTA, bitwidth, dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); auto repB = cast(dotOpB.getParent()) - .getMMAv2Rep(bShapePerCTA, bitwidth, dotOpB.getOpIdx()); + .getMMAv2OrV3Rep(bShapePerCTA, bitwidth, dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); From cdf2ae0ea7545559d50750b95f55b443a25ab532 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Mon, 23 Sep 2024 23:10:11 +0000 Subject: [PATCH 09/18] Rename SharedToDotOperandMMAv2 -> ...v2OrV3 --- .../nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt | 2 +- .../lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 8 ++++---- ...otOperandMMAv2.cpp => SharedToDotOperandMMAv2OrV3.cpp} | 4 ++-- .../lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) rename third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/{SharedToDotOperandMMAv2.cpp => SharedToDotOperandMMAv2OrV3.cpp} (99%) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index 197901d8555c..37c3bdc7d45c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -1,6 +1,6 @@ add_triton_library(TritonNVIDIAGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp - ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp + ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp ConvertLayoutOpToLLVM.cpp DotOpToLLVM/MMAv1.cpp DotOpToLLVM/MMAv2.cpp diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1f046e10a1f2..5b5c89f84413 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -34,13 +34,13 @@ Value convertLayout(int opIdx, Value tensor, const SharedMemoryObject &smemObj, } // namespace SharedToDotOperandMMAv1 -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter, Value thread); -} // namespace SharedToDotOperandMMAv2 +} // namespace SharedToDotOperandMMAv2OrV3 namespace { @@ -96,11 +96,11 @@ struct LocalLoadOpConversion if (mmaLayout.isHopper()) { // tensor core v3 assert(dotOperandLayout.getOpIdx() == 0); - res = SharedToDotOperandMMAv2::convertLayout( + res = SharedToDotOperandMMAv2OrV3::convertLayout( 0, rewriter, loc, src, dotOperandLayout, smemObj, typeConverter, getThreadId(rewriter, loc)); } else if (mmaLayout.isAmpere()) { // tensor core v2 - res = SharedToDotOperandMMAv2::convertLayout( + res = SharedToDotOperandMMAv2OrV3::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, smemObj, typeConverter, getThreadId(rewriter, loc)); } else if (mmaLayout.isVolta() && isMMA) { // tensor core v1 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp similarity index 99% rename from third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp rename to third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index a15f486c7dd9..2d4b30ba2c7d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -832,7 +832,7 @@ getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, return expandedSmemObj; } -namespace SharedToDotOperandMMAv2 { +namespace SharedToDotOperandMMAv3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, @@ -853,4 +853,4 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, expandedSmemObj, typeConverter, thread, false); } } -} // namespace SharedToDotOperandMMAv2 +} // namespace SharedToDotOperandMMAv3 diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 995a510131b7..cfc487c59ecb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -288,7 +288,7 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, // // Additionally, note that WGMMA expects quadK ordered before quadM, i.e. the layout // is quadM-major. This is opposite to Ampere's ordering for ldmatrix and dotOp. -// (see SharedToDotOperandMMAv2.cpp) +// (see SharedToDotOperandMMAv2OrV3.cpp) llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, Location loc, const SmallVector &elements, From 7a8ac2ec74cfb851d559965d2b417fcaf02fced4 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Tue, 24 Sep 2024 11:15:57 -0700 Subject: [PATCH 10/18] Remove debug flags in test_core.py --- python/test/unit/language/test_core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6718a96df4ad..3d1cbc5a82f0 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -11,8 +11,6 @@ import pytest import torch import os -os.environ['TRITON_ALWAYS_COMPILE'] = '1' -os.environ['MLIR_ENABLE_DUMP'] = '1' import inspect from numpy.random import RandomState From d898568fc66a1488f5bc2b5408c33540473fc4aa Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Tue, 24 Sep 2024 22:45:46 +0000 Subject: [PATCH 11/18] Fix bad rename --- .../ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 2d4b30ba2c7d..9897e1b17e6e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -832,7 +832,7 @@ getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc, return expandedSmemObj; } -namespace SharedToDotOperandMMAv3 { +namespace SharedToDotOperandMMAv2OrV3 { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, @@ -853,4 +853,4 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, expandedSmemObj, typeConverter, thread, false); } } -} // namespace SharedToDotOperandMMAv3 +} // namespace SharedToDotOperandMMAv2OrV3 From 3bf5ddc394dbc4bf73568fa6d2dd404e1ee088df Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Thu, 19 Sep 2024 16:46:08 +0000 Subject: [PATCH 12/18] Initial changes for pipelining --- .../Pipeliner/MatmulLoopPipeline.cpp | 94 +++++++++++-------- 1 file changed, 56 insertions(+), 38 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index e920de798289..3e6ee52231b1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -44,7 +44,8 @@ struct LoadInfo { ttg::SharedEncodingAttr sharedEncoding = nullptr; // Blocked encoding is used for loads not used by the dot. ttg::BlockedEncodingAttr blockedEncoding = nullptr; - bool loadIsMMAV3 = false; + bool loadIsMMAv3Shared = false; + bool loadIsMMAv3Registers = false; int distToUse = 0; bool usedByDot = false; }; @@ -102,7 +103,7 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Operation *wait = builder.create(loc, commmit->getResult(0), 0); - bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; + auto loadIsMMAv3Shared = loadToInfo[loadOp].loadIsMMAv3Shared; auto [stage, cluster] = schedule[loadOp]; schedule.erase(loadOp); schedule.insert(copy, stage, cluster); @@ -113,7 +114,7 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, loadOffsets[0] = extractIdx; auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); - if (isMMV3Load) { + if (loadIsMMAv3Shared) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); alloc.erase(); @@ -178,7 +179,7 @@ static void createTMAAsyncCopy( Operation *copy = builder.create( loc, loadOp.getDescPtr(), loadOp.getIndices(), barrier, view, pred); - bool isMMV3Load = loadToInfo[loadOp].loadIsMMAV3; + auto loadIsMMAv3Shared = loadToInfo[loadOp].loadIsMMAv3Shared; auto [stage, cluster] = schedule[loadOp]; schedule.erase(loadOp); schedule.insert(copy, stage, cluster); @@ -189,7 +190,7 @@ static void createTMAAsyncCopy( loadOffsets[0] = extractIdx; auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); - if (isMMV3Load) { + if (loadIsMMAv3Shared) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); alloc.erase(); @@ -275,7 +276,7 @@ getBlockedEncoding(tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfo) { } static std::optional -getSharedEncoding(Operation *loadOp, bool isMMAV3) { +getSharedEncoding(Operation *loadOp, bool isMMAV3Shared) { auto ty = cast(loadOp->getResultTypes()[0]); auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); auto blockedOrder = ttg::getOrder(ty.getEncoding()); @@ -290,7 +291,7 @@ getSharedEncoding(Operation *loadOp, bool isMMAV3) { } else { order = blockedOrder; } - if (isMMAV3) { + if (isMMAV3Shared) { return ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType()); } @@ -370,34 +371,43 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { return loadOpToIndLevelAndUse; } -static bool loadIsMMAv3(Operation *loadOp) { - if (!loadOp->hasOneUse()) - return false; - auto alloc = dyn_cast(*loadOp->getUsers().begin()); - if (!alloc) - return false; - auto sharedEnc = cast(alloc.getType().getEncoding()); - if (!sharedEnc.getHasLeadingOffset()) - return false; +enum class MMALoadType { + SharedV3, + Registers, // may be v2 or v3 + DoNotPipeline, // could be a valid shared/registers MMA operand, but skip pipelining +}; - // In case LHS is in registers, don't pipeline for now TODO(ggengnv) is this necessary? - auto op = *alloc->getUsers().begin(); - if (auto localLoad = dyn_cast(op)) { - auto resTy = cast(localLoad->getResultTypes()[0]); - if (!resTy || isa(resTy.getEncoding())) - return false; - } +static MMALoadType getMMALoadType(Operation *loadOp) { + if (!loadOp->hasOneUse()) + return MMALoadType::DoNotPipeline; + + if (auto alloc = dyn_cast(*loadOp->getUsers().begin())) { + auto sharedEnc = cast(alloc.getType().getEncoding()); + + // MMA V3 case. + auto newOrder = sharedEnc.getOrder(); + auto ty = cast(loadOp->getResultTypes()[0]); + auto oldOrder = ttg::getOrder(ty.getEncoding()); + + // The operand of MMAv3 is in SharedEncoding and its order should not + // be changed after FuseTranspositions Pass. So we only pipeline the + // load if the order of the loaded BlockedEncoding is the same as the + // order of the SharedEncoding it is converted to. + return oldOrder == newOrder ? MMALoadType::SharedV3 : MMALoadType::DoNotPipeline; + } else if (auto cvt = dyn_cast(*loadOp->getUsers().begin())) { + auto resTy = dyn_cast(cvt->getResultTypes()[0]); + if (!resTy) { + return MMALoadType::DoNotPipeline; + } - // MMA V3 case. - auto newOrder = sharedEnc.getOrder(); - auto ty = cast(loadOp->getResultTypes()[0]); - auto oldOrder = ttg::getOrder(ty.getEncoding()); + if (isa(resTy.getEncoding())) { + return MMALoadType::Registers; + } - // The operand of MMAv3 is in SharedEncoding and its order should not - // be changed after FuseTranspositions Pass. So we only pipeline the - // load if the order of the loaded BlockedEncoding is the same as the - // order of the SharedEncoding it is converted to. - return oldOrder == newOrder; + return MMALoadType::DoNotPipeline; + } else { + return MMALoadType::DoNotPipeline; + } } static llvm::MapVector @@ -438,15 +448,22 @@ assignMemoryLayouts(llvm::SmallVector> } if (use->hasTrait()) { + auto mmaLoadType = getMMALoadType(op); + auto dot = dyn_cast(use); + auto warpGroupDot = dyn_cast(use); + loadInfo.usedByDot = true; - if (loadIsMMAv3(op)) { - loadInfo.loadIsMMAV3 = true; + loadInfo.loadIsMMAv3Shared = mmaLoadType == MMALoadType::SharedV3; + loadInfo.loadIsMMAv3Registers = (mmaLoadType == MMALoadType::Registers) + && warpGroupDot; + + if (loadInfo.loadIsMMAv3Shared) { loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); } else if (isa(op)) { loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); - } else if (auto dot = dyn_cast(use)) { + } else if (loadInfo.loadIsMMAv3Registers || dot) { loadInfo.sharedEncoding = getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); } @@ -467,7 +484,7 @@ assignMemoryLayouts(llvm::SmallVector> // encoding. if (!loadInfo.sharedEncoding && !isa(use)) { loadInfo.sharedEncoding = - getSharedEncoding(op, /*isMMAV3=*/loadInfo.loadIsMMAV3) + getSharedEncoding(op, /*isMMAV3=*/loadInfo.loadIsMMAv3Shared) .value_or(nullptr); if (auto loadOp = dyn_cast(op)) { loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis); @@ -904,7 +921,7 @@ static void createTMABarrierAndWait( if (it != loadToInfo.end()) { // Special case for MMAv3 loads, we can ignore the alloc and only // consider uses of the alloc op since it will be removed. - if (it->second.loadIsMMAV3) { + if (it->second.loadIsMMAv3Shared) { auto alloc = cast( (*loadInfo->loadOp->getUsers().begin())); if (alloc->getBlock() == loadBlock) { @@ -995,7 +1012,8 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, return lhs.distToUse < rhs.distToUse; })->distToUse; bool hasMMAV3 = - llvm::any_of(loadToInfo, [](auto &kv) { return kv.second.loadIsMMAV3; }); + llvm::any_of(loadToInfo, [](auto &kv) { + return kv.second.loadIsMMAv3Shared || kv.second.loadIsMMAv3Registers; }); if (hasMMAV3) { // For MMAv3, we need an extra buffer as this is assumed in the wgmma // pipelining post-processing. From be4e1e388bee41a9e76f1cf0c16bd3acd68d73aa Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Thu, 19 Sep 2024 20:01:31 +0000 Subject: [PATCH 13/18] Add pipeline test --- test/TritonGPU/loop-pipeline-hopper.mlir | 64 ++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 2c2182154d6a..206bdf29ed14 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -931,3 +931,67 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %17#0 : tensor<128x16xf32, #mma1> } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK-LABEL: dot_lhs_registers + tt.func @dot_lhs_registers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma> { + %cst = arith.constant dense<0> : tensor<64x16xi32, #blocked> + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked> + %cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1> + %c0_i64 = arith.constant 0 : i64 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma> + %cst_3 = arith.constant dense<0> : tensor<128x64xi32, #blocked1> + %cst_4 = arith.constant dense<2.0> : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.addptr %arg0, %c0_i64 : !tt.ptr, i64 + %1 = tt.addptr %arg1, %c0_i64 : !tt.ptr, i64 + %2 = tt.splat %1 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %6 = tt.broadcast %3 : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x64x!tt.ptr, #blocked1> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> + %8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %9 = tt.load %8 : tensor<128x64x!tt.ptr, #blocked1> + %10 = tt.splat %0 : !tt.ptr -> tensor<1x16x!tt.ptr, #blocked> + %11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr, #blocked>, tensor<1x16xi32, #blocked> + %12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> + %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> + %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + // CHECK: scf.for + // CHECK: triton_gpu.local_load + // CHECK: triton_gpu.async_wait {{.*}} {num = 2 : i32} + // CHECK: triton_nvidia_gpu.warp_group_dot + // CHECK-NEXT: triton_nvidia_gpu.warp_group_dot_wait {{.*}} {pendings = 1 : i32} + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: triton_gpu.async_commit_group + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK: triton_gpu.async_commit_group + // CHECK: scf.yield + %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %8, %arg6 = %16) -> (tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, + tensor<64x16x!tt.ptr, #blocked>) : i32 { + %a_block = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked1> + %b_block = tt.load %arg6 : tensor<64x16x!tt.ptr, #blocked> + %a_dotop = triton_gpu.convert_layout %a_block : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %a_dotop_mul = arith.mulf %a_dotop, %cst_4 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %b_smem = triton_gpu.local_alloc %b_block : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %21 = triton_nvidia_gpu.warp_group_dot %a_dotop_mul, %b_smem, %arg4 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma> + %25 = tt.addptr %arg5, %cst_3 : tensor<128x64x!tt.ptr, #blocked1>, tensor<128x64xi32, #blocked1> + %26 = tt.addptr %arg6, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> + scf.yield %21, %25, %26 : tensor<128x16xf32, #mma>, tensor<128x64x!tt.ptr, #blocked1>, tensor<64x16x!tt.ptr, #blocked> + } + tt.return %17#0 : tensor<128x16xf32, #mma> + } +} + From 34b46c6a3b5c5531fd17a6eab7259573f1fdbff9 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Mon, 23 Sep 2024 18:14:43 +0000 Subject: [PATCH 14/18] Refactor MatmulLoopPipeline --- .../Pipeliner/MatmulLoopPipeline.cpp | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 3e6ee52231b1..87e8be14ed58 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -44,8 +44,8 @@ struct LoadInfo { ttg::SharedEncodingAttr sharedEncoding = nullptr; // Blocked encoding is used for loads not used by the dot. ttg::BlockedEncodingAttr blockedEncoding = nullptr; - bool loadIsMMAv3Shared = false; - bool loadIsMMAv3Registers = false; + bool isMMAv3Shared = false; + bool isMMAv3Registers = false; int distToUse = 0; bool usedByDot = false; }; @@ -103,7 +103,7 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Operation *wait = builder.create(loc, commmit->getResult(0), 0); - auto loadIsMMAv3Shared = loadToInfo[loadOp].loadIsMMAv3Shared; + auto loadIsMMAv3Shared = loadToInfo[loadOp].isMMAv3Shared; auto [stage, cluster] = schedule[loadOp]; schedule.erase(loadOp); schedule.insert(copy, stage, cluster); @@ -179,7 +179,7 @@ static void createTMAAsyncCopy( Operation *copy = builder.create( loc, loadOp.getDescPtr(), loadOp.getIndices(), barrier, view, pred); - auto loadIsMMAv3Shared = loadToInfo[loadOp].loadIsMMAv3Shared; + auto loadIsMMAv3Shared = loadToInfo[loadOp].isMMAv3Shared; auto [stage, cluster] = schedule[loadOp]; schedule.erase(loadOp); schedule.insert(copy, stage, cluster); @@ -453,17 +453,20 @@ assignMemoryLayouts(llvm::SmallVector> auto warpGroupDot = dyn_cast(use); loadInfo.usedByDot = true; - loadInfo.loadIsMMAv3Shared = mmaLoadType == MMALoadType::SharedV3; - loadInfo.loadIsMMAv3Registers = (mmaLoadType == MMALoadType::Registers) + loadInfo.isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3; + loadInfo.isMMAv3Registers = (mmaLoadType == MMALoadType::Registers) && warpGroupDot; - if (loadInfo.loadIsMMAv3Shared) { + if (loadInfo.isMMAv3Shared) { loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); } else if (isa(op)) { loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); - } else if (loadInfo.loadIsMMAv3Registers || dot) { + } else if (loadInfo.isMMAv3Registers || dot) { + // if warpGroupDot, we must now have operand A in registers since + // loadIsMMAv3Shared is false from above if-check + loadInfo.sharedEncoding = getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); } @@ -484,7 +487,7 @@ assignMemoryLayouts(llvm::SmallVector> // encoding. if (!loadInfo.sharedEncoding && !isa(use)) { loadInfo.sharedEncoding = - getSharedEncoding(op, /*isMMAV3=*/loadInfo.loadIsMMAv3Shared) + getSharedEncoding(op, /*isMMAV3=*/loadInfo.isMMAv3Shared) .value_or(nullptr); if (auto loadOp = dyn_cast(op)) { loadInfo.blockedEncoding = getBlockedEncoding(loadOp, axisInfoAnalysis); @@ -921,7 +924,7 @@ static void createTMABarrierAndWait( if (it != loadToInfo.end()) { // Special case for MMAv3 loads, we can ignore the alloc and only // consider uses of the alloc op since it will be removed. - if (it->second.loadIsMMAv3Shared) { + if (it->second.isMMAv3Shared) { auto alloc = cast( (*loadInfo->loadOp->getUsers().begin())); if (alloc->getBlock() == loadBlock) { @@ -1013,7 +1016,7 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, })->distToUse; bool hasMMAV3 = llvm::any_of(loadToInfo, [](auto &kv) { - return kv.second.loadIsMMAv3Shared || kv.second.loadIsMMAv3Registers; }); + return kv.second.isMMAv3Shared || kv.second.isMMAv3Registers; }); if (hasMMAV3) { // For MMAv3, we need an extra buffer as this is assumed in the wgmma // pipelining post-processing. From 579f8f2da1d6b6a70c37dbbc6d14d621193cf07b Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Fri, 20 Sep 2024 20:42:10 +0000 Subject: [PATCH 15/18] Improve coalescing for global to local copy --- .../Pipeliner/MatmulLoopPipeline.cpp | 68 +++++++++++++++---- 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 87e8be14ed58..f674324b75c0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -66,26 +66,70 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Value src = loadOp.getPtr(); Value mask = loadOp.getMask(); Value other = loadOp.getOther(); + tt::MemDescType allocTy = cast(alloc.getType()); + + auto convertBlockLayout = [&](Value val, ttg::BlockedEncodingAttr enc) { + auto ty = cast(val.getType()); + auto newTy = + RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); + auto cvt = + builder.create(loc, newTy, val); + return cvt.getResult(); + }; + if (!isExpensiveLoadOrStore(loadOp) && loadToInfo[loadOp].blockedEncoding) { // For inexpensive loads that do not directly feed into dot ops // we want to use optimal layout for the data. ttg::BlockedEncodingAttr encoding = loadToInfo[loadOp].blockedEncoding; - auto convertBlockLayout = [&](Value src) { - auto ty = cast(src.getType()); - auto newTy = - RankedTensorType::get(ty.getShape(), ty.getElementType(), encoding); - auto cvt = - builder.create(loadOp->getLoc(), newTy, src); - return cvt.getResult(); - }; - src = convertBlockLayout(src); + src = convertBlockLayout(src, encoding); if (mask) - mask = convertBlockLayout(mask); + mask = convertBlockLayout(mask, encoding); if (other) - other = convertBlockLayout(other); + other = convertBlockLayout(other, encoding); + } else if (loadToInfo[loadOp].isMMAv3Registers) { + // If the following are true... + // 1) Operand A is for WGMMA and is to be loaded in registers + // 2) We upcast operand A in registers before the WGMMA + // (downcasting is not yet supporting) + // + // ...then the SharedEncoding vec will be less than BlockedEncoding's + // sizePerThread, for k-dim. E.g. if shared vec is 8 and sizePerThread + // for k is 16, then AsyncCopyGlobalToLocal will generate two 8B-LDGSTS + // for each contiguous 16B global data owned by each thread. This breaks + // coalescing. + // + // The fix is to clip the BlockedEnc's sizePerThread using SharedEnc's vec. + auto tensorTy = cast(src.getType()); + auto blockEnc = cast(tensorTy.getEncoding()); + auto sharedEnc = cast(allocTy.getEncoding()); + auto sharedVec = sharedEnc.getVec(); + + SmallVector newSizePerThread; + llvm::transform(blockEnc.getSizePerThread(), + std::back_inserter(newSizePerThread), + [&](auto size) { return std::min(size, sharedVec); }); + + if (newSizePerThread != blockEnc.getSizePerThread()) { + auto mod = loadOp->getParentOfType(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); + auto newBlockEnc = ttg::BlockedEncodingAttr::get( + loadOp.getContext(), + tensorTy.getShape(), + newSizePerThread, + blockEnc.getOrder(), + numWarps, + threadsPerWarp, + blockEnc.getCTALayout()); + + src = convertBlockLayout(src, newBlockEnc); + if (mask) + mask = convertBlockLayout(mask, newBlockEnc); + if (other) + other = convertBlockLayout(other, newBlockEnc); + } } - tt::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); copyOffsets[0] = insertIdx; Attribute sharedMemorySpace = From ba48bf1e3298a5c740a1db41cd826507d573d2ea Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Fri, 20 Sep 2024 20:45:13 +0000 Subject: [PATCH 16/18] fix typo --- .../TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index f674324b75c0..8988f118e01a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -90,13 +90,13 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, // If the following are true... // 1) Operand A is for WGMMA and is to be loaded in registers // 2) We upcast operand A in registers before the WGMMA - // (downcasting is not yet supporting) + // (downcasting is not yet supported) // // ...then the SharedEncoding vec will be less than BlockedEncoding's - // sizePerThread, for k-dim. E.g. if shared vec is 8 and sizePerThread - // for k is 16, then AsyncCopyGlobalToLocal will generate two 8B-LDGSTS + // sizePerThread for k-dim. E.g. if shared vec is 8 and sizePerThread + // for k is 16, then AsyncCopyGlobalToLocal will generate two 8B-LDGSTS's // for each contiguous 16B global data owned by each thread. This breaks - // coalescing. + // coalescing (i.e. results 2x the minimum required transactions) // // The fix is to clip the BlockedEnc's sizePerThread using SharedEnc's vec. auto tensorTy = cast(src.getType()); From 693a71917d28647f4808eaed32b06fc6fd9c16a5 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Mon, 23 Sep 2024 23:35:56 +0000 Subject: [PATCH 17/18] Remove old comment --- .../TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 8988f118e01a..7fefad02bdc0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -508,9 +508,6 @@ assignMemoryLayouts(llvm::SmallVector> loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); } else if (loadInfo.isMMAv3Registers || dot) { - // if warpGroupDot, we must now have operand A in registers since - // loadIsMMAv3Shared is false from above if-check - loadInfo.sharedEncoding = getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); } From 56eefde39d44191278cb4e3cbc9f729564d26914 Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Wed, 9 Oct 2024 22:19:27 +0000 Subject: [PATCH 18/18] Fix check in getMMALoadType --- .../TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 7fefad02bdc0..5d2ec2794df3 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -428,6 +428,9 @@ static MMALoadType getMMALoadType(Operation *loadOp) { if (auto alloc = dyn_cast(*loadOp->getUsers().begin())) { auto sharedEnc = cast(alloc.getType().getEncoding()); + if (!sharedEnc.getHasLeadingOffset()) + return MMALoadType::DoNotPipeline; + // MMA V3 case. auto newOrder = sharedEnc.getOrder(); auto ty = cast(loadOp->getResultTypes()[0]);