|
| 1 | +import os |
| 2 | +from typing import Dict, Tuple, Union, Optional |
| 3 | + |
| 4 | +from torch.nn import Module |
| 5 | +from transformers import AutoModel |
| 6 | + |
| 7 | + |
| 8 | +def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: |
| 9 | + # transformer.word_embeddings 占用1层 |
| 10 | + # transformer.final_layernorm 和 lm_head 占用1层 |
| 11 | + # transformer.layers 占用 28 层 |
| 12 | + # 总共30层分配到num_gpus张卡上 |
| 13 | + num_trans_layers = 28 |
| 14 | + per_gpu_layers = 30 / num_gpus |
| 15 | + |
| 16 | + # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError |
| 17 | + # windows下 model.device 会被设置成 transformer.word_embeddings.device |
| 18 | + # linux下 model.device 会被设置成 lm_head.device |
| 19 | + # 在调用chat或者stream_chat时,input_ids会被放到model.device上 |
| 20 | + # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError |
| 21 | + # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 |
| 22 | + # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py |
| 23 | + # 仅此处做少许修改以支持ChatGLM2 |
| 24 | + device_map = { |
| 25 | + 'transformer.embedding.word_embeddings': 0, |
| 26 | + 'transformer.encoder.final_layernorm': 0, |
| 27 | + 'transformer.output_layer': 0, |
| 28 | + 'transformer.rotary_pos_emb': 0, |
| 29 | + 'lm_head': 0 |
| 30 | + } |
| 31 | + |
| 32 | + used = 2 |
| 33 | + gpu_target = 0 |
| 34 | + for i in range(num_trans_layers): |
| 35 | + if used >= per_gpu_layers: |
| 36 | + gpu_target += 1 |
| 37 | + used = 0 |
| 38 | + assert gpu_target < num_gpus |
| 39 | + device_map[f'transformer.encoder.layers.{i}'] = gpu_target |
| 40 | + used += 1 |
| 41 | + |
| 42 | + return device_map |
| 43 | + |
| 44 | + |
| 45 | +def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, |
| 46 | + device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module: |
| 47 | + if num_gpus < 2 and device_map is None: |
| 48 | + model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() |
| 49 | + else: |
| 50 | + from accelerate import dispatch_model |
| 51 | + |
| 52 | + model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half() |
| 53 | + |
| 54 | + if device_map is None: |
| 55 | + device_map = auto_configure_device_map(num_gpus) |
| 56 | + |
| 57 | + model = dispatch_model(model, device_map=device_map) |
| 58 | + |
| 59 | + return model |
0 commit comments