10
10
from lightllm .common .kv_trans_kernel .kv_trans import kv_trans
11
11
from lightllm .utils .dist_utils import get_current_rank_in_node
12
12
from lightllm .utils .envs_utils import get_unique_server_name , get_env_start_args
13
+ from lightllm .distributed .pynccl import PyNcclCommunicator
13
14
14
15
15
16
logger = init_logger (__name__ )
@@ -91,7 +92,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
91
92
return
92
93
93
94
def send_to_decode_node (
94
- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
95
+ self ,
96
+ move_tasks : List [KVMoveTask ],
97
+ mem_managers : List ["MemoryManager" ],
98
+ dp_size_in_node : int ,
99
+ nccl_comm : PyNcclCommunicator ,
95
100
):
96
101
assert dp_size_in_node == 1
97
102
@@ -108,14 +113,14 @@ def send_to_decode_node(
108
113
for layer_index in range (mem .layer_num ):
109
114
move_buffer = mem ._get_kv_move_data (move_token_indexes , layer_index )
110
115
if i == cur_device_index :
111
- dist .send (move_buffer , dst = 1 )
116
+ nccl_comm .send (move_buffer , dst = 1 )
112
117
else :
113
118
move_size = move_buffer .numel ()
114
119
new_move_buffer = cur_mem .kv_move_buffer .view (- 1 )[0 :move_size ].view (move_buffer .shape )
115
120
from torch .cuda import comm
116
121
117
122
comm .broadcast (move_buffer , out = [new_move_buffer ])
118
- dist .send (new_move_buffer , dst = 1 )
123
+ nccl_comm .send (new_move_buffer , dst = 1 )
119
124
return
120
125
121
126
def _get_kv_move_data (self , token_indexes : List [int ], layer_index : int ):
@@ -127,7 +132,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
127
132
return move_buffer
128
133
129
134
def receive_from_prefill_node (
130
- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
135
+ self ,
136
+ move_tasks : List [KVMoveTask ],
137
+ mem_managers : List ["MemoryManager" ],
138
+ dp_size_in_node : int ,
139
+ nccl_comm : PyNcclCommunicator ,
131
140
):
132
141
assert dp_size_in_node == 1
133
142
@@ -144,7 +153,7 @@ def receive_from_prefill_node(
144
153
recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (1 , token_num , 2 * self .head_num , self .head_dim )
145
154
for i , mem in enumerate (mem_managers ):
146
155
for layer_index in range (mem .layer_num ):
147
- dist .recv (recive_buffer , src = 0 )
156
+ nccl_comm .recv (recive_buffer , src = 0 )
148
157
if i == cur_device_index :
149
158
mem ._write_kv_move_data (move_token_indexes , recive_buffer , layer_index )
150
159
else :
@@ -160,7 +169,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
160
169
return
161
170
162
171
def send_to_decode_node_p2p (
163
- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
172
+ self ,
173
+ move_tasks : List [KVMoveTask ],
174
+ mem_managers : List ["MemoryManager" ],
175
+ dp_size_in_node : int ,
176
+ nccl_comm : PyNcclCommunicator ,
164
177
):
165
178
"""
166
179
使用 p2p triton kernel 进行数据复制和传输的实现方式。
@@ -178,7 +191,7 @@ def send_to_decode_node_p2p(
178
191
for i , mem in enumerate (mem_managers ):
179
192
for layer_index in range (mem .layer_num ):
180
193
move_buffer = mem ._get_kv_move_data_p2p (move_token_indexes , layer_index , self .kv_move_buffer )
181
- dist .send (move_buffer , dst = 1 )
194
+ nccl_comm .send (move_buffer , dst = 1 )
182
195
return
183
196
184
197
def _get_kv_move_data_p2p (self , token_indexes : torch .Tensor , layer_index : int , kv_move_buffer : torch .Tensor ):
@@ -191,7 +204,11 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
191
204
return move_buffer
192
205
193
206
def receive_from_prefill_node_p2p (
194
- self , move_tasks : List [KVMoveTask ], mem_managers : List ["MemoryManager" ], dp_size_in_node : int
207
+ self ,
208
+ move_tasks : List [KVMoveTask ],
209
+ mem_managers : List ["MemoryManager" ],
210
+ dp_size_in_node : int ,
211
+ nccl_comm : PyNcclCommunicator ,
195
212
):
196
213
assert dp_size_in_node == 1
197
214
@@ -209,7 +226,7 @@ def receive_from_prefill_node_p2p(
209
226
recive_buffer = self .kv_move_buffer .view (- 1 )[0 :move_size ].view (token_num , 2 * self .head_num , self .head_dim )
210
227
for i , mem in enumerate (mem_managers ):
211
228
for layer_index in range (mem .layer_num ):
212
- dist .recv (recive_buffer , src = 0 )
229
+ nccl_comm .recv (recive_buffer , src = 0 )
213
230
mem ._write_kv_move_data_p2p (move_token_indexes , recive_buffer , layer_index )
214
231
return
215
232
0 commit comments