29
29
from lightllm .utils .envs_utils import get_unique_server_name
30
30
31
31
KV_MOVE_MAX_NUM = 16
32
- KV_MOVE_MAX_RESTART_CNT = 3
32
+ KV_MOVE_MAX_START_CNT = 3
33
33
34
34
logger = init_logger (__name__ )
35
35
@@ -348,20 +348,13 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
348
348
349
349
from .prefill_trans_process import start_prefill_trans_process
350
350
351
- self .kv_trans_ports = []
352
- self .kv_trans_processes = []
353
- self .kv_trans_task_in_queues = []
354
- self .kv_trans_task_out_queues = []
355
- self .kv_trans_process_restart_cnt = []
351
+ self .kv_trans_ports = [None ] * self . node_world_size
352
+ self .kv_trans_processes = [None ] * self . node_world_size
353
+ self .kv_trans_task_in_queues = [None ] * self . node_world_size
354
+ self .kv_trans_task_out_queues = [None ] * self . node_world_size
355
+ self .kv_trans_process_start_cnt = [0 ] * self . node_world_size
356
356
357
357
for device_id in range (self .node_world_size ):
358
- self .kv_trans_task_in_queues .append (mp .Queue ())
359
- self .kv_trans_task_out_queues .append (mp .Queue ())
360
- self .kv_trans_ports .append (
361
- find_available_port (self .args .pd_p_allowed_port_min , self .args .pd_p_allowed_port_max )
362
- )
363
- self .kv_trans_process_restart_cnt .append (0 )
364
- self .kv_trans_processes .append (None )
365
358
assert self .start_trans_process (device_id )
366
359
367
360
return
@@ -385,10 +378,10 @@ def handle_release_task_loop(self):
385
378
return
386
379
387
380
def start_trans_process (self , device_id : int ):
388
- task_in_queue = self . kv_trans_task_in_queues [ device_id ]
389
- task_out_queue = self . kv_trans_task_out_queues [ device_id ]
390
- kv_trans_port = self .kv_trans_ports [ device_id ]
391
- self .kv_trans_process_restart_cnt [device_id ] += 1
381
+ task_in_queue = mp . Queue ()
382
+ task_out_queue = mp . Queue ()
383
+ kv_trans_port = find_available_port ( self .args . pd_p_allowed_port_min , self . args . pd_p_allowed_port_max )
384
+ self .kv_trans_process_start_cnt [device_id ] += 1
392
385
393
386
if self .kv_trans_processes [device_id ]:
394
387
# force kill
@@ -417,6 +410,9 @@ def start_trans_process(self, device_id: int):
417
410
assert task_out_queue .get (timeout = 60 ) == "get_mem_managers_ok"
418
411
419
412
self .kv_trans_processes [device_id ] = kv_trans_process
413
+ self .kv_trans_task_in_queues [device_id ] = task_in_queue
414
+ self .kv_trans_task_out_queues [device_id ] = task_out_queue
415
+ self .kv_trans_ports [device_id ] = kv_trans_port
420
416
421
417
return True
422
418
except Exception as e :
@@ -454,7 +450,7 @@ def check_trans_process_loop(self):
454
450
raise e
455
451
456
452
def is_kv_trans_process_alive (self , device_id ):
457
- return self .kv_trans_process_restart_cnt [device_id ] <= KV_MOVE_MAX_RESTART_CNT
453
+ return self .kv_trans_process_start_cnt [device_id ] <= KV_MOVE_MAX_START_CNT
458
454
459
455
def get_next_device_index (self ):
460
456
0 commit comments