Skip to content

Commit 16eb6bf

Browse files
hiworldwzjWeichao Luo
and
Weichao Luo
authored
reduce kv transfer process to num of tp for pd. (#813)
Co-authored-by: Weichao Luo <[email protected]>
1 parent d45574e commit 16eb6bf

17 files changed

+1991
-731
lines changed

lightllm/common/deepseek2_mem_manager.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightllm.utils.log_utils import init_logger
88
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
99
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
10+
from lightllm.distributed.pynccl import PyNcclCommunicator
1011

1112
logger = init_logger(__name__)
1213

@@ -35,7 +36,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
3536
return
3637

3738
def send_to_decode_node(
38-
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int
39+
self,
40+
move_tasks: List[KVMoveTask],
41+
mem_managers: List["Deepseek2MemoryManager"],
42+
dp_size_in_node: int,
43+
nccl_comm: PyNcclCommunicator,
3944
):
4045
assert dp_size_in_node == 1
4146

@@ -49,7 +54,7 @@ def send_to_decode_node(
4954
cur_mem = mem_managers[cur_device_index]
5055
for layer_index in range(cur_mem.layer_num):
5156
move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index)
52-
dist.send(move_buffer, dst=1)
57+
nccl_comm.send(move_buffer, dst=1)
5358
return
5459

5560
def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
@@ -61,7 +66,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
6166
return move_buffer
6267

6368
def receive_from_prefill_node(
64-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
69+
self,
70+
move_tasks: List[KVMoveTask],
71+
mem_managers: List["MemoryManager"],
72+
dp_size_in_node: int,
73+
nccl_comm: PyNcclCommunicator,
6574
):
6675
assert dp_size_in_node == 1
6776

@@ -76,7 +85,7 @@ def receive_from_prefill_node(
7685
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
7786
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim)
7887
for layer_index in range(self.layer_num):
79-
dist.recv(recive_buffer, src=0)
88+
nccl_comm.recv(recive_buffer, src=0)
8089
for i, mem in enumerate(mem_managers):
8190
if i == cur_device_index:
8291
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
@@ -93,7 +102,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
93102
return
94103

95104
def send_to_decode_node_p2p(
96-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
105+
self,
106+
move_tasks: List[KVMoveTask],
107+
mem_managers: List["MemoryManager"],
108+
dp_size_in_node: int,
109+
nccl_comm: PyNcclCommunicator,
97110
):
98111
"""
99112
使用 p2p triton kernel 进行数据复制和传输的实现方式。
@@ -120,7 +133,7 @@ def send_to_decode_node_p2p(
120133
move_buffer = self._get_kv_move_data_p2p(
121134
move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node
122135
)
123-
dist.send(move_buffer, dst=1)
136+
nccl_comm.send(move_buffer, dst=1)
124137
return
125138

126139
def _get_kv_move_data_p2p(
@@ -145,7 +158,11 @@ def _get_kv_move_data_p2p(
145158
return move_buffer
146159

147160
def receive_from_prefill_node_p2p(
148-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
161+
self,
162+
move_tasks: List[KVMoveTask],
163+
mem_managers: List["MemoryManager"],
164+
dp_size_in_node: int,
165+
nccl_comm: PyNcclCommunicator,
149166
):
150167
if not hasattr(self, "mem_ptrs_dict"):
151168
self.mem_ptrs_dict = {}
@@ -170,7 +187,7 @@ def receive_from_prefill_node_p2p(
170187
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
171188
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim)
172189
for layer_index in range(self.layer_num):
173-
dist.recv(recive_buffer, src=0)
190+
nccl_comm.recv(recive_buffer, src=0)
174191
self._write_kv_move_data_p2p(
175192
move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node
176193
)

lightllm/common/mem_manager.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
1111
from lightllm.utils.dist_utils import get_current_rank_in_node
1212
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
13+
from lightllm.distributed.pynccl import PyNcclCommunicator
1314

1415

1516
logger = init_logger(__name__)
@@ -91,7 +92,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
9192
return
9293

9394
def send_to_decode_node(
94-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
95+
self,
96+
move_tasks: List[KVMoveTask],
97+
mem_managers: List["MemoryManager"],
98+
dp_size_in_node: int,
99+
nccl_comm: PyNcclCommunicator,
95100
):
96101
assert dp_size_in_node == 1
97102

@@ -108,14 +113,14 @@ def send_to_decode_node(
108113
for layer_index in range(mem.layer_num):
109114
move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index)
110115
if i == cur_device_index:
111-
dist.send(move_buffer, dst=1)
116+
nccl_comm.send(move_buffer, dst=1)
112117
else:
113118
move_size = move_buffer.numel()
114119
new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape)
115120
from torch.cuda import comm
116121

117122
comm.broadcast(move_buffer, out=[new_move_buffer])
118-
dist.send(new_move_buffer, dst=1)
123+
nccl_comm.send(new_move_buffer, dst=1)
119124
return
120125

121126
def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
@@ -127,7 +132,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
127132
return move_buffer
128133

129134
def receive_from_prefill_node(
130-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
135+
self,
136+
move_tasks: List[KVMoveTask],
137+
mem_managers: List["MemoryManager"],
138+
dp_size_in_node: int,
139+
nccl_comm: PyNcclCommunicator,
131140
):
132141
assert dp_size_in_node == 1
133142

@@ -144,7 +153,7 @@ def receive_from_prefill_node(
144153
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim)
145154
for i, mem in enumerate(mem_managers):
146155
for layer_index in range(mem.layer_num):
147-
dist.recv(recive_buffer, src=0)
156+
nccl_comm.recv(recive_buffer, src=0)
148157
if i == cur_device_index:
149158
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
150159
else:
@@ -160,7 +169,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
160169
return
161170

162171
def send_to_decode_node_p2p(
163-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
172+
self,
173+
move_tasks: List[KVMoveTask],
174+
mem_managers: List["MemoryManager"],
175+
dp_size_in_node: int,
176+
nccl_comm: PyNcclCommunicator,
164177
):
165178
"""
166179
使用 p2p triton kernel 进行数据复制和传输的实现方式。
@@ -178,7 +191,7 @@ def send_to_decode_node_p2p(
178191
for i, mem in enumerate(mem_managers):
179192
for layer_index in range(mem.layer_num):
180193
move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
181-
dist.send(move_buffer, dst=1)
194+
nccl_comm.send(move_buffer, dst=1)
182195
return
183196

184197
def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
@@ -191,7 +204,11 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
191204
return move_buffer
192205

193206
def receive_from_prefill_node_p2p(
194-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
207+
self,
208+
move_tasks: List[KVMoveTask],
209+
mem_managers: List["MemoryManager"],
210+
dp_size_in_node: int,
211+
nccl_comm: PyNcclCommunicator,
195212
):
196213
assert dp_size_in_node == 1
197214

@@ -209,7 +226,7 @@ def receive_from_prefill_node_p2p(
209226
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim)
210227
for i, mem in enumerate(mem_managers):
211228
for layer_index in range(mem.layer_num):
212-
dist.recv(recive_buffer, src=0)
229+
nccl_comm.recv(recive_buffer, src=0)
213230
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
214231
return
215232

0 commit comments

Comments
 (0)