Skip to content
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class OpType(Enum):
LinearBlockedF8 = auto()
FusedMoEBlockedF8 = auto()
NSAIndexFP8 = auto()
Embedding = auto()


class OpsBackend(ABC):
Expand Down
42 changes: 42 additions & 0 deletions lmdeploy/pytorch/backends/default/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.distributed as dist
import torch.nn.functional as F

from ..embedding import EmbeddingBuilder, EmbeddingImpl


def get_masked_input_and_mask(input: torch.Tensor, start_index: int, end_index: int):
input = input - start_index
masked_input = input.clamp(0, end_index - start_index - 1)
inv_vocab_mask = masked_input != input
return masked_input, inv_vocab_mask


class DefaultEmbeddingImpl(EmbeddingImpl):
"""Embedding implementation api."""
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "Embedding implementation api" but it should be "Embedding implementation API" (API should be uppercase as it's an acronym).

Suggested change
"""Embedding implementation api."""
"""Embedding implementation API."""

Copilot uses AI. Check for mistakes.

def __init__(self, start_index: int, end_index: int):
self.start_index = start_index
self.end_index = end_index

def forward(self, x, weight: torch.Tensor, all_reduce: bool = False, group: dist.ProcessGroup = None):
"""forward."""
if all_reduce:
mask_input, inv_vocab_mask = get_masked_input_and_mask(x, self.start_index, self.end_index)
out = F.embedding(mask_input, weight)
out.masked_fill_(inv_vocab_mask.unsqueeze(-1), 0)
dist.all_reduce(out, group=group)
else:
out = F.embedding(x, weight)

return out


class DefaultEmbeddingBuilder(EmbeddingBuilder):
"""Embedding implementation builder."""

@staticmethod
def build(start_index: int, end_index: int):
"""build."""
return DefaultEmbeddingImpl(start_index=start_index, end_index=end_index)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/backends/default/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.SoftmaxTopK:
from .moe import DefaultSoftmaxTopKBuilder
return DefaultSoftmaxTopKBuilder
elif layer_type == OpType.Embedding:
from .embedding import DefaultEmbeddingBuilder
return DefaultEmbeddingBuilder
else:
raise RuntimeError(f'{layer_type} not supported.')

Expand Down
24 changes: 24 additions & 0 deletions lmdeploy/pytorch/backends/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod

import torch
import torch.distributed as dist


class EmbeddingImpl(ABC):
"""Embedding implementation api."""
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "Embedding implementation api" but it should be "Embedding implementation API" (API should be uppercase as it's an acronym).

Suggested change
"""Embedding implementation api."""
"""Embedding implementation API."""

Copilot uses AI. Check for mistakes.

@abstractmethod
def forward(self, x, weight: torch.Tensor, all_reduce: bool = False, group: dist.ProcessGroup = None):
"""forward."""
raise NotImplementedError


class EmbeddingBuilder(ABC):
"""Embedding implementation builder."""

@staticmethod
@abstractmethod
def build(start_index: int, end_index: int):
"""build."""
raise NotImplementedError
16 changes: 9 additions & 7 deletions lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank, get_tp_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager, get_step_ctx_manager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, SiluAndMul, build_rotary_embedding,
build_rotary_params)
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, ParallelEmbedding, RMSNorm, RopeType, SiluAndMul,
build_rotary_embedding, build_rotary_params)
from lmdeploy.pytorch.nn.eplb import EPLBDispatchInfo, EPLBManager
from lmdeploy.pytorch.nn.linear import (build_colwise_linear, build_down_linear, build_gateup_linear, build_o_proj,
build_rowwise_linear)
Expand Down Expand Up @@ -965,11 +965,13 @@ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device)
self.embed_tokens = ParallelEmbedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device,
is_tp=True)

if get_dist_manager().current_context().dist_config.enable_eplb:
ep_size_, _ = get_ep_world_rank()
EPLBManager.init_global_eplb_metadata(ep_size_, config.n_routed_experts, config.num_hidden_layers)
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# https://github.com/vllm-project/vllm/blob/main/vllm/attention/
from .activation import GeluAndMul, SiluAndMul # noqa: F401
from .attention import Attention, FlashAttention # noqa: F401
from .embedding import ParallelEmbedding # noqa: F401
from .norm import LayerNorm, RMSNorm # noqa: F401
from .rotary_embedding import ApplyRotaryEmb # noqa: F401
from .rotary_embedding import RopeType # noqa: F401
Expand Down
101 changes: 101 additions & 0 deletions lmdeploy/pytorch/nn/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'torch' is imported with both 'import' and 'import from'.
Module 'lmdeploy.pytorch.check_env.torch' is imported with both 'import' and 'import from'.

