Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions fla/ops/deltaformer/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def parallel_deltaformer_fwd_kernel(
)
q = tl.load(q_blk_ptr, boundary_check=(0,))

for kv_i in range(0, T, BLOCK_T):
for kv_i in range(0, T-C, BLOCK_T):
k_blk_ptr = tl.make_block_ptr(
base=k_ptr + pid_h * D,
shape=(D, T),
Expand All @@ -179,10 +179,6 @@ def parallel_deltaformer_fwd_kernel(
k = tl.load(k_blk_ptr, boundary_check=(1,))
qk = tl.dot(q, k) * qk_scale

if kv_i >= T - C:
mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)
qk = tl.where(mask, -1e6, qk)

rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1))
qk -= rowmax_i[:, None]
p = tl.math.exp2(qk)
Expand All @@ -193,17 +189,41 @@ def parallel_deltaformer_fwd_kernel(
acc = acc * alpha[:, None]
rowmax = rowmax_i

if kv_i < T - C:
u_blk_ptr = tl.make_block_ptr(
base=u_ptr + pid_h * D,
shape=(T, D),
strides=(H * D, 1),
offsets=(kv_i, 0),
block_shape=(BLOCK_T, D),
order=(1, 0),
)
u = tl.load(u_blk_ptr, boundary_check=(0,))
acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc)
u_blk_ptr = tl.make_block_ptr(
base=u_ptr + pid_h * D,
shape=(T, D),
strides=(H * D, 1),
offsets=(kv_i, 0),
block_shape=(BLOCK_T, D),
order=(1, 0),
)
u = tl.load(u_blk_ptr, boundary_check=(0,))
acc = tl.dot(p.to(u_ptr.dtype.element_ty), u, acc)

for kv_i in range(T-C, T, BLOCK_T):
k_blk_ptr = tl.make_block_ptr(
base=k_ptr + pid_h * D,
shape=(D, T),
strides=(1, H * D),
offsets=(0, kv_i),
block_shape=(D, BLOCK_T),
order=(0, 1),
)
k = tl.load(k_blk_ptr, boundary_check=(1,))
qk = tl.dot(q, k) * qk_scale

mask = (T - C - kv_i + rowid_block[:, None] - colid_block[None, :] < 1)
qk = tl.where(mask, -1e6, qk)

rowmax_i = tl.maximum(rowmax, tl.max(qk, axis=1))
qk -= rowmax_i[:, None]
p = tl.math.exp2(qk)

rowsum_i = tl.sum(p, axis=1)
alpha = tl.math.exp2(rowmax - rowmax_i)
rowsum = rowsum * alpha + rowsum_i
acc = acc * alpha[:, None]
rowmax = rowmax_i

lse = rowmax + tl.math.log2(rowsum)
lse_block_ptr = lse_ptr + pid_h + rowid_block * H
Expand All @@ -218,7 +238,7 @@ def parallel_deltaformer_fwd_kernel(
block_shape=(BLOCK_C, D),
order=(1, 0),
)
acc = acc / rowsum[:, None]
acc = acc / (rowsum[:, None] + 1e-9)

beta_ptr = tl.make_block_ptr(
base=beta_ptr + pid_h,
Expand Down Expand Up @@ -861,7 +881,7 @@ def _forward_impl(
betai = beta_full[b, i:i + Ci, :]

w, lse_chunk = parallel_deltaformer_chunk_fwd(qi, ki, vi, ui_prev, fa_scale, betai)
w = w * betai.unsqueeze(-1)
w = w * betai.unsqueeze(-1).to(torch.float32)
if need_aux:
wpad = torch.zeros(C, H, C, device=ko.device, dtype=ko.dtype)
wpad[:Ci, :, :Ci].copy_(w)
Expand Down
Loading