Skip to content

Commit 9e79b01

Browse files
ngxsoncompilade
andauthored
convert: allow using quantized Mistral weight (#17889)
* convert: allow using quantized Mistral weight * data_torch.ndim * update dequant fn Co-authored-by: compilade <[email protected]> --------- Co-authored-by: compilade <[email protected]>
1 parent 2e9eab8 commit 9e79b01

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

convert_hf_to_gguf.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,17 @@ def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: T
383383
s = self.model_tensors[name]
384384
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
385385
tensors_to_remove.append(name)
386+
if name.endswith(".activation_scale"): # unused
387+
tensors_to_remove.append(name)
388+
# mistral format
389+
if name.endswith(".qscale_weight"):
390+
weight_name = name.removesuffix("qscale_weight") + "weight"
391+
w = self.model_tensors[weight_name]
392+
s = self.model_tensors[name]
393+
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
394+
tensors_to_remove.append(name)
395+
if name.endswith(".qscale_act"):
396+
tensors_to_remove.append(name)
386397
elif quant_method == "gptq":
387398
for name in self.model_tensors.keys():
388399
if name.endswith(".qweight"):
@@ -2854,13 +2865,10 @@ def set_gguf_parameters(self):
28542865
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
28552866

28562867
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
2857-
# TODO: probably not worth supporting quantized weight, as official BF16 is also available
2858-
if name.endswith("weight_scale_inv"):
2859-
raise ValueError("This is a quantized weight, please use BF16 weight instead")
2860-
28612868
name = name.replace("language_model.", "")
28622869
if "multi_modal_projector" in name or "vision_tower" in name:
28632870
return []
2871+
28642872
return super().modify_tensors(data_torch, name, bid)
28652873

28662874

@@ -9898,6 +9906,18 @@ def __init__(self, *args, **kwargs):
98989906
self.gguf_writer.add_architecture()
98999907
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
99009908

9909+
def dequant_model(self):
9910+
# transform quantization config into HF format
9911+
quant_config = self.hparams.get("quantization")
9912+
if quant_config is not None:
9913+
assert quant_config["qformat_weight"] == "fp8_e4m3"
9914+
self.hparams["quantization_config"] = {
9915+
"activation_scheme": "static",
9916+
"quant_method": "fp8",
9917+
"weight_block_size": None,
9918+
}
9919+
return super().dequant_model()
9920+
99019921
@staticmethod
99029922
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
99039923
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg

0 commit comments

Comments
 (0)