Skip to content

Commit a9a715c

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Enable FP8 Triton dequantized block-wise kernel (#3788)
Summary: Pull Request resolved: #3788 X-link: facebookresearch/FBGEMM#875 Enable FP8 Triton dequantized block-wise kernel, which is required to upcast with block-wise quantized all2all. Reviewed By: sunfish2010 Differential Revision: D70872110 fbshipit-source-id: fa842baa49c72b67e6c12c375f469dae3219827a
1 parent ba25044 commit a9a715c

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

+29
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
if torch.cuda.is_available():
1616
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
17+
dequantize_fp8_block,
1718
matmul_fp8_block,
1819
matmul_fp8_row,
1920
quantize_fp8_block,
@@ -274,6 +275,34 @@ def _test_quantize_fp8_block(
274275
_test_quantize_fp8_block((3, 6), (2, 8))
275276
_test_quantize_fp8_block((3, 6), (2, 8), use_scale_ub=True)
276277

278+
def test_dequantize_fp8_block(self) -> None:
279+
def _test_dequantize_fp8_block(
280+
shape: Tuple[int, int],
281+
block_shape: Tuple[int, int],
282+
use_scale_ub: bool = False,
283+
) -> None:
284+
M, K = shape
285+
BLOCK_M, BLOCK_K = block_shape
286+
a = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
287+
288+
scale_ub = (
289+
torch.tensor([1200], dtype=torch.float, device="cuda")
290+
if use_scale_ub
291+
else None
292+
)
293+
294+
a_fp8, a_scale = quantize_fp8_block(
295+
a, block_m=BLOCK_M, block_k=BLOCK_K, scale_ub=scale_ub
296+
)
297+
a_dequant = dequantize_fp8_block(
298+
a_fp8, a_scale, block_m=BLOCK_M, block_k=BLOCK_K
299+
)
300+
self.assertTrue(torch.allclose(a, a_dequant, atol=2e-1, rtol=5e-2))
301+
302+
_test_dequantize_fp8_block((3, 1024), (1, 256))
303+
_test_dequantize_fp8_block((11, 128), (1, 128))
304+
_test_dequantize_fp8_block((11, 256), (1, 256), use_scale_ub=True)
305+
277306
def test_matmul_fp8_block(self) -> None:
278307
def _test_matmul_fp8_block(
279308
shape: Tuple[int, int, int],

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

+72
Original file line numberDiff line numberDiff line change
@@ -3097,3 +3097,75 @@ def _kernel_matmul_fp8_row_non_persistent(
30973097
tl.store(C, acc, mask=mask)
30983098
else:
30993099
tl.atomic_add(C, acc, mask=mask)
3100+
3101+
3102+
@triton.jit
3103+
def _kernel_dequantize_fp8_block(
3104+
xq_ptr,
3105+
x_scale_ptr,
3106+
x_dequant_ptr,
3107+
M,
3108+
K,
3109+
BLOCK_M: tl.constexpr,
3110+
BLOCK_K: tl.constexpr,
3111+
):
3112+
"""
3113+
Kernel to dequantize FP8 tensor to BF16 tensor.
3114+
Args:
3115+
xq_ptr (tl.constexpr): Pointer to FP8 tensor.
3116+
x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
3117+
x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
3118+
M (tl.constexpr): M dimension of input tensor.
3119+
K (tl.constexpr): K dimension of input tensor.
3120+
BLOCK_M (tl.constexpr): Block size for the M dimension.
3121+
BLOCK_K (tl.constexpr): Block size for the K dimension.
3122+
"""
3123+
pid_m = tl.program_id(axis=0)
3124+
pid_k = tl.program_id(axis=1)
3125+
k = tl.cdiv(K, BLOCK_K)
3126+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
3127+
offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
3128+
offs = offs_m[:, None] * K + offs_k[None, :]
3129+
mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
3130+
xq = tl.load(xq_ptr + offs, mask=mask).to(tl.bfloat16)
3131+
x_scale = tl.load(x_scale_ptr + pid_m * k + pid_k)
3132+
x_dequant = xq * x_scale
3133+
tl.store(x_dequant_ptr + offs, x_dequant, mask=mask)
3134+
3135+
3136+
def dequantize_fp8_block(
3137+
xq: torch.Tensor,
3138+
x_scale: torch.Tensor,
3139+
block_m: int = 256,
3140+
block_k: int = 256,
3141+
) -> torch.Tensor:
3142+
"""
3143+
Dequantize FP8 tensor to BF16 tensor.
3144+
3145+
Args:
3146+
xq (torch.Tensor): FP8 tensor to be dequantized.
3147+
x_scale (torch.Tensor): FP8 scale tensor.
3148+
block_m (int): Block size for the M dimension.
3149+
block_k (int): Block size for the K dimension.
3150+
3151+
Returns:
3152+
torch.Tensor: Dequantized BF16 tensor.
3153+
"""
3154+
3155+
assert (
3156+
xq.is_contiguous() and x_scale.is_contiguous()
3157+
), "Input tensors must be contiguous"
3158+
assert xq.dim() == 2 and x_scale.dim() == 2, "Input tensors must have 2 dimensions"
3159+
M, K = xq.size()
3160+
x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
3161+
3162+
def grid(meta):
3163+
return (
3164+
triton.cdiv(M, meta["BLOCK_M"]),
3165+
triton.cdiv(K, meta["BLOCK_K"]),
3166+
)
3167+
3168+
_kernel_dequantize_fp8_block[grid](
3169+
xq, x_scale, x_dequant, M, K, BLOCK_M=block_m, BLOCK_K=block_k # pyre-ignore[6]
3170+
)
3171+
return x_dequant

0 commit comments

Comments
 (0)