-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathexample_mla_decode.py
297 lines (269 loc) · 14.1 KB
/
example_mla_decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro
def flash_attn(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_attn_split(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main_split(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
if num_split > 1:
return main_split
else:
return main_no_split
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = 64
num_split = 1
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [6], tilelang.TensorSupplyType.Randn)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All close")
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch")
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))