Skip to content

[NVFP4][ModelOpt] FP4 Activation Quantization Support #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 29 additions & 51 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from torch.nn import Module
from torch.nn.parameter import Parameter

from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
Expand All @@ -18,8 +17,9 @@
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
is_fp4_marlin_supported, prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
dequantize_to_dtype, ref_nvfp4_quant)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down Expand Up @@ -387,66 +387,44 @@ def process_weights_after_loading(self, layer: Module) -> None:
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)

layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
requires_grad=False)

# Swizzle the weight blockscale.
# contracting dimension is input dimension
# block_size = 16;
assert (layer.weight_scale.shape[1] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Block scale must be represented as FP8-E4M3")
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)

layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)

if self.use_marlin:
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
del layer.weight_scale_swizzled

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.use_marlin:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)

# for input only the contracting dimension has a constraint.
x_m, x_k = x.shape
block_size = 16
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]

# quantize BF16 or FP16 to (FP4 and interleaved block scale)
s_quant = 1 / layer.input_scale
x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)

# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
assert (x_fp4.dtype == torch.uint8)
assert (layer.weight.dtype == torch.uint8)
assert (x_blockscale.dtype == torch.float8_e4m3fn)
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
assert (layer.alpha.dtype == torch.float32)

out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled, layer.alpha,
output_dtype)
if bias is not None:
out = out + bias
return out.view(*output_shape)

# quantize input to (FP4 and interleaved block scale)
x_global_scale = 1 / layer.input_scale
x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, block_size)

# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size)
x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
del x_fp4, x_blockscale

# dequantize weight
w_fp4 = layer.weight.data.view(torch.uint8)
w_blockscale = layer.weight_scale_swizzled.data
w_global_scale = 1 / layer.weight_scale_2
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
output_dtype, x.device, block_size)

# matmul
out = torch.matmul(x_dq, w_dq.t())
del w_dq, x_dq
return out


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
import torch

__all__ = [
"break_fp4_bytes",
"dequantize_to_dtype",
]
from vllm.scalar_type import scalar_types

__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"]

FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()

kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)
Expand Down Expand Up @@ -59,3 +60,44 @@ def dequantize_to_dtype(tensor_fp4,
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype)


def get_reciprocal(x):
if isinstance(x, torch.Tensor):
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
elif isinstance(x, (float, int)):
return 0.0 if x == 0 else 1.0 / x
else:
raise TypeError("Input must be a float, int, or a torch.Tensor.")


def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)
x[(x >= 0.0) & (x <= 0.25)] = 0.0
x[(x > 0.25) & (x < 0.75)] = 0.5
x[(x >= 0.75) & (x <= 1.25)] = 1.0
x[(x > 1.25) & (x < 1.75)] = 1.5
x[(x >= 1.75) & (x <= 2.5)] = 2.0
x[(x > 2.5) & (x < 3.5)] = 3.0
x[(x >= 3.5) & (x <= 5.0)] = 4.0
x[x > 5.0] = 6.0
return x * sign


def ref_nvfp4_quant(x, global_scale, block_size):
assert global_scale.dtype == torch.float32
assert x.ndim == 2
m, n = x.shape
x = torch.reshape(x, (m, n // block_size, block_size))
vec_max = torch.max(torch.abs(x), dim=-1,
keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
scale = torch.clamp(scale, max=448, min=-448)
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))

scaled_x = x.to(torch.float32) * output_scale
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
# both outputs are float32
return cast_to_fp4(clipped_x), scale.squeeze(-1)