-
Notifications
You must be signed in to change notification settings - Fork 657
lmdeploy suppport parrllel embedding #4192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6291fb8
c9e447b
bc5ba6e
73916c0
2bcced7
b08b10d
74f816f
4effa7b
4700a67
fd7a1a8
ef445b6
da8ffb3
1bdb0ff
a3d7828
5b7e9d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.nn.functional as F | ||
|
|
||
| from ..embedding import EmbeddingBuilder, EmbeddingImpl | ||
|
|
||
|
|
||
| def get_masked_input_and_mask(input: torch.Tensor, start_index: int, end_index: int): | ||
| input = input - start_index | ||
| masked_input = input.clamp(0, end_index - start_index - 1) | ||
| inv_vocab_mask = masked_input != input | ||
| return masked_input, inv_vocab_mask | ||
|
|
||
|
|
||
| class DefaultEmbeddingImpl(EmbeddingImpl): | ||
| """Embedding implementation api.""" | ||
|
|
||
| def __init__(self, start_index: int, end_index: int): | ||
| self.start_index = start_index | ||
| self.end_index = end_index | ||
|
|
||
| def forward(self, x, weight: torch.Tensor, all_reduce: bool = False, group: dist.ProcessGroup = None): | ||
| """forward.""" | ||
| if all_reduce: | ||
| mask_input, inv_vocab_mask = get_masked_input_and_mask(x, self.start_index, self.end_index) | ||
| out = F.embedding(mask_input, weight) | ||
| out.masked_fill_(inv_vocab_mask.unsqueeze(-1), 0) | ||
| dist.all_reduce(out, group=group) | ||
| else: | ||
| out = F.embedding(x, weight) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| class DefaultEmbeddingBuilder(EmbeddingBuilder): | ||
| """Embedding implementation builder.""" | ||
|
|
||
| @staticmethod | ||
| def build(start_index: int, end_index: int): | ||
| """build.""" | ||
| return DefaultEmbeddingImpl(start_index=start_index, end_index=end_index) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,24 @@ | ||||||
| # Copyright (c) OpenMMLab. All rights reserved. | ||||||
| from abc import ABC, abstractmethod | ||||||
|
|
||||||
| import torch | ||||||
| import torch.distributed as dist | ||||||
|
|
||||||
|
|
||||||
| class EmbeddingImpl(ABC): | ||||||
| """Embedding implementation api.""" | ||||||
|
||||||
| """Embedding implementation api.""" | |
| """Embedding implementation API.""" |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,101 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) OpenMMLab. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch import nn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from lmdeploy.pytorch.backends import OpType, get_backend | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from lmdeploy.pytorch.distributed import get_dist_group, get_dist_manager, get_tp_world_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEFAULT_VOCAB_PADDING_SIZE = 64 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Pad the vocab size to the given value.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ((vocab_size + pad_to - 1) // pad_to) * pad_to | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class ParallelEmbedding(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+17
to
+18
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class ParallelEmbedding(nn.Module): | |
| class ParallelEmbedding(nn.Module): | |
| """Embedding layer with optional tensor-parallel sharding over the vocabulary. | |
| This module implements an embedding lookup that can operate either as a | |
| standard (non-parallel) embedding or as a tensor-parallel (TP) embedding | |
| where the vocabulary dimension is partitioned row-wise across TP ranks. | |
| When ``is_tp`` is ``False``, the behavior is similar to a standard | |
| :class:`torch.nn.Embedding` in that the full vocabulary is stored on each | |
| rank and looked up directly. When ``is_tp`` is ``True`` and the tensor | |
| parallel world size for the given ``layer_type`` is greater than 1: | |
| * The vocabulary size is padded up to a multiple of ``padding_size`` so | |
| that it can be evenly partitioned across tensor-parallel ranks. | |
| * Each rank owns a contiguous shard of the padded vocabulary range | |
| ``[start_index, end_index)``, and only that shard is materialized in | |
| the local ``weight`` parameter. | |
| * Embedding lookups are dispatched through a backend-specific | |
| implementation (selected via :func:`get_backend`) that is aware of the | |
| sharding and can perform any necessary communication. If required by the | |
| backend, the outputs can be all-reduced across TP ranks using | |
| ``tp_group`` when ``all_reduce`` is ``True``. | |
| The ``weight`` parameter is created as a non-trainable | |
| :class:`torch.nn.Parameter` (``requires_grad=False``) and is expected to be | |
| populated by the model weight loading pipeline. When tensor parallelism is | |
| enabled, only the local shard of the full embedding matrix is loaded. | |
| Parameters: | |
| vocab_size (int): Size of the vocabulary before any padding or | |
| sharding is applied. | |
| hidden_size (int): Size of each embedding vector. | |
| padding_idx (int): Index of the padding token in the *global* | |
| vocabulary. If not ``None`` and the index falls within the local | |
| shard range ``[start_index, end_index)``, the corresponding row in | |
| ``weight`` is zeroed after loading. | |
| dtype (torch.dtype, optional): Data type of the embedding weights. | |
| Defaults to ``torch.float16`` when not specified. | |
| device (torch.device, optional): Device on which to allocate the | |
| embedding weights. Defaults to ``"cuda"`` when not specified. | |
| is_tp (bool, optional): Whether to enable tensor-parallel sharding of | |
| the embedding weights. If ``False``, the full vocabulary is stored | |
| on each rank and standard (non-sharded) loading is used. If | |
| ``True`` and the tensor-parallel size is greater than 1, the | |
| vocabulary dimension is partitioned across ranks and a TP-aware | |
| weight loader is used. | |
| padding_size (int, optional): Granularity to which ``vocab_size`` is | |
| padded when tensor parallelism is enabled. The padded vocabulary | |
| size must be divisible by the tensor-parallel world size. | |
| layer_type (str, optional): Logical layer type used to obtain the | |
| tensor-parallel configuration and communication group (for example, | |
| ``"mlp"`` or other named layer groups supported by the distributed | |
| manager). | |
| Notes: | |
| * ``vocab_size_padded`` stores either the full vocabulary size | |
| (non-TP) or the per-rank shard size (TP). | |
| * ``start_index`` and ``end_index`` denote the global vocabulary range | |
| owned by the current rank when tensor parallelism is enabled. | |
| * The actual embedding lookup behavior is delegated to the backend | |
| implementation returned by :func:`get_backend`, which may implement | |
| optimized fused kernels or communication patterns. | |
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different layer_type have different behaviour when dp>1. As you want to gather inputs in tp groups, I think the default value should be 'attn'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's done
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is_tp=False or tp <= 1, the start_index and end_index calculation will use self.rank, which may not be 0. This could lead to incorrect behavior when tensor parallelism is not enabled. Consider setting start_index=0 and end_index=vocab_size when tensor parallelism is not active.
grimoire marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,125 @@ | ||||||||||||||||||
| import os | ||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a kernel, the ut should not be placed here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’m not quite sure where to place the unit test files—could you give me a suggestion?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Create a new folder under pytorch, may be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's done |
||||||||||||||||||
| import time | ||||||||||||||||||
|
|
||||||||||||||||||
| import pytest | ||||||||||||||||||
| import torch | ||||||||||||||||||
| import torch.distributed as dist | ||||||||||||||||||
| import torch.multiprocessing as mp | ||||||||||||||||||
| from torch import nn | ||||||||||||||||||
|
||||||||||||||||||
| from torch import nn |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable 'input' shadows the built-in Python function 'input'. Consider renaming it to 'inputs', 'input_tensor', or 'x_cuda' to avoid shadowing built-ins.
| input = x.to(torch.device(type='cuda', index=rank)) | |
| with torch.inference_mode(): | |
| out = model(input) | |
| input_tensor = x.to(torch.device(type='cuda', index=rank)) | |
| with torch.inference_mode(): | |
| out = model(input_tensor) |
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable 'input' shadows the built-in Python function 'input'. Consider renaming it to 'inputs', 'input_tensor', or 'x_cuda' to avoid shadowing built-ins.
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test only covers the case where padding_idx is None. To ensure the padding_idx handling logic is correct, the test should also include cases with valid padding_idx values (e.g., 0, 1, -1, vocab_size-1) to verify that padding tokens are correctly zeroed out.
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test only covers the tensor parallel case (is_tp=True). To ensure the ParallelEmbedding module works correctly in non-TP mode, add test cases with is_tp=False to verify that the module behaves like a standard embedding layer when tensor parallelism is disabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says "Embedding implementation api" but it should be "Embedding implementation API" (API should be uppercase as it's an acronym).