Copilot uses AI. Check for mistakes.
from torch import nn

from lmdeploy.pytorch.backends import OpType, get_backend
from lmdeploy.pytorch.distributed import get_dist_group, get_dist_manager, get_tp_world_rank
from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader

DEFAULT_VOCAB_PADDING_SIZE = 64


def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to


class ParallelEmbedding(nn.Module):

Comment on lines +17 to +18
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for the ParallelEmbedding class is missing. As this is a new public API being introduced, it should have comprehensive documentation explaining its purpose, parameters, usage, and how it differs from standard embedding when is_tp is enabled.

Suggested change
class ParallelEmbedding(nn.Module):
class ParallelEmbedding(nn.Module):
"""Embedding layer with optional tensor-parallel sharding over the vocabulary.
This module implements an embedding lookup that can operate either as a
standard (non-parallel) embedding or as a tensor-parallel (TP) embedding
where the vocabulary dimension is partitioned row-wise across TP ranks.
When ``is_tp`` is ``False``, the behavior is similar to a standard
:class:`torch.nn.Embedding` in that the full vocabulary is stored on each
rank and looked up directly. When ``is_tp`` is ``True`` and the tensor
parallel world size for the given ``layer_type`` is greater than 1:
* The vocabulary size is padded up to a multiple of ``padding_size`` so
that it can be evenly partitioned across tensor-parallel ranks.
* Each rank owns a contiguous shard of the padded vocabulary range
``[start_index, end_index)``, and only that shard is materialized in
the local ``weight`` parameter.
* Embedding lookups are dispatched through a backend-specific
implementation (selected via :func:`get_backend`) that is aware of the
sharding and can perform any necessary communication. If required by the
backend, the outputs can be all-reduced across TP ranks using
``tp_group`` when ``all_reduce`` is ``True``.
The ``weight`` parameter is created as a non-trainable
:class:`torch.nn.Parameter` (``requires_grad=False``) and is expected to be
populated by the model weight loading pipeline. When tensor parallelism is
enabled, only the local shard of the full embedding matrix is loaded.
Parameters:
vocab_size (int): Size of the vocabulary before any padding or
sharding is applied.
hidden_size (int): Size of each embedding vector.
padding_idx (int): Index of the padding token in the *global*
vocabulary. If not ``None`` and the index falls within the local
shard range ``[start_index, end_index)``, the corresponding row in
``weight`` is zeroed after loading.
dtype (torch.dtype, optional): Data type of the embedding weights.
Defaults to ``torch.float16`` when not specified.
device (torch.device, optional): Device on which to allocate the
embedding weights. Defaults to ``"cuda"`` when not specified.
is_tp (bool, optional): Whether to enable tensor-parallel sharding of
the embedding weights. If ``False``, the full vocabulary is stored
on each rank and standard (non-sharded) loading is used. If
``True`` and the tensor-parallel size is greater than 1, the
vocabulary dimension is partitioned across ranks and a TP-aware
weight loader is used.
padding_size (int, optional): Granularity to which ``vocab_size`` is
padded when tensor parallelism is enabled. The padded vocabulary
size must be divisible by the tensor-parallel world size.
layer_type (str, optional): Logical layer type used to obtain the
tensor-parallel configuration and communication group (for example,
``"mlp"`` or other named layer groups supported by the distributed
manager).
Notes:
* ``vocab_size_padded`` stores either the full vocabulary size
(non-TP) or the per-rank shard size (TP).
* ``start_index`` and ``end_index`` denote the global vocabulary range
owned by the current rank when tensor parallelism is enabled.
* The actual embedding lookup behavior is delegated to the backend
implementation returned by :func:`get_backend`, which may implement
optimized fused kernels or communication patterns.
"""

Copilot uses AI. Check for mistakes.
def __init__(self,
vocab_size: int,
hidden_size: int,
padding_idx: int,
dtype: torch.dtype = None,
device: torch.device = None,
is_tp: bool = False,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different layer_type have different behaviour when dp>1. As you want to gather inputs in tp groups, I think the default value should be 'attn'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done

layer_type: str = 'attn'):
self.dist_ctx = get_dist_manager().current_context()
super().__init__()

