diff --git a/gptqmodel/nn_modules/qlinear/gemm_hf_kernel.py b/gptqmodel/nn_modules/qlinear/gemm_hf_kernel.py new file mode 100644 index 000000000..25adf0751 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/gemm_hf_kernel.py @@ -0,0 +1,286 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + + +import torch +import torch.nn as nn +from transformers import PreTrainedModel + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger + + +log = setup_logger() + +gemm_int4_forward_kernel = None +gemm_int4_forward_kernel_exception = None + +try: + from kernels import get_kernel + + gemm_int4_forward_kernel = get_kernel("kernels-community/quantization_gptq").gemm_int4_forward +except Exception as exc: # pragma: no cover - best effort fallback + gemm_int4_forward_kernel_exception = str(exc) + log.warning("Failed to load CPU gemm_4bit kernel: %s. Use fallback path. \ + Please make sure you already `pip install kernels` and the kernels >= 0.11.1", str(exc)) + + +class HFKernelLinear(PackableQuantLinear): + SUPPORTS_BITS = [4] + SUPPORTS_GROUP_SIZE = [16, 32, 64, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True, False] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = True + SUPPORTS_AUTO_PADDING = True + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_DEVICES = [DEVICE.CPU] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + REQUIRES_FORMAT_V2 = True + + # for transformers/optimum tests compat + QUANT_TYPE = "hf_kernel" + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.TORCH), + adapter=adapter, + register_buffers=register_buffers, + **kwargs) + + self.transformed = False + self.dequant_dtype = torch.int8 + + @classmethod + def validate(cls, **args): + if gemm_int4_forward_kernel_exception is not None: + return False, ImportError(gemm_int4_forward_kernel_exception) + + return cls._validate(**args) + + def post_init(self): + super().post_init() + self.optimize() + + def optimize(self): + if self.optimized: + return + + super().optimize() + + def _build_ret_idx(self) -> torch.Tensor: + existing = getattr(self, "ret_idx", None) + total = self.g_idx.shape[0] + if isinstance(existing, torch.Tensor) and existing.numel() == total: + return existing + + device = self.g_idx.device + ret_idx = torch.zeros(total, dtype=torch.int32, device=device) + group_size = max(int(self.group_size), 1) + groups = total // group_size + remainder = total % group_size + g_idx = self.g_idx.to(torch.int32) + g_idx_2 = g_idx * group_size + + if remainder > 0: + mask = g_idx == groups + if mask.any(): + g_idx_2[mask] += torch.arange(remainder, device=device, dtype=torch.int32) + + if groups > 0: + base = torch.arange(group_size, device=device, dtype=torch.int32) + for i in range(groups): + mask = g_idx == i + if not mask.any(): + continue + count = int(mask.sum().item()) + g_idx_2[mask] += base[:count] + + ret_idx[g_idx_2] = torch.arange(total, device=device, dtype=torch.int32) + self.ret_idx = ret_idx + return ret_idx + + def train(self, mode: bool = True): + old_train = self.training + if mode == old_train: + return self + + from ...utils.model import convert_gptq_v1_to_v2_format_module + + if self.SUPPORTS_TRAINING_USE_TORCH_KERNEL: + # training starts + if mode: + # one time clone v1 qzeros and save both v1 and v2 qzeros in memory + if self.qzero_format() == 1: + if not hasattr(self, "qzeros_data_v1"): + self.qzeros_data_v1 = self.qzeros.data.clone() + convert_gptq_v1_to_v2_format_module(self, bits=self.bits, pack_dtype=self.pack_dtype) + self.qzeros_data_v2 = self.qzeros.data + else: + self.qzeros.data = self.qzeros_data_v2 + self.qzero_format(format=2) + + # training switching to inference/eval + else: + if hasattr(self, "qzeros_data_v1"): + # switch qzero back to v1 for inference/eval + self.qzeros.data = self.qzeros_data_v1 + self.qzero_format(format=1) + + return super().train(mode=mode) + + def convert_weight_packed_zp(self, block_n: int = 32): + """ + qweight: int4_weight (*, K, N) uint8 (0-15) + return: packed_weight uint8 (*, N, K/2) (low 4bit + high 4bit) + """ + assert self.qweight.dtype == torch.uint8, "qweight must be uint8" + sizes = list(self.qweight.shape) + if len(sizes) < 2: + raise ValueError("qweight_final rank error") + N, K = sizes[-2], sizes[-1] + assert N % block_n == 0, "N must be divisible by block_n" + assert K % 2 == 0, "K must be even" + BLOCK_N = block_n + BIT_COUNT = 32 # (=32 low +32 high) + prefix = sizes[:-2] + new_shape = prefix + [N // BLOCK_N, BLOCK_N, K // 2, 2] + out_shape = prefix + [N, K // 2] + qw = self.qweight.reshape(new_shape) # (..., N/B, B, K/2, 2) + qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2) + qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64] + high = qw[:, BIT_COUNT:] # high 32 + low = qw[:, :BIT_COUNT] # low 32 + packed = ((high << 4) | low).to(torch.uint8) # combine + final_qweight = packed.reshape(out_shape) + + self.qweight = final_qweight.contiguous() + + def transform_cpu(self): + self.scales = self.scales.to(torch.bfloat16).contiguous() + # Unpack and reorder qweight + weight = torch.bitwise_and( + torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), + self.wf_unsqueeze_neg_one # self.wf.unsqueeze(-1) + ).to(torch.uint8), + self.maxq + ) + ret_idx = self._build_ret_idx() + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]).index_select(0, ret_idx).t() + self.qweight = weight.contiguous() + zeros = torch.bitwise_right_shift( + torch.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor), + self.wf_unsqueeze_zero # self.wf.unsqueeze(0), + ).to(torch.uint8) + zeros = torch.bitwise_and(zeros, self.maxq).reshape(zeros.shape[0], -1) + self.qzeros = zeros.contiguous() + + def transform(self, device): + if device == "cpu": + self.transform_cpu() + self.convert_weight_packed_zp() + else: + raise NotImplementedError + + def forward(self, x: torch.Tensor): + out_shape = x.shape[:-1] + (self.out_features,) + x = x.reshape(-1, x.shape[-1]) + if not self.training and not self.transformed and gemm_int4_forward_kernel is not None: + self.transform(x.device.type) + self.transformed = True + + if self.transformed: + out = self._fused_op_forward(x).reshape(out_shape) + else: + # make sure dequant dtype matches input x + num_itr = self.g_idx.shape[0] // x.shape[-1] + weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) + out = torch.matmul(x, weights).reshape(out_shape) + + # Add bias and adapter + if self.bias is not None: + out.add_(self.bias) + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out + + @torch.no_grad + def _fused_op_forward(self, x): + x = x[:, self.ret_idx].contiguous() + if x.device.type == "cpu": + out = gemm_int4_forward_kernel(x, self.qweight, self.qzeros, self.scales, self.group_size) + else: + raise NotImplementedError + + return out + + # clear gptq only weights: useful in de-quantization + def _empty_gptq_only_weights(self): + self.qzeros = None + self.qweight = None + self.g_idx = None + self.scales = None + +def dequantize_model(model: PreTrainedModel): + for name, module in model.named_modules(): + if isinstance(module, BaseQuantLinear) and not isinstance(module, HFKernelLinear): + raise ValueError( + "Only models loaded using HFKernelLinear are supported for dequantization. " + "Please load model using backend=BACKEND.HF_KERNEL" + ) + + if isinstance(module, HFKernelLinear): + # Create a new Linear layer with dequantized weights + new_module = nn.Linear(module.in_features, module.out_features) + new_module.weight = nn.Parameter(module.dequantize_weight().T.detach().to("cpu", torch.float16)) + new_module.bias = torch.nn.Parameter(module.bias) + + # Replace the module in the model + parent = model + if '.' in name: + parent_name, module_name = name.rsplit('.', 1) + parent = dict(model.named_modules())[parent_name] + else: + module_name = name + + setattr(parent, module_name, new_module) + + del model.config.quantization_config + return model + + +__all__ = ["HFKernelLinear", "dequantize_model"] diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index c54d2449f..e86f27f78 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -22,6 +22,7 @@ class BACKEND(str, Enum): MARLIN = "marlin" # FASTEST: marlin reduce ops in fp32 (higher precision -> more accurate, slightly slower) MARLIN_FP16 = "marlin_fp16" # FASTEST and then some: marlin reduce ops in fp16 (lower precision -> less accurate, slightly faster) BITBLAS = "bitblas" # EXTREMELY FAST: speed at the cost of 10+ minutes of AOT (ahead of time compilation with disk cache) + HF_KERNEL = "hf_kernel" # FAST: optimized from HuggingFace kernels-community # qqq QQQ = "qqq" # marlin based qqq kernel diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 3203f7734..e394bb612 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -32,6 +32,7 @@ from ..nn_modules.qlinear.torch_awq import AwqTorchQuantLinear from ..nn_modules.qlinear.torch_fused import TorchFusedQuantLinear from ..nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear +from ..nn_modules.qlinear.gemm_hf_kernel import HFKernelLinear from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear from ..quantization import FORMAT, METHOD from ..utils.logger import setup_logger @@ -55,6 +56,7 @@ # BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, # BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, # optimized for bs > 1 BACKEND.EXLLAMA_V1: ExllamaQuantLinear, # optimized for bs == 1 + BACKEND.HF_KERNEL: HFKernelLinear, # optimized from HuggingFace kernels-community BACKEND.TORCH_FUSED: TorchFusedQuantLinear, # optimized for Intel XPU BACKEND.TRITON: TritonV2QuantLinear, # good all around kernel that JIT compiles # BACKEND.CUDA: DynamicCudaQuantLinear, @@ -80,8 +82,8 @@ SUPPORTS_BACKEND_MAP = { METHOD.GPTQ: { - FORMAT.GPTQ: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TORCH_FUSED, BACKEND.TRITON, BACKEND.TORCH_FUSED, BACKEND.TORCH, BACKEND.MARLIN_FP16, BACKEND.EXLLAMA_EORA], - FORMAT.GPTQ_V2: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TORCH_FUSED, BACKEND.TRITON, BACKEND.TORCH], + FORMAT.GPTQ: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.HF_KERNEL, BACKEND.TRITON, BACKEND.TORCH_FUSED, BACKEND.TORCH, BACKEND.MARLIN_FP16, BACKEND.EXLLAMA_EORA], + FORMAT.GPTQ_V2: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.HF_KERNEL, BACKEND.TORCH_FUSED, BACKEND.TRITON, BACKEND.TORCH], FORMAT.MARLIN: [BACKEND.MARLIN, BACKEND.MARLIN_FP16], FORMAT.BITBLAS: [BACKEND.BITBLAS], }, @@ -425,6 +427,8 @@ def select_quant_linear( qlinear = AwqTorchQuantLinear elif backend == BACKEND.TORCH: qlinear = TorchQuantLinear + elif backend == BACKEND.HF_KERNEL: + qlinear = HFKernelLinear elif backend == BACKEND.TORCH_FUSED: qlinear = TorchFusedQuantLinear elif backend == BACKEND.TORCH_FUSED_AWQ: diff --git a/pyproject.toml b/pyproject.toml index 2b4ffcad8..d9724d6bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "dill>=0.3.8", # datasets requirements "pypcre>=0.2.6", "torchao>=0.14.1", # fix bad transformers 4.57.1 breaking torchao compat + "kernels>=0.11.1", # For CPU kernels # "cython>=3.1.4", # required by hf-xet/hf-transfer # "flash-attn>=2.8.3", <-- install for lower vram usage ] diff --git a/requirements.txt b/requirements.txt index af06bd6f6..0b0159b35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ datasets>=3.6.0 pyarrow>=21.0 dill>=0.3.8 torchao>=0.14.1 +kernels>=0.11.1 diff --git a/tests/test_kernel_output_torch_fused.py b/tests/test_kernel_output_intel_cpu_xpu.py similarity index 91% rename from tests/test_kernel_output_torch_fused.py rename to tests/test_kernel_output_intel_cpu_xpu.py index 4dcfa1218..3ff0c687d 100644 --- a/tests/test_kernel_output_torch_fused.py +++ b/tests/test_kernel_output_intel_cpu_xpu.py @@ -14,6 +14,7 @@ from gptqmodel import BACKEND, GPTQModel from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear from gptqmodel.nn_modules.qlinear.torch_fused import TorchFusedQuantLinear +from gptqmodel.nn_modules.qlinear.gemm_hf_kernel import HFKernelLinear from gptqmodel.utils.model import find_modules @@ -29,6 +30,7 @@ class TestKernelOutput(unittest.TestCase): target_qliner_map = { BACKEND.TORCH: TorchQuantLinear, BACKEND.TORCH_FUSED: TorchFusedQuantLinear, + BACKEND.HF_KERNEL: HFKernelLinear, } target = 'model.layers.6.self_attn.v_proj' device = "cpu" @@ -71,6 +73,7 @@ def assert_on_mismatch(self, a: Tensor, b: Tensor, rtol=0.00005, atol=0.00005): (BACKEND.TORCH, 0.0000, 0.0005), # (BACKEND.TRITON, 0.0000, 0.0005), (BACKEND.TORCH_FUSED, r_tolerance, a_tolerance), + (BACKEND.HF_KERNEL, r_tolerance, a_tolerance), ]) def test_kernel_output(self, backend: BACKEND, r_tolerance: float, a_tolerance: float): model = GPTQModel.load(self.model_path, backend=backend, device=self.device, dtype=self.dtype) @@ -99,7 +102,7 @@ class TestKernelOutputXPUBFloat16(TestKernelOutputXPU): dtype = torch.bfloat16 -class TestTorchFusedKernelDevices(unittest.TestCase): +class TestTorchFusedAndHFKernelDevices(unittest.TestCase): model_path = TestKernelOutput.model_path target_qliner_map = TestKernelOutput.target_qliner_map target = TestKernelOutput.target @@ -149,23 +152,24 @@ def assert_on_mismatch(self, a: Tensor, b: Tensor, rtol=0.00005, atol=0.00005): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) @parameterized.expand([ - ("cpu", "cpu"), - ("xpu", "xpu:0"), + ("cpu", "cpu", BACKEND.TORCH_FUSED), + ("cpu", "cpu", BACKEND.HF_KERNEL), + ("xpu", "xpu:0", BACKEND.TORCH_FUSED), ]) - def test_torch_fused_matches_cpu_reference(self, _name: str, device: str): + def test_backends_matches_cpu_reference(self, _name: str, device: str, backend: BACKEND): if device.startswith("xpu") and not _xpu_available(): self.skipTest("Test requires XPU") model = GPTQModel.load( self.model_path, - backend=BACKEND.TORCH_FUSED, + backend=backend, device=device, dtype=self.dtype, ) failures = [] for idx, sample in enumerate(self.inputs): model_input = sample.to(model.device) - fused_out = self.forward(model, model_input, BACKEND.TORCH_FUSED) + fused_out = self.forward(model, model_input, backend) reference = self.reference_outputs[idx] try: self.assert_on_mismatch( @@ -181,7 +185,7 @@ def test_torch_fused_matches_cpu_reference(self, _name: str, device: str): table = tabulate( [ [ - BACKEND.TORCH_FUSED.name, + backend.name, str(self.dtype), device, len(self.inputs), @@ -210,7 +214,7 @@ def test_torch_fused_matches_cpu_reference(self, _name: str, device: str): if failures: raise AssertionError(f"{len(failures)} mismatched samples on device {device}") -class TestTorchFusedKernelDevicesWithBias(TestTorchFusedKernelDevices): +class TestTorchFusedAndHFKernelDevicesWithBias(TestTorchFusedAndHFKernelDevices): model_path = "/monster/data/model/bloom-560m-gptqmodel-4bit" target = 'transformer.h.6.self_attention.query_key_value' k = 1024