|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +# ruff: noqa |
| 4 | +import torch |
| 5 | +from reference import naive_nsa_simple_inference |
| 6 | +import tilelang |
| 7 | +from tilelang import language as T |
| 8 | +import tilelang.testing |
| 9 | + |
| 10 | +tilelang.testing.set_random_seed(42) |
| 11 | + |
| 12 | + |
| 13 | +def native_sparse_attention( |
| 14 | + batch, |
| 15 | + heads, |
| 16 | + seq_len, # Length of K/V sequences (context window size) |
| 17 | + dim, # Embedding dimension per head |
| 18 | + scale=None, |
| 19 | + block_size=64, # Tile size for attention computation |
| 20 | + groups=1, # Grouped query attention (GQA) groups |
| 21 | + selected_blocks=16 # Number of blocks to select per attention head |
| 22 | +): |
| 23 | + if scale is None: |
| 24 | + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) |
| 25 | + head_kv = heads // groups |
| 26 | + # Modified shapes for inference (q has seq_len=1) |
| 27 | + q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 |
| 28 | + kv_shape = [batch, seq_len, head_kv, dim] |
| 29 | + block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1 |
| 30 | + block_indices_dtype = "int32" |
| 31 | + dtype = "float16" |
| 32 | + accum_dtype = "float" |
| 33 | + block_S = block_size |
| 34 | + block_T = min(128, tilelang.math.next_power_of_2(dim)) |
| 35 | + |
| 36 | + NK = tilelang.cdiv(dim, block_T) |
| 37 | + NV = tilelang.cdiv(dim, block_T) |
| 38 | + assert NK == 1, "The key dimension can not be larger than 256" |
| 39 | + |
| 40 | + S = selected_blocks |
| 41 | + G = groups |
| 42 | + BS = block_S |
| 43 | + BK = BV = block_T |
| 44 | + num_stages = 0 |
| 45 | + threads = 32 |
| 46 | + |
| 47 | + @T.prim_func |
| 48 | + def native_sparse_attention( |
| 49 | + Q: T.Buffer(q_shape, dtype), # [batch, 1, heads, dim] |
| 50 | + K: T.Buffer(kv_shape, dtype), # [batch, seq_len, head_kv, dim] |
| 51 | + V: T.Buffer(kv_shape, dtype), # Same shape as K |
| 52 | + BlockIndices: T.Buffer(block_indices_shape, |
| 53 | + block_indices_dtype), # Selected block indices |
| 54 | + Output: T.Buffer(q_shape, dtype), # Output attention tensor |
| 55 | + ): |
| 56 | + with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz): |
| 57 | + # Shared memory allocations for tile storage |
| 58 | + Q_shared = T.alloc_shared([G, BK], dtype) # Current query block |
| 59 | + K_shared = T.alloc_shared([BS, BK], dtype) # Current key block |
| 60 | + V_shared = T.alloc_shared([BS, BV], dtype) # Current value block |
| 61 | + O_shared = T.alloc_shared([G, BV], dtype) # Output accumulator |
| 62 | + |
| 63 | + # Attention computation buffers |
| 64 | + acc_s = T.alloc_fragment([G, BS], accum_dtype) # QK^T scores |
| 65 | + acc_s_cast = T.alloc_fragment([G, BS], dtype) # Casted scores for softmax |
| 66 | + acc_o = T.alloc_fragment([G, BV], accum_dtype) # Output accumulator |
| 67 | + scores_max = T.alloc_fragment([G], accum_dtype) |
| 68 | + scores_max_prev = T.alloc_fragment([G], accum_dtype) |
| 69 | + scores_scale = T.alloc_fragment([G], accum_dtype) |
| 70 | + scores_sum = T.alloc_fragment([G], accum_dtype) |
| 71 | + logsum = T.alloc_fragment([G], accum_dtype) |
| 72 | + |
| 73 | + i_v, i_bh = by, bz |
| 74 | + i_b, i_h = i_bh // head_kv, i_bh % head_kv |
| 75 | + |
| 76 | + NS = S |
| 77 | + # Copy Q for the single position |
| 78 | + T.copy(Q[i_b, 0, i_h * G:(i_h + 1) * G, :], Q_shared) # Changed i_t to 0 |
| 79 | + |
| 80 | + T.fill(acc_o, 0) |
| 81 | + T.fill(logsum, 0) |
| 82 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 83 | + |
| 84 | + # Main attention computation loop over selected blocks |
| 85 | + for i in T.Pipelined(NS, num_stages=num_stages): |
| 86 | + i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset |
| 87 | + if i_s >= 0: # Skip invalid/padding blocks |
| 88 | + # Load current key block to shared memory |
| 89 | + T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared) |
| 90 | + |
| 91 | + # Compute QK^T attention scores |
| 92 | + T.clear(acc_s) |
| 93 | + T.gemm( |
| 94 | + Q_shared, |
| 95 | + K_shared, |
| 96 | + acc_s, |
| 97 | + transpose_B=True, |
| 98 | + policy=T.GemmWarpPolicy.FullRow) |
| 99 | + |
| 100 | + # Online softmax with numerical stability |
| 101 | + # 1. Compute max for scaling |
| 102 | + # 2. Compute exponentials and sum |
| 103 | + # 3. Maintain running logsum for normalization |
| 104 | + T.copy(scores_max, scores_max_prev) |
| 105 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 106 | + T.reduce_max(acc_s, scores_max, dim=1, clear=True) |
| 107 | + |
| 108 | + for i in T.Parallel(G): |
| 109 | + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) |
| 110 | + for i, j in T.Parallel(G, BS): |
| 111 | + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) |
| 112 | + T.reduce_sum(acc_s, scores_sum, dim=1) |
| 113 | + for i in T.Parallel(G): |
| 114 | + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] |
| 115 | + T.copy(acc_s, acc_s_cast) |
| 116 | + |
| 117 | + # Accumulate attention-weighted values |
| 118 | + T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared) |
| 119 | + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) |
| 120 | + |
| 121 | + # Final normalization and output |
| 122 | + for i, j in T.Parallel(G, BV): |
| 123 | + acc_o[i, j] /= logsum[i] # Normalize by logsum |
| 124 | + T.copy(acc_o, O_shared) |
| 125 | + T.copy(O_shared, Output[i_b, 0, i_h * G:(i_h + 1) * G, |
| 126 | + i_v * BV:(i_v + 1) * BV]) # Changed i_t to 0 |
| 127 | + |
| 128 | + return native_sparse_attention |
| 129 | + |
| 130 | + |
| 131 | +if __name__ == "__main__": |
| 132 | + B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 |
| 133 | + groups = HQ // H |
| 134 | + SEQ_LEN_Q = 1 |
| 135 | + program = native_sparse_attention( |
| 136 | + batch=B, |
| 137 | + heads=HQ, |
| 138 | + seq_len=SEQ_LEN, |
| 139 | + dim=D, |
| 140 | + block_size=block_size, |
| 141 | + groups=HQ // H, |
| 142 | + selected_blocks=S, |
| 143 | + ) |
| 144 | + |
| 145 | + kernel = tilelang.compile(program, out_idx=-1) |
| 146 | + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) |
| 147 | + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) |
| 148 | + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) |
| 149 | + |
| 150 | + mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device='cuda') |
| 151 | + DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda') |
| 152 | + |
| 153 | + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device='cuda') |
| 154 | + for b in range(B): |
| 155 | + for t in range(SEQ_LEN_Q): |
| 156 | + for h in range(H): |
| 157 | + i_i = torch.randperm(max(1, (t // block_size)))[:S] |
| 158 | + block_indices[b, t, h, :len(i_i)] = i_i |
| 159 | + block_indices = block_indices.sort(-1)[0] |
| 160 | + block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device='cuda') |
| 161 | + |
| 162 | + out = kernel(Q, K, V, block_indices.to(torch.int32)) |
| 163 | + |
| 164 | + ref = naive_nsa_simple_inference( |
| 165 | + q=Q, |
| 166 | + k=K, |
| 167 | + v=V, |
| 168 | + block_indices=block_indices, |
| 169 | + block_counts=block_counts, |
| 170 | + block_size=block_size, |
| 171 | + ) |
| 172 | + print("out", out) |
| 173 | + print("ref", ref) |
| 174 | + torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) |
0 commit comments