@@ -59,6 +59,8 @@ def __init__(
59
59
self .transfer_lock = asyncio .Lock () # the lock for transfer to next module in multi node mode.
60
60
self .disable_abort = args .nnodes > 1 and args .dp == 1 # mulitnode dp=1 mode, disable abort
61
61
self .is_multinode_tp = args .dp == 1 and args .nnodes > 1
62
+ self .is_multinode_tp_master = args .dp == 1 and args .nnodes > 1 and args .node_rank == 0
63
+ self .is_multinode_tp_slave = args .dp == 1 and args .nnodes > 1 and args .node_rank > 0
62
64
if self .is_multinode_tp :
63
65
if args .node_rank == 0 :
64
66
self .multinode_req_manager = []
@@ -192,7 +194,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False):
192
194
if is_health_req :
193
195
return sampling_params .group_request_id
194
196
if self .pd_mode == NodeRole .NORMAL :
195
- if not ( self .nnodes > 1 and self . args . dp == 1 ) :
197
+ if not self .is_multinode_tp :
196
198
group_request_id = self .id_gen .generate_id ()
197
199
else :
198
200
if self .node_rank == 0 :
@@ -222,7 +224,7 @@ async def generate(
222
224
223
225
try :
224
226
original_multimodal_params = None
225
- if self .nnodes > 1 and self . node_rank == 0 and self . args . dp == 1 :
227
+ if self .is_multinode_tp_master :
226
228
original_multimodal_params = copy .deepcopy (multimodal_params )
227
229
228
230
if self .pd_mode .is_P_or_NORMAL ():
@@ -366,8 +368,10 @@ async def transfer_to_next_module_or_node(
366
368
original_multimodal_params : MultimodalParams ,
367
369
group_req_objs : Optional [GroupReqObjs ] = None ,
368
370
):
369
- # 多节点纯tp 运行模式下,保证请求能保持相同的顺序转发到其他节点和当前节点next module.
370
- if self .nnodes > 1 and self .node_rank == 0 and self .args .dp == 1 :
371
+ # 多节点纯tp 运行模式下,master 节点需要将请求按照可控的顺序转发给slave节点,
372
+ # 同时转发给salve节点的时候,要保证master节点按照转发的顺序转发给next_module
373
+ # 所以需要锁的控制。
374
+ if self .is_multinode_tp_master :
371
375
async with self .transfer_lock :
372
376
for sender in self .multinode_req_manager :
373
377
sender .send_pyobj (
@@ -376,8 +380,10 @@ async def transfer_to_next_module_or_node(
376
380
)
377
381
await self .transfer_to_next_module (group_req_objs )
378
382
return
379
-
380
- if self .nnodes > 1 and self .node_rank > 0 and self .args .dp == 1 :
383
+ # 多节点纯tp 的slave节点,需要按照接受到请求的顺序转发,这需要锁和排队机制来保证。
384
+ # self.request_order_queue 实现了一种简单的排队取出机制,这样master 和 slave
385
+ # 节点的请求到达各自节点的router的顺序才是一致的,才能完成同步同态调度。
386
+ if self .is_multinode_tp_slave :
381
387
while True :
382
388
if self .request_order_queue and self .request_order_queue [0 ] != group_req_objs .group_req_id :
383
389
await asyncio .sleep (0.002 )
@@ -578,8 +584,10 @@ async def handle_loop(self):
578
584
if self .pd_mode .is_P_or_D ():
579
585
self .forwarding_queue = AsyncQueue ()
580
586
asyncio .create_task (self .pd_handle_loop ())
581
-
582
- if self .args .node_rank > 0 :
587
+
588
+ # 多节点tp模式下的slave节点,需要开启一个协程task用来接收
589
+ # master 转发过来的请求对象。
590
+ if self .is_multinode_tp_slave :
583
591
asyncio .create_task (self .loop_for_request ())
584
592
585
593
while True :
0 commit comments