Skip to content
87 changes: 72 additions & 15 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,67 @@
'MistralForCausalLM': 'lm_head',
}

STR_TO_TORCH_DTYPE = {
'float16': torch.float16,
'bfloat16': torch.bfloat16,
'float32': torch.float32
}

TORCH_DTYPE_TO_STR = {
torch.float16: 'float16',
torch.bfloat16: 'bfloat16',
torch.float32: 'float32'
}

def _set_use_cache(model):
model.config.use_cache = False
if hasattr(model.config, 'text_config'):
model.config.text_config.use_cache = False
elif hasattr(model.config, 'llm_config'):
model.config.llm_config.use_cache = False


def _get_torch_dtype(config):
def _resolve_dtype(config):
dtype = getattr(config, 'torch_dtype', None)
if dtype is None:
dtype = getattr(config, 'dtype', None)
return dtype

dtype = _resolve_dtype(config)

if hasattr(config, 'text_config'):
sub_dtype = _resolve_dtype(config.text_config)
if sub_dtype is not None:
dtype = sub_dtype
elif hasattr(config, 'llm_config'):
sub_dtype = _resolve_dtype(config.llm_config)
if sub_dtype is not None:
dtype = sub_dtype

if dtype is None:
dtype = 'bfloat16'

if isinstance(dtype, torch.dtype):
return dtype
return STR_TO_TORCH_DTYPE[dtype]


Comment on lines +128 to +134
def _set_config_dtype(model, torch_dtype):
dtype = TORCH_DTYPE_TO_STR[torch_dtype]
configs = [model.config]

for name in ['text_config', 'llm_config', 'vision_config', 'ts_config']:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason to check vision_config and ts_config?

sub_config = getattr(model.config, name, None)
if sub_config is not None:
configs.append(sub_config)

for config in configs:
if hasattr(config, 'dtype'):
config.dtype = dtype


def check_vl_llm(backend: str, config: dict) -> bool:
def check_vl_llm(config: dict) -> bool:
"""Check if the model is a vl model from model config."""
if 'auto_map' in config:
for _, v in config['auto_map'].items():
Expand Down Expand Up @@ -121,11 +180,11 @@ def check_vl_llm(backend: str, config: dict) -> bool:
return False


def get_task(backend: str, model_path: str):
def get_task(model_path: str, trust_remote_code: bool = False) -> str:
"""Get pipeline type and pipeline class from model config."""

_, config = get_model_arch(model_path)
if check_vl_llm(backend, config.to_dict()):
_, config = get_model_arch(model_path, trust_remote_code)
if check_vl_llm(config.to_dict()):
return 'vlm'

# default task
Expand Down Expand Up @@ -203,11 +262,11 @@ class name or the class type itself.


# TODO to be removed
def make_compatible_internvl_config(model_path):
def make_compatible_internvl_config(model_path, trust_remote_code: bool = False):
"""Patch model.config since after transformers v4.45.0, InternVL models
can't use `save_pretrained`"""
from lmdeploy.archs import get_model_arch
arch, _ = get_model_arch(model_path)
arch, _ = get_model_arch(model_path, trust_remote_code)
if arch == 'InternVLChatModel':
import transformers
from packaging import version
Expand Down Expand Up @@ -257,8 +316,8 @@ def load_model_and_tokenizer(model: str,
work_dir: str = './work_dir',
trust_remote_code: bool = False):
"""Load model and tokenizer."""
model_type = get_task(backend='turbomind', model_path=model)
make_compatible_internvl_config(model)
model_type = get_task(model, trust_remote_code)
make_compatible_internvl_config(model, trust_remote_code)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=trust_remote_code)
Expand All @@ -276,14 +335,12 @@ def load_model_and_tokenizer(model: str,
model = vl_model.language_model
if hasattr(vl_model, 'llm'): # MiniCPMV, ...
model = vl_model.llm
model.config.use_cache = False
if hasattr(model.config, 'text_config'):
model.config.text_config.use_cache = False
elif hasattr(model.config, 'llm_config'):
model.config.llm_config.use_cache = False
if dtype == 'float16' or (dtype == 'auto' and original_config.torch_dtype == torch.float16):
_set_use_cache(model)
torch_dtype = _get_torch_dtype(original_config)
_set_config_dtype(model, torch_dtype)
if dtype == 'float16' or (dtype == 'auto' and torch_dtype == torch.float16):
model.half()
elif dtype == 'bfloat16' or (dtype == 'auto' and original_config.torch_dtype == torch.bfloat16):
elif dtype == 'bfloat16' or (dtype == 'auto' and torch_dtype == torch.bfloat16):
Comment on lines +338 to +343
assert torch.cuda.is_bf16_supported(
), 'your device does not support bfloat16 please set --dtype float16' # noqa
model.to(torch.bfloat16)
Expand Down
1 change: 0 additions & 1 deletion lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def smooth_quant(model: str,

patterns = []
skipped_modules = []
arch = model.config.architectures[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't "MODELS.get(arch)" fail if arch = model.config.architectures[0] is removed?

rebuilder = MODELS.get(arch)
if rebuilder:
patterns = rebuilder.skipped_modules()
Expand Down
Loading