Skip to content

Commit 49d8adb

Browse files
author
Weichao Luo
committed
fixup.
1 parent e2bb7b7 commit 49d8adb

File tree

2 files changed

+25
-31
lines changed

2 files changed

+25
-31
lines changed

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
thread_local_data = threading.local()
3030

3131
KV_MOVE_MAX_NUM = 16
32-
KV_MOVE_MAX_RESTART_CNT = 3
32+
KV_MOVE_MAX_START_CNT = 3
3333

3434

3535
@dataclass
@@ -288,16 +288,12 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
288288
# 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。
289289
self.device_locks = [threading.Lock() for _ in range(self.node_world_size)]
290290

291-
self.kv_trans_processes = []
292-
self.kv_trans_task_in_queues = []
293-
self.kv_trans_task_out_queues = []
294-
self.kv_trans_process_restart_cnt = []
291+
self.kv_trans_processes = [None] * self.node_world_size
292+
self.kv_trans_task_in_queues = [None] * self.node_world_size
293+
self.kv_trans_task_out_queues = [None] * self.node_world_size
294+
self.kv_trans_process_start_cnt = [0] * self.node_world_size
295295

296296
for device_id in range(self.node_world_size):
297-
self.kv_trans_task_in_queues.append(mp.Queue())
298-
self.kv_trans_task_out_queues.append(mp.Queue())
299-
self.kv_trans_process_restart_cnt.append(0)
300-
self.kv_trans_processes.append(None)
301297
assert self.start_trans_process(device_id)
302298

303299
return
@@ -525,9 +521,9 @@ def remove_trans_obj_by_deviceid(self, device_id):
525521
self.remove_dead_trans_obj(node_id)
526522

527523
def start_trans_process(self, device_id: int):
528-
task_in_queue = self.kv_trans_task_in_queues[device_id]
529-
task_out_queue = self.kv_trans_task_out_queues[device_id]
530-
self.kv_trans_process_restart_cnt[device_id] += 1
524+
task_in_queue = mp.Queue()
525+
task_out_queue = mp.Queue()
526+
self.kv_trans_process_start_cnt[device_id] += 1
531527

532528
if self.kv_trans_processes[device_id]:
533529
# force kill
@@ -554,14 +550,16 @@ def start_trans_process(self, device_id: int):
554550
assert task_out_queue.get(timeout=60) == "get_mem_managers_ok"
555551

556552
self.kv_trans_processes[device_id] = kv_trans_process
553+
self.kv_trans_task_in_queues[device_id] = task_in_queue
554+
self.kv_trans_task_out_queues[device_id] = task_out_queue
557555

558556
return True
559557
except Exception as e:
560558
logger.warning(f"Failed start kv trans process for device {device_id}: {e}")
561559
return False
562560

563561
def is_kv_trans_process_alive(self, device_id):
564-
return self.kv_trans_process_restart_cnt[device_id] <= KV_MOVE_MAX_RESTART_CNT
562+
return self.kv_trans_process_start_cnt[device_id] <= KV_MOVE_MAX_START_CNT
565563

566564
def check_trans_process(self, raise_exception=True):
567565
at_least_one_alive = False

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from lightllm.utils.envs_utils import get_unique_server_name
3030

3131
KV_MOVE_MAX_NUM = 16
32-
KV_MOVE_MAX_RESTART_CNT = 3
32+
KV_MOVE_MAX_START_CNT = 3
3333

3434
logger = init_logger(__name__)
3535

@@ -348,20 +348,13 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
348348

349349
from .prefill_trans_process import start_prefill_trans_process
350350

351-
self.kv_trans_ports = []
352-
self.kv_trans_processes = []
353-
self.kv_trans_task_in_queues = []
354-
self.kv_trans_task_out_queues = []
355-
self.kv_trans_process_restart_cnt = []
351+
self.kv_trans_ports = [None] * self.node_world_size
352+
self.kv_trans_processes = [None] * self.node_world_size
353+
self.kv_trans_task_in_queues = [None] * self.node_world_size
354+
self.kv_trans_task_out_queues = [None] * self.node_world_size
355+
self.kv_trans_process_start_cnt = [0] * self.node_world_size
356356

357357
for device_id in range(self.node_world_size):
358-
self.kv_trans_task_in_queues.append(mp.Queue())
359-
self.kv_trans_task_out_queues.append(mp.Queue())
360-
self.kv_trans_ports.append(
361-
find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max)
362-
)
363-
self.kv_trans_process_restart_cnt.append(0)
364-
self.kv_trans_processes.append(None)
365358
assert self.start_trans_process(device_id)
366359

367360
return
@@ -385,10 +378,10 @@ def handle_release_task_loop(self):
385378
return
386379

387380
def start_trans_process(self, device_id: int):
388-
task_in_queue = self.kv_trans_task_in_queues[device_id]
389-
task_out_queue = self.kv_trans_task_out_queues[device_id]
390-
kv_trans_port = self.kv_trans_ports[device_id]
391-
self.kv_trans_process_restart_cnt[device_id] += 1
381+
task_in_queue = mp.Queue()
382+
task_out_queue = mp.Queue()
383+
kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max)
384+
self.kv_trans_process_start_cnt[device_id] += 1
392385

393386
if self.kv_trans_processes[device_id]:
394387
# force kill
@@ -417,6 +410,9 @@ def start_trans_process(self, device_id: int):
417410
assert task_out_queue.get(timeout=60) == "get_mem_managers_ok"
418411

419412
self.kv_trans_processes[device_id] = kv_trans_process
413+
self.kv_trans_task_in_queues[device_id] = task_in_queue
414+
self.kv_trans_task_out_queues[device_id] = task_out_queue
415+
self.kv_trans_ports[device_id] = kv_trans_port
420416

421417
return True
422418
except Exception as e:
@@ -454,7 +450,7 @@ def check_trans_process_loop(self):
454450
raise e
455451

456452
def is_kv_trans_process_alive(self, device_id):
457-
return self.kv_trans_process_restart_cnt[device_id] <= KV_MOVE_MAX_RESTART_CNT
453+
return self.kv_trans_process_start_cnt[device_id] <= KV_MOVE_MAX_START_CNT
458454

459455
def get_next_device_index(self):
460456

0 commit comments

Comments
 (0)