Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion evals/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

import fla # noqa
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM

import fla # noqa


@register_model('fla')
class FlashLinearAttentionLMWrapper(HFLM):
Expand Down
362 changes: 325 additions & 37 deletions fla/ops/gated_delta_product/chunk.py

Large diffs are not rendered by default.

43 changes: 30 additions & 13 deletions fla/ops/gated_delta_product/chunk_deltaproduct_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
chunk_offsets,
scale,
T,
num_householder: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
Expand All @@ -237,6 +238,7 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H

if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
Expand All @@ -245,7 +247,9 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
# boh = i_n * tl.cdiv(T, BT)
# Jinha: update boh to match the chunk_gated_delta_product_fwd_kernel_h_blockdim64 implementation
boh = i_n * tl.cdiv(T // num_householder, BT)

# [BK, BV]
b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
Expand Down Expand Up @@ -312,13 +316,13 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_g_exp = None

p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_wo = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))

b_wo = tl.load(p_wo, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dv = tl.zeros([BT, BV], dtype=tl.float32)

# Update dv
# Update dv based on hidden state gradients
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dv += tl.dot(b_k, b_dh1.to(b_k.dtype))
Expand All @@ -344,7 +348,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dv += tl.load(p_dv, boundary_check=(0, 1))

tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# Update dh

# Update hidden state gradients
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
Expand All @@ -353,7 +358,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dh1 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
b_q = (b_q * scale).to(b_q.dtype)
b_dh1 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
b_dh1 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype))

if K > 64:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
Expand All @@ -363,7 +369,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dh2 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
b_q = (b_q * scale).to(b_q.dtype)
b_dh2 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
b_dh2 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype))

if K > 128:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
Expand All @@ -373,7 +380,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dh3 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
b_q = (b_q * scale).to(b_q.dtype)
b_dh3 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
b_dh3 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype))

if K > 192:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
Expand All @@ -383,7 +391,7 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dh4 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
b_q = (b_q * scale).to(b_q.dtype)
b_dh4 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
b_dh4 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype))

if USE_INITIAL_STATE:
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
Expand Down Expand Up @@ -460,19 +468,25 @@ def chunk_gated_delta_product_bwd_dhu(
dv: torch.Tensor,
scale: float,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
chunk_size: int = 64,
num_householder: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *q.shape, do.shape[-1]
assert T % num_householder == 0, "T must be divisible by num_householder"
T_true = T // num_householder

# N: the actual number of sequences in the batch with either equal or variable lengths
BT = 64
assert K <= 256, "current kernel does not support head dimension being larger than 256."

chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
chunk_indices = prepare_chunk_indices(cu_seqlens // num_householder, chunk_size) if cu_seqlens is not None else None
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
N, NT, chunk_offsets = B, triton.cdiv(T_true, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
N, NT, chunk_offsets = (
len(cu_seqlens) - 1, len(chunk_indices),
prepare_chunk_offsets(cu_seqlens // num_householder, BT)
)

dh = q.new_empty(B, NT, H, K, V)
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
Expand All @@ -494,9 +508,12 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H)
chunk_offsets=chunk_offsets,
scale=scale,
T=T,
num_householder=num_householder,
H=H,
K=K,
V=V,
BT=BT,
)
# could call chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64 instead
# after adjusting number of tokens
return dh, dh0, dv2
209 changes: 209 additions & 0 deletions fla/ops/gated_delta_product/chunk_deltaproduct_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,212 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
BT=BT,
)
return o

# @triton.heuristics({
# 'USE_G': lambda args: args['g'] is not None,
# 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
# })
# @triton.autotune(
# configs=[
# triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
# for BK in BKV_LIST
# for BV in BKV_LIST
# for num_warps in NUM_WARPS
# for num_stages in [2, 3, 4]
# ],
# key=['H', 'K', 'V', 'BT'],
# )
# @triton.jit(do_not_specialize=['T'])
# def chunk_gated_delta_product_bwd_kernel_o(
# q,
# k,
# v,
# h,
# g,
# do,
# dq,
# dk,
# dv,
# dh,
# cu_seqlens,
# chunk_indices,
# scale,
# T,
# num_householder: tl.constexpr,
# H: tl.constexpr,
# K: tl.constexpr,
# V: tl.constexpr,
# BT: tl.constexpr,
# BK: tl.constexpr,
# BV: tl.constexpr,
# USE_G: tl.constexpr,
# IS_VARLEN: tl.constexpr,
# ):
# # same parameters as forward pass
# i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# i_b, i_h = i_bh // H, i_bh % H

# if IS_VARLEN:
# i_tg = i_t
# i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
# bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
# T = eos - bos
# NT = tl.cdiv(T, BT)
# else:
# NT = tl.cdiv(T, BT)
# i_tg = i_b * NT + i_t
# bos, eos = i_b * T, i_b * T + T

# # offset calculation
# q += (bos * H + i_h) * K
# k += (bos * H + i_h) * K
# v += (bos * H + i_h) * V
# do += (bos * H + i_h) * V
# dq += (bos * H + i_h) * K
# dk += (bos * H + i_h) * K
# dv += (bos * H + i_h) * V
# h += (i_tg * H + i_h).to(tl.int64) * K*V
# dh += (i_tg * H + i_h).to(tl.int64) * K*V

# b_dq = tl.zeros([BT, BK], dtype=tl.float32)
# b_dk = tl.zeros([BT, BK], dtype=tl.float32)
# b_dv = tl.zeros([BT, BV], dtype=tl.float32)
# b_ds = tl.zeros([BT, BT], dtype=tl.float32)

# # Compute gradients from hidden state
# for i_k in range(tl.cdiv(K, BK)):
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))

