Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse cumulative sum into FP8xINT4 Grouped Gemm #3812

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 96 additions & 33 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm_fp8_rowwise,
)
from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle
from tinygemm.utils import group_quantize_tensor

if torch.cuda.is_available() and torch.version.cuda:
Expand Down Expand Up @@ -1326,51 +1327,113 @@ def cuda(self) -> bool:


@register_quantize_op
class F8I4ShuffledGemm(F8I4RowwiseGemm):
def _int4_row_quantize(
self,
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float)
class F8I4ShuffledGemm(QuantizeOpBase):
def preprocess(self, x, w):
# Prequantize and pack weights.
wq, row_scale, group_scale = quantize_int4_preshuffle(w)
return x, wq, row_scale, group_scale

max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
max_int = 2 ** (n_bit - 1)
min_int = -(2 ** (n_bit - 1))
scales = max_val.clamp(min=1e-6) / max_int
def quantize(self, x, wq, row_scale, group_scale):
# Quantize both input tensors.
xq, x_scale = quantize_fp8_row(x)
return xq, wq, x_scale, row_scale, group_scale

out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
def compute(self, xq, wq, x_scale, row_scale, group_scale):
# Handle batched cases by looping over each batch.
if xq.dim() == 3:
B, M, _ = xq.shape
_, N, _ = wq.shape
y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
for i in range(B):
y[i] = torch.ops.fbgemm.f8i4bf16_shuffled(
xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i]
)
return y
# Otherwise run gemm normally.
return torch.ops.fbgemm.f8i4bf16_shuffled(
xq, wq, x_scale, row_scale, group_scale
)

# Cast to int8 and restore shape.
out = out.to(dtype=torch.int8).reshape(x.shape)
def quantize_and_compute(self, x, wq, row_scale, group_scale):
xq, wq, x_scale, row_scale, group_scale = self.quantize(
x, wq, row_scale, group_scale
)
return self.compute(xq, wq, x_scale, row_scale, group_scale)

# Scales should be in [num_groups, N] layout.
scales = scales.view(x.shape[0], -1).t().contiguous()
@property
def name(self) -> str:
return "cutlass_f8i4_preshuffle"

return out, scales
@property
def hip(self) -> bool:
# Not yet supported on AMD.
return False

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = quantize_fp8_row(x)
wq, w_scale = self._int4_row_quantize(w)
# Pack int4 values together.
wq = self._pack_int4(wq)
# Shuffle weights and scales for faster compute.
wq, w_scale = torch.ops.fbgemm.preshuffle_i4(wq, w_scale)
return xq, wq, x_scale, w_scale
@property
def cuda(self) -> bool:
return True

def compute(self, xq, wq, x_scale, w_scale):
out = torch.ops.fbgemm.f8i4bf16_shuffled(xq, wq, x_scale, w_scale)

@register_quantize_op
class F8I4ShuffledGroupedGemm(QuantizeOpBase):
"""
FP8 x Int4 mixed dtype grouped gemm with preshuffling.
"""

def preprocess(self, x, w):
assert isinstance(x, list) and isinstance(
w, list
), "Only supported for grouped inputs."
m_values = [i.shape[0] for i in x]
# Convert m_values into offsets into grouped tensor.
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
# Quantize weights.
# TODO Only rowwise scaling is currently supported. This needs to be fixed.
K = x[0].shape[-1]
wq, row_scale, group_scale = zip(
*[quantize_int4_preshuffle(i, group_size=K) for i in w]
)
# Group weights as single tensor.
wq = torch.stack(wq, dim=0).contiguous()
row_scale = torch.stack(row_scale, dim=0).contiguous()
group_scale = torch.stack(group_scale, dim=0).contiguous()
# Also view input as flattened.
x = torch.concat(x, dim=0).contiguous()
# Return processed tensors.
return x, wq, row_scale, group_scale, m_sizes

def quantize(self, x, wq, row_scale, group_scale, m_sizes):
B = x.shape[0]
xq, x_scale = triton_quantize_fp8_row(x)
x_scale = x_scale.view(B, -1)
return xq, wq, x_scale, row_scale, group_scale, m_sizes

