Skip to content

Commit 59342bb

Browse files
authored
[Release] Bumpy version to v0.1.1 (#107)
* Remove Torch CPP backend and update execution backend options - Remove TorchCPPKernelAdapter and related code from JIT modules - Update execution backend options in jit/__init__.py, kernel.py, and adapter/__init__.py - Remove "torch_cpp" from supported execution backend literals - Simplify backend validation and remove unused torch_cpp-related code 。 * lint fix * Add block sparse attention implementations for TileLang and Triton - Implement block sparse attention kernels for TileLang and Triton - Add example scripts for block sparse attention with top-k and threshold-based masking - Include utility functions for generating sparse attention masks - Demonstrate causal attention with block-level sparsity - Add test cases to validate sparse attention implementations against PyTorch reference * Bump version to 0.1.1 * Refactor block sparse attention examples for improved code quality - Apply consistent code formatting and style in TileLang and Triton block sparse attention implementations - Add ruff linter ignore comment for specific line in Triton implementation - Improve readability by adjusting indentation and line breaks - Standardize sparse mask generation and test function implementations - Minor optimizations in test case configurations * lint
1 parent 0b1bcc5 commit 59342bb

File tree

4 files changed

+165
-133
lines changed

4 files changed

+165
-133
lines changed

MANIFEST.in

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ include CMakeLists.txt
33
include requirements.txt
44
include requirements-test.txt
55
include requirements-dev.txt
6+
include tilelang/jit/adapter/cython/cython_wrapper.pyx
67
recursive-include src *
78
recursive-include 3rdparty *
89
recursive-exclude 3rdparty/clang* *

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.1.0
1+
0.1.1

examples/blocksparse_attention/block_sparse_attn_tilelang.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,28 @@
77
import tilelang.language as T
88
import torch.nn.functional as F
99

10+
1011
def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False):
1112
bsz, num_head, downsample_len, _ = x.shape
1213
# N_CTX = downsample_len * BLOCK
1314
sparse_index = torch.topk(x, topk, dim=-1).indices
14-
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device)
15+
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len],
16+
False,
17+
dtype=torch.bool,
18+
device=x.device)
1519
dense_mask.scatter_(-1, sparse_index, True)
1620
if use_dense_for_last_block:
17-
dense_mask[:, :,-2:,:] = True
21+
dense_mask[:, :, -2:, :] = True
1822
dense_mask.tril_()
19-
return dense_mask
23+
return dense_mask
2024

2125

2226
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False):
23-
dense_mask = x > threshold
27+
dense_mask = x > threshold
2428
if use_dense_for_last_block:
25-
dense_mask[:, :,-2:,:] = True
29+
dense_mask[:, :, -2:, :] = True
2630
dense_mask.tril_()
27-
return dense_mask
31+
return dense_mask
2832

2933

3034
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
@@ -136,7 +140,7 @@ def main(
136140
scores_sum = T.alloc_fragment([block_M], accum_dtype)
137141
logsum = T.alloc_fragment([block_M], accum_dtype)
138142
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
139-
143+
140144
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
141145
T.fill(acc_o, 0)
142146
T.fill(logsum, 0)
@@ -165,6 +169,7 @@ def main(
165169

166170
return kernel_func(block_M, block_N, num_stages, threads)
167171

172+
168173
def test_topk_sparse_attention():
169174
# Config
170175
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64
@@ -177,13 +182,15 @@ def test_topk_sparse_attention():
177182
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
178183
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16)
179184

180-
sm_scale = 1.0 / (D_HEAD ** 0.5)
185+
sm_scale = 1.0 / (D_HEAD**0.5)
181186

182187
# Create sparse mask (downsampled to block level)
183188
downsample_factor = BLOCK
184189
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
185-
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16)
186-
x_ds[:,:,:,0] = 100
190+
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len],
191+
device='cuda',
192+
dtype=torch.bfloat16)
193+
x_ds[:, :, :, 0] = 100
187194
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK)
188195

189196
# Run Triton kernel
@@ -194,25 +201,24 @@ def test_topk_sparse_attention():
194201

195202
# Compute reference
196203
# Expand block mask to full attention matrix
197-
full_mask = torch.kron(block_mask.float(),
198-
torch.ones(BLOCK, BLOCK, device='cuda'))
204+
full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda'))
199205
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool()
200206
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal
201-
207+
202208
# PyTorch reference implementation
203209
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale
204210
attn = attn.masked_fill(~full_mask, float('-inf'))
205211
attn = F.softmax(attn, dim=-1)
206212
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v)
207-
213+
208214
print("ref_output", ref_output)
209215
print("tilelang_output", tilelang_output)
210216

211-
212217
# Verify accuracy
213218
assert torch.allclose(tilelang_output, ref_output, atol=1e-2, rtol=1e-2), \
214219
"TileLang output doesn't match reference"
215220
print("Pass topk sparse attention test with qlen == klen")
216221

222+
217223
if __name__ == "__main__":
218224
test_topk_sparse_attention()

0 commit comments

Comments
 (0)