Skip to content

reduce kv transfer process to num of tp for pd. #813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 11, 2025
33 changes: 25 additions & 8 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.utils.log_utils import init_logger
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
from lightllm.distributed.pynccl import PyNcclCommunicator

logger = init_logger(__name__)

Expand Down Expand Up @@ -35,7 +36,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
return

def send_to_decode_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["Deepseek2MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -49,7 +54,7 @@ def send_to_decode_node(
cur_mem = mem_managers[cur_device_index]
for layer_index in range(cur_mem.layer_num):
move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index)
dist.send(move_buffer, dst=1)
nccl_comm.send(move_buffer, dst=1)
return

def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
Expand All @@ -61,7 +66,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
return move_buffer

def receive_from_prefill_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -76,7 +85,7 @@ def receive_from_prefill_node(
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim)
for layer_index in range(self.layer_num):
dist.recv(recive_buffer, src=0)
nccl_comm.recv(recive_buffer, src=0)
for i, mem in enumerate(mem_managers):
if i == cur_device_index:
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
Expand All @@ -93,7 +102,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
return

def send_to_decode_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
"""
使用 p2p triton kernel 进行数据复制和传输的实现方式。
Expand All @@ -120,7 +133,7 @@ def send_to_decode_node_p2p(
move_buffer = self._get_kv_move_data_p2p(
move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node
)
dist.send(move_buffer, dst=1)
nccl_comm.send(move_buffer, dst=1)
return

def _get_kv_move_data_p2p(
Expand All @@ -145,7 +158,11 @@ def _get_kv_move_data_p2p(
return move_buffer

def receive_from_prefill_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
if not hasattr(self, "mem_ptrs_dict"):
self.mem_ptrs_dict = {}
Expand All @@ -170,7 +187,7 @@ def receive_from_prefill_node_p2p(
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim)
for layer_index in range(self.layer_num):
dist.recv(recive_buffer, src=0)
nccl_comm.recv(recive_buffer, src=0)
self._write_kv_move_data_p2p(
move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node
)
Expand Down
35 changes: 26 additions & 9 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
from lightllm.distributed.pynccl import PyNcclCommunicator


logger = init_logger(__name__)
Expand Down Expand Up @@ -91,7 +92,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
return

def send_to_decode_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -108,14 +113,14 @@ def send_to_decode_node(
for layer_index in range(mem.layer_num):
move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index)
if i == cur_device_index:
dist.send(move_buffer, dst=1)
nccl_comm.send(move_buffer, dst=1)
else:
move_size = move_buffer.numel()
new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape)
from torch.cuda import comm

comm.broadcast(move_buffer, out=[new_move_buffer])
dist.send(new_move_buffer, dst=1)
nccl_comm.send(new_move_buffer, dst=1)
return

def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
Expand All @@ -127,7 +132,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
return move_buffer

def receive_from_prefill_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -144,7 +153,7 @@ def receive_from_prefill_node(
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim)
for i, mem in enumerate(mem_managers):
for layer_index in range(mem.layer_num):
dist.recv(recive_buffer, src=0)
nccl_comm.recv(recive_buffer, src=0)
if i == cur_device_index:
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
else:
Expand All @@ -160,7 +169,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
return

def send_to_decode_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
"""
使用 p2p triton kernel 进行数据复制和传输的实现方式。
Expand All @@ -178,7 +191,7 @@ def send_to_decode_node_p2p(
for i, mem in enumerate(mem_managers):
for layer_index in range(mem.layer_num):
move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
dist.send(move_buffer, dst=1)
nccl_comm.send(move_buffer, dst=1)
return

def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
Expand All @@ -191,7 +204,11 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
return move_buffer

def receive_from_prefill_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -209,7 +226,7 @@ def receive_from_prefill_node_p2p(
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim)
for i, mem in enumerate(mem_managers):
for layer_index in range(mem.layer_num):
dist.recv(recive_buffer, src=0)
nccl_comm.recv(recive_buffer, src=0)
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
return

Expand Down
Loading