diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 45bbf83dc1..c27372dd07 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -164,9 +164,16 @@ def get_tm_model(model_path, input_model_name = get_input_model_registered_name(model_path, engine_config.model_format) input_policy = get_input_policy(engine_config.model_format) + + if engine_config.model_format == 'fp8' and not quant_config: + use_quant_online = True + else: + use_quant_online = False + input_model = INPUT_MODELS.get(input_model_name)(model_path=model_path, tokenizer_path=model_path, - input_policy=input_policy) + input_policy=input_policy, + use_quant_online=use_quant_online) output_model_name, tm_cfg = \ get_output_model_registered_name_and_config( diff --git a/lmdeploy/turbomind/deploy/source_model/llama.py b/lmdeploy/turbomind/deploy/source_model/llama.py index d88be26113..bcad0e45f3 100644 --- a/lmdeploy/turbomind/deploy/source_model/llama.py +++ b/lmdeploy/turbomind/deploy/source_model/llama.py @@ -5,6 +5,7 @@ import torch from lmdeploy.archs import get_model_arch +from lmdeploy.lite.quantization.weight.quant_utils import quant_blocked_fp8 from ..config import RopeParam from ..loader import create_loader @@ -23,7 +24,16 @@ class LlamaReader(BaseReader): attn_pattern = r'self_attn' ffn_pattern = r'mlp' - def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_cfg: dict, policy): + proj_pattern = 'proj' + scale_inv_prefix = 'scale_inv' + + def __init__(self, + new_params: dict, + unused_params: dict, + last_bin: bool, + model_cfg: dict, + policy, + use_quant_online: bool = False): super().__init__() self.params = unused_params self.params.update(new_params) @@ -33,6 +43,22 @@ def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, model_ if tie_word_embeddings: self.output_weight_key = self.tok_embeddings_key self.processor = policy + if use_quant_online: + quant_params = self.quant_weight_fp8() + self.params.update(quant_params) + + def quant_weight_fp8(self): + pattern_str = f'({self.attn_pattern}|{self.ffn_pattern}).*{self.proj_pattern}' + target_pattern = re.compile(pattern_str) + + quant_params = {} + for name, weight in self.params.items(): + if target_pattern.search(name): + q_weight, scale = quant_blocked_fp8(weight, torch.float8_e4m3fn, block_size=128) + quant_params[name] = q_weight + quant_params[f'{name}_{self.scale_inv_prefix}'] = scale.to(weight.dtype) + + return quant_params def filter(self, pattern: str): params = [] @@ -104,6 +130,7 @@ class LlamaModel(BaseInputModel): def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict): super().__init__(model_path, tokenizer_path) self.policy = kwargs.get('input_policy') + self.use_quant_online = kwargs.get('use_quant_online', False) _, self.model_config = get_model_arch(model_path) self.model_config = self.model_config.to_dict() @@ -111,7 +138,11 @@ def readers(self): mappings = getattr(self.Reader, 'mappings', []) loader = create_loader(self.model_path, self.Reader.attn_layer_patten, mappings) for i, param in loader.items(): - reader = self.Reader(param, {}, False, self.model_config, policy=self.policy) + reader = self.Reader(param, {}, + False, + self.model_config, + policy=self.policy, + use_quant_online=self.use_quant_online) yield i, reader torch.cuda.empty_cache()