self.is_tp = is_tp
self.vocab_size = vocab_size
self.padding_size = padding_size
if padding_idx is not None:
if padding_idx < 0:
padding_idx = vocab_size + padding_idx
assert padding_idx >= 0 and padding_idx < vocab_size
self.padding_idx = padding_idx

dist_cfg = get_dist_manager().current_config()
_, self.rank = get_tp_world_rank(layer_type)
self.tp, _ = dist_cfg.get_tp_by_layer(layer_type)

dist_group = get_dist_group(layer_type=layer_type)
self.tp_group = dist_group.gpu_group

if is_tp and self.tp > 1:
self.vocab_size_padded = pad_vocab_size(self.vocab_size, self.padding_size)
assert self.vocab_size_padded % self.tp == 0, \
f'vocab_size_padded({self.vocab_size_padded}) must be divisible by tp({self.tp})'
self.vocab_size_padded = self.vocab_size_padded // self.tp
else:
self.vocab_size_padded = self.vocab_size

self.start_index = self.rank * self.vocab_size_padded
self.end_index = (self.rank + 1) * self.vocab_size_padded
Comment on lines +55 to +56
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is_tp=False or tp <= 1, the start_index and end_index calculation will use self.rank, which may not be 0. This could lead to incorrect behavior when tensor parallelism is not enabled. Consider setting start_index=0 and end_index=vocab_size when tensor parallelism is not active.

Copilot uses AI. Check for mistakes.
self.register_parameter('weight', self.create_weight(self.vocab_size_padded, hidden_size, dtype, device))
self.weight.weight_loader = self.weight_loader

backend = get_backend()
builder = backend.get_layer_impl_builder(OpType.Embedding)
self.impl = builder.build(self.start_index, self.end_index)

self.all_reduce = self.is_tp and self.tp > 1

@staticmethod
def create_weight(vocab_size: int, hidden_size: int, dtype: torch.dtype = None, device: torch.device = None):
"""Create weight."""
if dtype is None:
dtype = torch.float16
if device is None:
device = 'cuda'
weight = torch.nn.Parameter(torch.zeros((vocab_size, hidden_size), dtype=dtype, device=device),
requires_grad=False)
return weight

def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
"""Weight loader for rowwise embedding."""
loaded_weight = loaded_weight.to(param.device)

shard_size = self.vocab_size_padded
if self.end_index > loaded_weight.shape[0]:
shard_size = loaded_weight.shape[0] - self.start_index

loaded_weight = loaded_weight.narrow(0, self.start_index, shard_size)
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)

def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor):
"""Weight loader."""
if not self.all_reduce:
default_weight_loader(param, loaded_weight)
if self.padding_idx is not None:
self.weight[self.padding_idx] = 0
else:
self._weight_loader_tp_rowwise(param, loaded_weight)
if (self.padding_idx is not None and self.padding_idx >= self.start_index
and self.padding_idx < self.end_index):
self.weight[self.padding_idx - self.start_index] = 0

def forward(self, x: torch.Tensor):
return self.impl.forward(x, self.weight, all_reduce=self.all_reduce, group=self.tp_group)
125 changes: 125 additions & 0 deletions tests/pytorch/nn/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a kernel, the ut should not be placed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not quite sure where to place the unit test files—could you give me a suggestion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a new folder under pytorch, may be pytorch/nn. Or just forget about the unit test, we have daily ete test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done

import time

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import nn
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'torch' is imported with both 'import' and 'import from'.
Module 'lmdeploy.pytorch.check_env.torch' is imported with both 'import' and 'import from'.

Suggested change
from torch import nn

Copilot uses AI. Check for mistakes.

from lmdeploy.pytorch.distributed import DefaultContext
from lmdeploy.pytorch.nn import ParallelEmbedding


def parallel_emb(rank: int, world_size: int, vocab_size: int, feat_size: int, padding_idx: int, dtype: torch.dtype,
x: torch.Tensor, weight: torch.Tensor, result_queue: mp.Queue):
dist.init_process_group('nccl', rank=rank, world_size=world_size)
gpu_group = dist.new_group(ranks=list(range(world_size)), backend='nccl')