# # [BT, BK]
# b_q = tl.load(p_q, boundary_check=(0, 1))
# # [BK, BV]
# b_h = tl.load(p_h, boundary_check=(0, 1))
# b_dh = tl.load(p_dh, boundary_check=(0, 1))
# # [BT, BV]
# b_do = tl.load(p_do, boundary_check=(0, 1))

# # Compute gradients w.r.t. q: dq += do @ h^T
# b_dq += tl.dot(b_do, tl.trans(b_h))

# # Compute gradients w.r.t. h: dh += q^T @ do
# tl.store(p_dh, (b_dh + tl.dot(tl.trans(b_q), b_do)).to(p_dh.dtype.element_ty), boundary_check=(0, 1))

# # Process multiple Householder transformations
# for i_dp in range(num_householder):
# b_A = tl.zeros([BT, BT], dtype=tl.float32)

# # Compute attention matrix A = Q @ K^T for this Householder step
# for i_k in range(tl.cdiv(K, BK)):
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))

# b_q = tl.load(p_q, boundary_check=(0, 1))
# b_k = tl.load(p_k, boundary_check=(0, 1))
# b_A += tl.dot(b_q, b_k)

# # Apply causal mask and gating
# o_t = i_t * BT + tl.arange(0, BT)
# m_t = o_t < T
# if USE_G:
# g += bos * H + i_h
# p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
# b_g = tl.load(p_g, boundary_check=(0,))
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
# b_A = tl.where(m_A, b_A * exp(b_g[:, None] - b_g[None, :]), 0)
# else:
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
# b_A = tl.where(m_A, b_A, 0)

# # Load values for this Householder step
# p_v = tl.make_block_ptr(v+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# b_v = tl.load(p_v, boundary_check=(0, 1))
# b_do = tl.load(p_do, boundary_check=(0, 1))

# # Gradient w.r.t. values: dv += A^T @ do
# b_dv += tl.dot(tl.trans(b_A.to(b_v.dtype)), b_do)

# # Gradient w.r.t. attention scores: ds = do @ v^T
# b_ds += tl.dot(b_do, tl.trans(b_v))

# # Apply scale and gating to score gradients
# b_ds = b_ds * scale
# if USE_G:
# b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0)
# else:
# b_ds = tl.where(m_A, b_ds, 0)

# # Compute final gradients for each Householder step
# for i_dp in range(num_householder):
# for i_k in range(tl.cdiv(K, BK)):
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
# p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_dk = tl.make_block_ptr(dk+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))

# b_q = tl.load(p_q, boundary_check=(0, 1))
# b_k = tl.load(p_k, boundary_check=(0, 1))

# # dq += ds @ k^T
# b_dq += tl.dot(b_ds, tl.trans(b_k))
# # dk += q^T @ ds
# b_dk = tl.dot(tl.trans(b_q), b_ds)

# tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))

# # Store value gradients
# for i_dp in range(num_householder):
# p_dv = tl.make_block_ptr(dv+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))


# def chunk_gated_delta_product_bwd_o(
# q: torch.Tensor,
# k: torch.Tensor,
# v: torch.Tensor,
# h: torch.Tensor,
# g: Optional[torch.Tensor] = None,
# do: torch.Tensor = None,
# scale: Optional[float] = None,
# cu_seqlens: Optional[torch.LongTensor] = None,
# chunk_size: int = 64,
# num_householder: int = 1,
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# assert q.shape[1] * num_householder == k.shape[1], "q.shape[1] * num_householder must be equal to k.shape[1]"
# B, T, H, K, V = *q.shape, v.shape[-1]
# BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
# chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
# NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

# dq = torch.zeros_like(q)
# dk = torch.zeros_like(k)
# dv = torch.zeros_like(v)
# dh = torch.zeros_like(h)

# def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
# chunk_gated_delta_product_bwd_kernel_o[grid](
# q,
# k,
# v,
# h,
# g,
# do,
# dq,
# dk,
# dv,
# dh,
# cu_seqlens,
# chunk_indices,
# scale,
# T=T,
# num_householder=num_householder,
# H=H,
# K=K,
# V=V,
# BT=BT,
# )
# return dq, dk, dv, dh
8 changes: 4 additions & 4 deletions fla/ops/simple_gla/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Simple GLA

Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet).
Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet).

Compared to GLA, the gating is head-wise instead of elementwise.
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability.
It is faster than GLA but has less expressive power.
Compared to GLA, the gating is head-wise instead of elementwise.
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability.
It is faster than GLA but has less expressive power.
I will use it as a baseline for the GLA.

$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.
Loading
Loading