Skip to content

Commit d8a06c0

Browse files
authored
[Example] Implement NSA Decode tilelang exampls (#168)
* [Refactor] Update BitBLAS Benchmark with TileLang Carver Imports and Roller Hints Generation - Replace BitBLAS imports with TileLang Carver imports in benchmark_matmul.py - Modify roller hints generation using new TileLang Carver template and utility functions - Update get_roller_hints_from_func to handle None cases and improve return logic - Adjust DefaultPolicy to handle different codegen dictionary formats * [Refactor] Update Thread Binding and Import Statements in TileLang Kernels - Replace T.thread_binding() with T.get_thread_binding() across multiple kernel test files - Update import statements for MMA layout and macro generator in dequantize GEMM and FP8 examples - Move map_torch_type utility function to tilelang.utils.tensor - Remove unnecessary imports and improve code organization * Refactor Native Sparse Attention Example with Enhanced Triton Kernel - Update parallel_nsa_fwd_kernel to support more flexible sparse attention computation - Add support for block counts and offsets in the Triton kernel - Modify kernel grid and computation logic for improved performance - Update example script to use naive_nsa_simple reference implementation - Improve type hints and kernel configuration * Add Native Sparse Attention Examples with Tilelang and Triton Implementations - Introduce new example scripts for native sparse attention: * example_tilelang_nsa_fwd.py: Forward pass implementation using TileLang * example_tilelang_nsa_decode.py: Decoding-specific sparse attention implementation * example_triton_nsa_fwd.py: Triton-based sparse attention forward pass - Update reference.py with naive implementations for sparse attention - Support different sparse attention scenarios including forward pass and inference - Add comprehensive testing and validation against reference implementations * lint fix
1 parent 5a63e65 commit d8a06c0

5 files changed

+499
-181
lines changed

examples/native_sparse_attention/example_tilelang_nsa.py

-181
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
# ruff: noqa
4+
import torch
5+
from reference import naive_nsa_simple_inference
6+
import tilelang
7+
from tilelang import language as T
8+
import tilelang.testing
9+
10+
tilelang.testing.set_random_seed(42)
11+
12+
13+
def native_sparse_attention(
14+
batch,
15+
heads,
16+
seq_len, # Length of K/V sequences (context window size)
17+
dim, # Embedding dimension per head
18+
scale=None,
19+
block_size=64, # Tile size for attention computation
20+
groups=1, # Grouped query attention (GQA) groups
21+
selected_blocks=16 # Number of blocks to select per attention head
22+
):
23+
if scale is None:
24+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
25+
head_kv = heads // groups
26+
# Modified shapes for inference (q has seq_len=1)
27+
q_shape = [batch, 1, heads, dim] # Changed seq_len to 1
28+
kv_shape = [batch, seq_len, head_kv, dim]
29+
block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1
30+
block_indices_dtype = "int32"
31+
dtype = "float16"
32+
accum_dtype = "float"
33+
block_S = block_size
34+
block_T = min(128, tilelang.math.next_power_of_2(dim))
35+
36+
NK = tilelang.cdiv(dim, block_T)
37+
NV = tilelang.cdiv(dim, block_T)
38+
assert NK == 1, "The key dimension can not be larger than 256"
39+
40+
S = selected_blocks
41+
G = groups
42+
BS = block_S
43+
BK = BV = block_T
44+
num_stages = 0
45+
threads = 32
46+
47+
@T.prim_func
48+
def native_sparse_attention(
49+
Q: T.Buffer(q_shape, dtype), # [batch, 1, heads, dim]
50+
K: T.Buffer(kv_shape, dtype), # [batch, seq_len, head_kv, dim]
51+
V: T.Buffer(kv_shape, dtype), # Same shape as K
52+
BlockIndices: T.Buffer(block_indices_shape,
53+
block_indices_dtype), # Selected block indices
54+
Output: T.Buffer(q_shape, dtype), # Output attention tensor
55+
):
56+
with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz):
57+
# Shared memory allocations for tile storage
58+
Q_shared = T.alloc_shared([G, BK], dtype) # Current query block
59+
K_shared = T.alloc_shared([BS, BK], dtype) # Current key block
60+
V_shared = T.alloc_shared([BS, BV], dtype) # Current value block
61+
O_shared = T.alloc_shared([G, BV], dtype) # Output accumulator
62+
63+
# Attention computation buffers
64+
acc_s = T.alloc_fragment([G, BS], accum_dtype) # QK^T scores
65+
acc_s_cast = T.alloc_fragment([G, BS], dtype) # Casted scores for softmax
66+
acc_o = T.alloc_fragment([G, BV], accum_dtype) # Output accumulator
67+
scores_max = T.alloc_fragment([G], accum_dtype)
68+
scores_max_prev = T.alloc_fragment([G], accum_dtype)
69+
scores_scale = T.alloc_fragment([G], accum_dtype)
70+
scores_sum = T.alloc_fragment([G], accum_dtype)
71+
logsum = T.alloc_fragment([G], accum_dtype)
72+
73+
i_v, i_bh = by, bz
74+
i_b, i_h = i_bh // head_kv, i_bh % head_kv
75+
76+
NS = S
77+
# Copy Q for the single position
78+
T.copy(Q[i_b, 0, i_h * G:(i_h + 1) * G, :], Q_shared) # Changed i_t to 0
79+
80+
T.fill(acc_o, 0)
81+
T.fill(logsum, 0)
82+
T.fill(scores_max, -T.infinity(accum_dtype))
83+
84+
# Main attention computation loop over selected blocks
85+
for i in T.Pipelined(NS, num_stages=num_stages):
86+
i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset
87+
if i_s >= 0: # Skip invalid/padding blocks
88+
# Load current key block to shared memory
89+
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
90+
91+
# Compute QK^T attention scores
92+
T.clear(acc_s)
93+
T.gemm(
94+
Q_shared,
95+
K_shared,
96+
acc_s,
97+
transpose_B=True,
98+
policy=T.GemmWarpPolicy.FullRow)
99+
100+
# Online softmax with numerical stability
101+
# 1. Compute max for scaling
102+
# 2. Compute exponentials and sum
103+
# 3. Maintain running logsum for normalization
104+
T.copy(scores_max, scores_max_prev)
105+
T.fill(scores_max, -T.infinity(accum_dtype))
106+
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
107+
108+
for i in T.Parallel(G):
109+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
110+
for i, j in T.Parallel(G, BS):
111+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
112+
T.reduce_sum(acc_s, scores_sum, dim=1)
113+
for i in T.Parallel(G):
114+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
115+
T.copy(acc_s, acc_s_cast)
116+
117+
# Accumulate attention-weighted values
118+
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
119+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
120+
121+
# Final normalization and output
122+
for i, j in T.Parallel(G, BV):
123+
acc_o[i, j] /= logsum[i] # Normalize by logsum
124+
T.copy(acc_o, O_shared)
125+
T.copy(O_shared, Output[i_b, 0, i_h * G:(i_h + 1) * G,
126+
i_v * BV:(i_v + 1) * BV]) # Changed i_t to 0
127+
128+
return native_sparse_attention
129+
130+
131+
if __name__ == "__main__":
132+
B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16
133+
groups = HQ // H
134+
SEQ_LEN_Q = 1
135+
program = native_sparse_attention(
136+
batch=B,
137+
heads=HQ,
138+
seq_len=SEQ_LEN,
139+
dim=D,
140+
block_size=block_size,
141+
groups=HQ // H,
142+
selected_blocks=S,
143+
)
144+
145+
kernel = tilelang.compile(program, out_idx=-1)
146+
Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
147+
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
148+
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
149+
150+
mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device='cuda')
151+
DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda')
152+
153+
block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device='cuda')
154+
for b in range(B):
155+
for t in range(SEQ_LEN_Q):
156+
for h in range(H):
157+
i_i = torch.randperm(max(1, (t // block_size)))[:S]
158+
block_indices[b, t, h, :len(i_i)] = i_i
159+
block_indices = block_indices.sort(-1)[0]
160+
block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device='cuda')
161+
162+
out = kernel(Q, K, V, block_indices.to(torch.int32))
163+
164+
ref = naive_nsa_simple_inference(
165+
q=Q,
166+
k=K,
167+
v=V,
168+
block_indices=block_indices,
169+
block_counts=block_counts,
170+
block_size=block_size,
171+
)
172+
print("out", out)
173+
print("ref", ref)
174+
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)

0 commit comments

Comments
 (0)