DefaultContext.attn_tp_group.rank = rank
DefaultContext.dist_config.attn_tp = world_size
DefaultContext.attn_tp_group.gpu_group = gpu_group

model = ParallelEmbedding(vocab_size=vocab_size,
hidden_size=feat_size,
padding_idx=padding_idx,
dtype=dtype,
is_tp=True,
device=torch.device(type='cuda', index=rank))

weight = weight.to(torch.device(type='cuda', index=rank))
model.weight_loader(model.weight, weight)

input = x.to(torch.device(type='cuda', index=rank))

with torch.inference_mode():
out = model(input)
Comment on lines +33 to +36
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'input' shadows the built-in Python function 'input'. Consider renaming it to 'inputs', 'input_tensor', or 'x_cuda' to avoid shadowing built-ins.

Suggested change
input = x.to(torch.device(type='cuda', index=rank))
with torch.inference_mode():
out = model(input)
input_tensor = x.to(torch.device(type='cuda', index=rank))
with torch.inference_mode():
out = model(input_tensor)

Copilot uses AI. Check for mistakes.

if rank == 0:
result_queue.put(mp.reductions.reduce_tensor(out))

if dist.is_initialized():
dist.destroy_process_group()


class TestEmbedding:

@pytest.fixture
def vocab_size(self, request):
yield request.param

@pytest.fixture
def feat_size(self, request):
yield request.param

@pytest.fixture
def padding_idx(self, request):
yield request.param

@pytest.fixture
def dtype(self, request):
yield request.param

@pytest.fixture
def tp(self, request):
yield request.param

@pytest.fixture
def seqlen(self, request):
yield request.param

@pytest.fixture
def weight(self, vocab_size, feat_size, dtype):
yield torch.rand(vocab_size, feat_size, dtype=dtype)

@pytest.fixture
def x(self, seqlen, vocab_size):
yield torch.randint(low=0, high=vocab_size, size=(seqlen, ), dtype=torch.int32)

@pytest.fixture
def gt(self, x, vocab_size, feat_size, padding_idx, dtype, weight):
token_emb = nn.Embedding(vocab_size,
feat_size,
padding_idx=padding_idx,
dtype=dtype,
device=torch.device(type='cuda', index=0))
token_emb.weight.data.copy_(weight)
token_emb._fill_padding_idx_with_zero()
input = x.to(torch.device(type='cuda', index=0))
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable 'input' shadows the built-in Python function 'input'. Consider renaming it to 'inputs', 'input_tensor', or 'x_cuda' to avoid shadowing built-ins.

Copilot uses AI. Check for mistakes.
yield token_emb(input)

@pytest.mark.parametrize('vocab_size', [65576, 65533, 3333], indirect=True)
@pytest.mark.parametrize('feat_size', [4096, 768], indirect=True)
@pytest.mark.parametrize('padding_idx', [None], indirect=True)
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test only covers the case where padding_idx is None. To ensure the padding_idx handling logic is correct, the test should also include cases with valid padding_idx values (e.g., 0, 1, -1, vocab_size-1) to verify that padding tokens are correctly zeroed out.

Copilot uses AI. Check for mistakes.
@pytest.mark.parametrize('seqlen', [1024, 1011, 128], indirect=True)
@pytest.mark.parametrize('tp', [2], indirect=True)
@pytest.mark.parametrize('dtype', [torch.bfloat16], indirect=True)
def test_embedding(self, vocab_size, feat_size, padding_idx, seqlen, tp, dtype, x, weight, gt):
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test only covers the tensor parallel case (is_tp=True). To ensure the ParallelEmbedding module works correctly in non-TP mode, add test cases with is_tp=False to verify that the module behaves like a standard embedding layer when tensor parallelism is disabled.

Copilot uses AI. Check for mistakes.
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
os.environ['NCCL_SOCKET_IFNAME'] = 'lo'

world_size = tp
processes = []
mp.set_start_method('spawn', force=True)
result_queue = mp.Queue()

for rank in range(world_size):
p = mp.Process(target=parallel_emb,
args=(rank, world_size, vocab_size, feat_size, padding_idx, dtype, x, weight, result_queue))
processes.append(p)
p.start()
time.sleep(0.5)

func, args = result_queue.get()
out = func(*args)

for p in processes:
p.join(timeout=10)
if p.is_alive():
p.terminate()
p.join(timeout=5)
if p.is_alive():
p.kill()

torch.testing.assert_close(out, gt)