Skip to content

Commit 1a74878

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Add tilelang install and flash_attention impl (#180)
Summary: To install tilelang: ``` python install.py --tile ``` To test tilelang: ``` python run.py --op flash_attention --only tile,triton_tutorial_flash_v2 --metrics tflops (Batch, Heads, SeqLen, Dhead) tile-tflops triton_tutorial_flash_v2-tflops ------------------------------- ------------- --------------------------------- (4, 48, 128, 64) 40.7874 57.0654 (4, 48, 256, 64) 97.2592 132.104 (4, 48, 512, 64) 159.593 256.794 (4, 48, 1024, 64) 223.976 331.811 (4, 48, 2048, 64) 263.711 345.755 (4, 48, 4096, 64) 277.59 354.224 (4, 48, 8192, 64) 288.024 350.316 (4, 48, 16384, 64) 292.826 351.27 average 205.471 272.418 ``` Pull Request resolved: #180 Reviewed By: FindHao Differential Revision: D72250350 Pulled By: xuzhao9 fbshipit-source-id: 42947fd43d3e3fadce7adb70141c061787ee9e48
1 parent 4a78815 commit 1a74878

File tree

9 files changed

+343
-23
lines changed

9 files changed

+343
-23
lines changed

.ci/tritonbench/test-install.sh

-14
This file was deleted.

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ torch_compile_debug/
1919
build/
2020
/*.csv
2121
*.hatchet
22+
autotuner.log

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,12 @@ We depend on the following projects as a source of customized Triton or CUTLASS
5353
* (CUDA, HIP) [kernels](https://github.com/triton-lang/kernels)
5454
* (CUDA, HIP) [generative-recommenders](https://github.com/facebookresearch/generative-recommenders)
5555
* (CUDA, HIP) [Liger-Kernel](https://github.com/linkedin/Liger-Kernel)
56+
* (CUDA, HIP) [tilelang](https://github.com/tile-ai/tilelang)
5657
* (CUDA) [xformers](https://github.com/facebookresearch/xformers)
5758
* (CUDA) [flash-attention](https://github.com/Dao-AILab/flash-attention)
5859
* (CUDA) [FBGEMM](https://github.com/pytorch/FBGEMM)
5960
* (CUDA) [ThunderKittens](https://github.com/HazyResearch/ThunderKittens)
60-
* (CUDA) [cutlass-kernels](https://github.com/ColfaxResearch/cutlass-kernels)
61+
6162

6263

6364
## License

docker/tritonbench-nightly.dockerfile

-4
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ RUN cd /workspace/tritonbench && \
6060
RUN cd /workspace/tritonbench && \
6161
bash .ci/tritonbench/install.sh
6262

63-
# Test Tritonbench
64-
RUN cd /workspace/tritonbench && \
65-
bash .ci/tritonbench/test-install.sh
66-
6763
# Remove NVIDIA driver library - they are supposed to be mapped at runtime
6864
RUN sudo apt-get purge -y libnvidia-compute-550
6965

install.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ def setup_hip(args: argparse.Namespace):
125125
parser.add_argument("--tk", action="store_true", help="Install ThunderKittens")
126126
parser.add_argument("--liger", action="store_true", help="Install Liger-kernel")
127127
parser.add_argument("--xformers", action="store_true", help="Install xformers")
128+
parser.add_argument("--tile", action="store_true", help="install tile lang")
128129
parser.add_argument(
129130
"--all", action="store_true", help="Install all custom kernel repos"
130131
)
131-
parser.add_argument("--test", action="store_true", help="Run tests")
132132
args = parser.parse_args()
133133

134134
if args.all and is_hip():
@@ -166,6 +166,7 @@ def setup_hip(args: argparse.Namespace):
166166
if args.fbgemm or args.fbgemm_all or args.all:
167167
logger.info("[tritonbench] installing FBGEMM...")
168168
install_fbgemm(genai=(not args.fbgemm_all))
169+
test_fbgemm()
169170
if args.fa2 or args.all:
170171
logger.info("[tritonbench] installing fa2 from source...")
171172
install_fa2(compile=True)
@@ -182,6 +183,11 @@ def setup_hip(args: argparse.Namespace):
182183
from tools.tk.install import install_tk
183184

184185
install_tk()
186+
if args.tile:
187+
logger.info("[tritonbench] installing tilelang...")
188+
from tools.tilelang.install import install_tile
189+
190+
install_tile()
185191
if args.liger or args.all:
186192
logger.info("[tritonbench] installing liger-kernels...")
187193
install_liger()
@@ -191,6 +197,3 @@ def setup_hip(args: argparse.Namespace):
191197

192198
install_xformers()
193199
logger.info("[tritonbench] installation complete!")
194-
# run tests to check installation
195-
if args.test:
196-
test_fbgemm()

tools/tilelang/install.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
import subprocess
3+
import sys
4+
5+
REQUIREMENTS_FILE = os.path.join(
6+
os.path.dirname(os.path.abspath(__file__)), "requirements.txt"
7+
)
8+
9+
10+
def install_requirements(requirements_txt: str):
11+
# ignore dependencies to bypass reinstalling pytorch stable version
12+
cmd = ["pip", "install", "-r", requirements_txt, "--no-deps"]
13+
subprocess.check_call(cmd)
14+
15+
16+
def check_install():
17+
cmd = [sys.executable, "-c", "import tilelang"]
18+
subprocess.check_call(cmd)
19+
20+
21+
def install_tile():
22+
install_requirements(REQUIREMENTS_FILE)
23+
check_install()

tools/tilelang/requirements.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
tilelang==0.1.3
2+
Cython
3+
decorator

tritonbench/operators/flash_attention/operator.py

+40
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@
103103
except (ImportError, IOError, AttributeError, TypeError):
104104
HAS_XFORMERS = False
105105

106+
try:
107+
import tilelang
108+
109+
from .tilelang_mha import tilelang_mha
110+
111+
HAS_TILELANG = True
112+
except (ImportError, IOError, AttributeError, TypeError):
113+
HAS_TILELANG = False
114+
106115
# [Optional] colfax cutlass backend
107116
try:
108117
if not hasattr(torch.version, "git_version"):
@@ -467,6 +476,37 @@ def _inner():
467476

468477
return _inner
469478

479+
@register_benchmark(enabled=HAS_TILELANG)
480+
def tile(self, q, k, v):
481+
# [B, H, S, D] -> [B, S, H, D]
482+
q = q.transpose(1, 2).contiguous()
483+
k = k.transpose(1, 2).contiguous()
484+
v = v.transpose(1, 2).contiguous()
485+
best_config = tilelang_mha(
486+
self.BATCH,
487+
self.H,
488+
self.N_CTX,
489+
self.D_HEAD,
490+
self.causal,
491+
self.dtype,
492+
tune=True,
493+
)[1]
494+
func = tilelang_mha(
495+
self.BATCH,
496+
self.H,
497+
self.N_CTX,
498+
self.D_HEAD,
499+
self.causal,
500+
self.dtype,
501+
)(*best_config)
502+
jit_kernel = tilelang.compile(func, out_idx=[3])
503+
504+
def _inner():
505+
o = jit_kernel(q, k, v)
506+
return o
507+
508+
return _inner
509+
470510
@register_benchmark(enabled=False, label=f"cudnn")
471511
def cudnn(self, q, k, v):
472512
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"

0 commit comments

Comments
 (0)