diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index 937cdf01ee..226881ffb2 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -32,6 +32,7 @@ class OpType(Enum): LinearBlockedF8 = auto() FusedMoEBlockedF8 = auto() NSAIndexFP8 = auto() + Embedding = auto() class OpsBackend(ABC): diff --git a/lmdeploy/pytorch/backends/default/embedding.py b/lmdeploy/pytorch/backends/default/embedding.py new file mode 100644 index 0000000000..f4fe948ef3 --- /dev/null +++ b/lmdeploy/pytorch/backends/default/embedding.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.distributed as dist +import torch.nn.functional as F + +from ..embedding import EmbeddingBuilder, EmbeddingImpl + + +def get_masked_input_and_mask(input: torch.Tensor, start_index: int, end_index: int): + input = input - start_index + masked_input = input.clamp(0, end_index - start_index - 1) + inv_vocab_mask = masked_input != input + return masked_input, inv_vocab_mask + + +class DefaultEmbeddingImpl(EmbeddingImpl): + """Embedding implementation api.""" + + def __init__(self, start_index: int, end_index: int): + self.start_index = start_index + self.end_index = end_index + + def forward(self, x, weight: torch.Tensor, all_reduce: bool = False, group: dist.ProcessGroup = None): + """forward.""" + if all_reduce: + mask_input, inv_vocab_mask = get_masked_input_and_mask(x, self.start_index, self.end_index) + out = F.embedding(mask_input, weight) + out.masked_fill_(inv_vocab_mask.unsqueeze(-1), 0) + dist.all_reduce(out, group=group) + else: + out = F.embedding(x, weight) + + return out + + +class DefaultEmbeddingBuilder(EmbeddingBuilder): + """Embedding implementation builder.""" + + @staticmethod + def build(start_index: int, end_index: int): + """build.""" + return DefaultEmbeddingImpl(start_index=start_index, end_index=end_index) diff --git a/lmdeploy/pytorch/backends/default/op_backend.py b/lmdeploy/pytorch/backends/default/op_backend.py index bd9487f20f..f228cab859 100644 --- a/lmdeploy/pytorch/backends/default/op_backend.py +++ b/lmdeploy/pytorch/backends/default/op_backend.py @@ -45,6 +45,9 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.SoftmaxTopK: from .moe import DefaultSoftmaxTopKBuilder return DefaultSoftmaxTopKBuilder + elif layer_type == OpType.Embedding: + from .embedding import DefaultEmbeddingBuilder + return DefaultEmbeddingBuilder else: raise RuntimeError(f'{layer_type} not supported.') diff --git a/lmdeploy/pytorch/backends/embedding.py b/lmdeploy/pytorch/backends/embedding.py new file mode 100644 index 0000000000..4a1b025aea --- /dev/null +++ b/lmdeploy/pytorch/backends/embedding.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +import torch +import torch.distributed as dist + + +class EmbeddingImpl(ABC): + """Embedding implementation api.""" + + @abstractmethod + def forward(self, x, weight: torch.Tensor, all_reduce: bool = False, group: dist.ProcessGroup = None): + """forward.""" + raise NotImplementedError + + +class EmbeddingBuilder(ABC): + """Embedding implementation builder.""" + + @staticmethod + @abstractmethod + def build(start_index: int, end_index: int): + """build.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index 4db550eb8d..7fa834b2ee 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -13,8 +13,8 @@ import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager, get_step_ctx_manager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding, - build_rotary_params) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, ParallelEmbedding, RMSNorm, RopeType, SiluAndMul, + build_rotary_embedding, build_rotary_params) from lmdeploy.pytorch.nn.eplb import EPLBDispatchInfo, EPLBManager from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj, build_rowwise_linear) @@ -965,11 +965,13 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = ParallelEmbedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device, + is_tp=True) + if get_dist_manager().current_context().dist_config.enable_eplb: ep_size_, _ = get_ep_world_rank() EPLBManager.init_global_eplb_metadata(ep_size_, config.n_routed_experts, config.num_hidden_layers) diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index 96cbee873a..2c89fac7c4 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -3,6 +3,7 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/attention/ from .activation import GeluAndMul, SiluAndMul # noqa: F401 from .attention import Attention, FlashAttention # noqa: F401 +from .embedding import ParallelEmbedding # noqa: F401 from .norm import LayerNorm, RMSNorm # noqa: F401 from .rotary_embedding import ApplyRotaryEmb # noqa: F401 from .rotary_embedding import RopeType # noqa: F401 diff --git a/lmdeploy/pytorch/nn/embedding.py b/lmdeploy/pytorch/nn/embedding.py new file mode 100644 index 0000000000..cd9f5b8086 --- /dev/null +++ b/lmdeploy/pytorch/nn/embedding.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from lmdeploy.pytorch.backends import OpType, get_backend +from lmdeploy.pytorch.distributed import get_dist_group, get_dist_manager, get_tp_world_rank +from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader + +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + +class ParallelEmbedding(nn.Module): + + def __init__(self, + vocab_size: int, + hidden_size: int, + padding_idx: int, + dtype: torch.dtype = None, + device: torch.device = None, + is_tp: bool = False, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + layer_type: str = 'attn'): + self.dist_ctx = get_dist_manager().current_context() + super().__init__() + + self.is_tp = is_tp + self.vocab_size = vocab_size + self.padding_size = padding_size + if padding_idx is not None: + if padding_idx < 0: + padding_idx = vocab_size + padding_idx + assert padding_idx >= 0 and padding_idx < vocab_size + self.padding_idx = padding_idx + + dist_cfg = get_dist_manager().current_config() + _, self.rank = get_tp_world_rank(layer_type) + self.tp, _ = dist_cfg.get_tp_by_layer(layer_type) + + dist_group = get_dist_group(layer_type=layer_type) + self.tp_group = dist_group.gpu_group + + if is_tp and self.tp > 1: + self.vocab_size_padded = pad_vocab_size(self.vocab_size, self.padding_size) + assert self.vocab_size_padded % self.tp == 0, \ + f'vocab_size_padded({self.vocab_size_padded}) must be divisible by tp({self.tp})' + self.vocab_size_padded = self.vocab_size_padded // self.tp + else: + self.vocab_size_padded = self.vocab_size + + self.start_index = self.rank * self.vocab_size_padded + self.end_index = (self.rank + 1) * self.vocab_size_padded + self.register_parameter('weight', self.create_weight(self.vocab_size_padded, hidden_size, dtype, device)) + self.weight.weight_loader = self.weight_loader + + backend = get_backend() + builder = backend.get_layer_impl_builder(OpType.Embedding) + self.impl = builder.build(self.start_index, self.end_index) + + self.all_reduce = self.is_tp and self.tp > 1 + + @staticmethod + def create_weight(vocab_size: int, hidden_size: int, dtype: torch.dtype = None, device: torch.device = None): + """Create weight.""" + if dtype is None: + dtype = torch.float16 + if device is None: + device = 'cuda' + weight = torch.nn.Parameter(torch.zeros((vocab_size, hidden_size), dtype=dtype, device=device), + requires_grad=False) + return weight + + def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): + """Weight loader for rowwise embedding.""" + loaded_weight = loaded_weight.to(param.device) + + shard_size = self.vocab_size_padded + if self.end_index > loaded_weight.shape[0]: + shard_size = loaded_weight.shape[0] - self.start_index + + loaded_weight = loaded_weight.narrow(0, self.start_index, shard_size) + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): + """Weight loader.""" + if not self.all_reduce: + default_weight_loader(param, loaded_weight) + if self.padding_idx is not None: + self.weight[self.padding_idx] = 0 + else: + self._weight_loader_tp_rowwise(param, loaded_weight) + if (self.padding_idx is not None and self.padding_idx >= self.start_index + and self.padding_idx < self.end_index): + self.weight[self.padding_idx - self.start_index] = 0 + + def forward(self, x: torch.Tensor): + return self.impl.forward(x, self.weight, all_reduce=self.all_reduce, group=self.tp_group) diff --git a/tests/pytorch/nn/test_embedding.py b/tests/pytorch/nn/test_embedding.py new file mode 100644 index 0000000000..9ff01f6893 --- /dev/null +++ b/tests/pytorch/nn/test_embedding.py @@ -0,0 +1,125 @@ +import os +import time + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch import nn + +from lmdeploy.pytorch.distributed import DefaultContext +from lmdeploy.pytorch.nn import ParallelEmbedding + + +def parallel_emb(rank: int, world_size: int, vocab_size: int, feat_size: int, padding_idx: int, dtype: torch.dtype, + x: torch.Tensor, weight: torch.Tensor, result_queue: mp.Queue): + dist.init_process_group('nccl', rank=rank, world_size=world_size) + gpu_group = dist.new_group(ranks=list(range(world_size)), backend='nccl') + + DefaultContext.attn_tp_group.rank = rank + DefaultContext.dist_config.attn_tp = world_size + DefaultContext.attn_tp_group.gpu_group = gpu_group + + model = ParallelEmbedding(vocab_size=vocab_size, + hidden_size=feat_size, + padding_idx=padding_idx, + dtype=dtype, + is_tp=True, + device=torch.device(type='cuda', index=rank)) + + weight = weight.to(torch.device(type='cuda', index=rank)) + model.weight_loader(model.weight, weight) + + input = x.to(torch.device(type='cuda', index=rank)) + + with torch.inference_mode(): + out = model(input) + + if rank == 0: + result_queue.put(mp.reductions.reduce_tensor(out)) + + if dist.is_initialized(): + dist.destroy_process_group() + + +class TestEmbedding: + + @pytest.fixture + def vocab_size(self, request): + yield request.param + + @pytest.fixture + def feat_size(self, request): + yield request.param + + @pytest.fixture + def padding_idx(self, request): + yield request.param + + @pytest.fixture + def dtype(self, request): + yield request.param + + @pytest.fixture + def tp(self, request): + yield request.param + + @pytest.fixture + def seqlen(self, request): + yield request.param + + @pytest.fixture + def weight(self, vocab_size, feat_size, dtype): + yield torch.rand(vocab_size, feat_size, dtype=dtype) + + @pytest.fixture + def x(self, seqlen, vocab_size): + yield torch.randint(low=0, high=vocab_size, size=(seqlen, ), dtype=torch.int32) + + @pytest.fixture + def gt(self, x, vocab_size, feat_size, padding_idx, dtype, weight): + token_emb = nn.Embedding(vocab_size, + feat_size, + padding_idx=padding_idx, + dtype=dtype, + device=torch.device(type='cuda', index=0)) + token_emb.weight.data.copy_(weight) + token_emb._fill_padding_idx_with_zero() + input = x.to(torch.device(type='cuda', index=0)) + yield token_emb(input) + + @pytest.mark.parametrize('vocab_size', [65576, 65533, 3333], indirect=True) + @pytest.mark.parametrize('feat_size', [4096, 768], indirect=True) + @pytest.mark.parametrize('padding_idx', [None], indirect=True) + @pytest.mark.parametrize('seqlen', [1024, 1011, 128], indirect=True) + @pytest.mark.parametrize('tp', [2], indirect=True) + @pytest.mark.parametrize('dtype', [torch.bfloat16], indirect=True) + def test_embedding(self, vocab_size, feat_size, padding_idx, seqlen, tp, dtype, x, weight, gt): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29500' + os.environ['NCCL_SOCKET_IFNAME'] = 'lo' + + world_size = tp + processes = [] + mp.set_start_method('spawn', force=True) + result_queue = mp.Queue() + + for rank in range(world_size): + p = mp.Process(target=parallel_emb, + args=(rank, world_size, vocab_size, feat_size, padding_idx, dtype, x, weight, result_queue)) + processes.append(p) + p.start() + time.sleep(0.5) + + func, args = result_queue.get() + out = func(*args) + + for p in processes: + p.join(timeout=10) + if p.is_alive(): + p.terminate() + p.join(timeout=5) + if p.is_alive(): + p.kill() + + torch.testing.assert_close(out, gt)