def compute(self, xq, wq, x_scale, row_scale, group_scale, m_sizes):
out = torch.ops.fbgemm.f8i4bf16_shuffled_grouped(
xq, wq, x_scale, row_scale, group_scale, m_sizes
)
return out

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale, w_scale)
def quantize_and_compute(self, x, wq, row_scale, group_scale, m_sizes):
xq, wq, x_scale, row_scale, group_scale, m_sizes = self.quantize(
x, wq, row_scale, group_scale, m_sizes
)
return self.compute(xq, wq, x_scale, row_scale, group_scale, m_sizes)

@property
def name(self) -> str:
return "cutlass_f8i4_preshuffle"
if torch.version.cuda:
return "cutlass_f8i4_grouped_preshuffle"
else:
return "ck_f8i4_grouped_preshuffle"

@property
def hip(self) -> bool:
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
Expand Down
119 changes: 119 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

# Helper functions for using FBGEMM quantized operators.

from typing import Tuple

import torch

from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row


def pack_int4(x: torch.Tensor) -> torch.Tensor:
# Given int8 x, pack adjacent int4 values into a single int8.
low_x = x[:, ::2]
high_x = x[:, 1::2]

# High bits need to left shift, this also masks off extra bits.
high_x = torch.bitwise_left_shift(high_x, 4)
# Low bits need to have sign bits removed.
low_x = torch.bitwise_and(low_x, 0xF)

# Recombine into a single value with bitwise or.
return torch.bitwise_or(low_x, high_x).contiguous()


def int4_row_quantize(
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Helper function to quantize a tensor to int4 with groupwise scales.

Args:
x (Tensor): [N, K] Higher precision weight tensor to quantize.
group_size (int): Number of elements to calculate group scale for.
Returns:
wq (Tensor): [N, K // 2] Quantized int4 tensor stored in int8 elements.
group_scale (Tensor): [K / group_size, N] FP32 Scale per group.
"""
n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float)

max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
max_int = 2 ** (n_bit - 1)
min_int = -(2 ** (n_bit - 1))
scales = max_val.clamp(min=1e-6) / max_int

out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)

# Cast to int8 and restore shape.
out = out.to(dtype=torch.int8).reshape(x.shape)

# Scales should be in [num_groups, N] layout.
scales = scales.view(x.shape[0], -1).t().contiguous()

return out, scales


def quantize_int4_preshuffle(
w: torch.Tensor, group_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Quantizes an input weight tensor to int4 using preshuffling and scale packing.
This function is intended to be used with fbgemms mixed dtype kernels and is expected
to be applied to weights ahead of time. As such, it is not perfectly optimized.

Args:
w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
group_size (int): Number of elements to calculate group scale for, must be at least 128.
Returns:
wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
row_scale (Tensor): [N] FP32 Scale per row of the weight tensor.
group_scale (Tensor): [K / group_size, 8, N] FP8 Scale per group of the weight tensor.
"""

def _quantize(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Start by lowering weights to FP8 and producing row scales.
wq, row_scale = quantize_fp8_row(w)

# Now reduce to INT4.
wq, group_scale = int4_row_quantize(wq, group_size)
# Reduce group scale to FP8.
group_scale = group_scale.to(torch.float8_e4m3fn)

# Take quantized weights and pack them efficiently.
wq = pack_int4(wq)

# Finally pack weights and scales into efficient preshuffled format.
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)

return wq, row_scale, group_scale

if w.ndim >= 3:
orig_shape = w.shape
# Flatten to 3 dimensions then iterate over batches.
w = w.view(-1, *w.shape[1:])
w.unbind(dim=0)
wq = []
row_scale = []
group_scale = []
for batch in w:
wq_, row_scale_, group_scale_ = _quantize(batch)
wq.append(wq_)
row_scale.append(row_scale_)
group_scale.append(group_scale_)
wq = torch.stack(wq).view(*orig_shape[:-2], *wq[0].shape)
row_scale = torch.stack(row_scale).view(*orig_shape[:-2], *row_scale[0].shape)
group_scale = torch.stack(group_scale).view(
*orig_shape[:-2], *group_scale[0].shape
)
else:
wq, row_scale, group_scale = _quantize(w)
return wq, row_scale, group_scale
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ at::Tensor f8i4bf16_rowwise_impl(

int group_size = K / num_groups;

// Return immediately if input is empty.
if (M == 0 || N == 0 || K == 0) {
return at::zeros({M, N}, XQ.options().dtype(at::kBFloat16));
}
auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));

using ElementInputA = INPUT_DTYPE;
Expand Down
Loading
Loading