From 452a0ac8bb50d1b8f21ad7027caccf334862b0c8 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Fri, 7 Mar 2025 19:55:21 +0800 Subject: [PATCH 01/20] single kv transfer process for pd. --- lightllm/common/deepseek2_mem_manager.py | 21 +- lightllm/common/mem_manager.py | 23 +- lightllm/distributed/pynccl.py | 332 ++++++++++++++ lightllm/distributed/pynccl_wrapper.py | 405 ++++++++++++++++++ lightllm/server/pd_io_struct.py | 13 + .../decode_kv_move_manager.py | 111 ++--- .../decode_node_impl/decode_trans_process.py | 113 ++--- .../prefill_kv_move_manager.py | 82 ++-- .../prefill_trans_process.py | 131 +++--- 9 files changed, 1011 insertions(+), 220 deletions(-) create mode 100644 lightllm/distributed/pynccl.py create mode 100644 lightllm/distributed/pynccl_wrapper.py diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 7ae70d46f..94dec293e 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -7,6 +7,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node +from lightllm.distributed.pynccl import PyNcclCommunicator logger = init_logger(__name__) @@ -35,7 +36,8 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): assert dp_size_in_node == 1 @@ -49,7 +51,7 @@ def send_to_decode_node( cur_mem = mem_managers[cur_device_index] for layer_index in range(cur_mem.layer_num): move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index) - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) return def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): @@ -61,7 +63,8 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): assert dp_size_in_node == 1 @@ -76,7 +79,7 @@ def receive_from_prefill_node( move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim) for layer_index in range(self.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) for i, mem in enumerate(mem_managers): if i == cur_device_index: mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) @@ -93,7 +96,8 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -120,7 +124,7 @@ def send_to_decode_node_p2p( move_buffer = self._get_kv_move_data_p2p( move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node ) - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) return def _get_kv_move_data_p2p( @@ -145,7 +149,8 @@ def _get_kv_move_data_p2p( return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): if not hasattr(self, "mem_ptrs_dict"): self.mem_ptrs_dict = {} @@ -170,7 +175,7 @@ def receive_from_prefill_node_p2p( move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim) for layer_index in range(self.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) self._write_kv_move_data_p2p( move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node ) diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 5e701effa..8b3e1b73b 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -10,6 +10,7 @@ from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.distributed.pynccl import PyNcclCommunicator logger = init_logger(__name__) @@ -91,7 +92,8 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): assert dp_size_in_node == 1 @@ -108,14 +110,14 @@ def send_to_decode_node( for layer_index in range(mem.layer_num): move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index) if i == cur_device_index: - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) else: move_size = move_buffer.numel() new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape) from torch.cuda import comm comm.broadcast(move_buffer, out=[new_move_buffer]) - dist.send(new_move_buffer, dst=1) + nccl_comm.send(new_move_buffer, dst=1) return def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): @@ -127,7 +129,8 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -144,7 +147,7 @@ def receive_from_prefill_node( recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim) for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) if i == cur_device_index: mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) else: @@ -160,7 +163,8 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -178,7 +182,7 @@ def send_to_decode_node_p2p( for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer) - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) return def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): @@ -191,7 +195,8 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -209,7 +214,7 @@ def receive_from_prefill_node_p2p( recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim) for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index) return diff --git a/lightllm/distributed/pynccl.py b/lightllm/distributed/pynccl.py new file mode 100644 index 000000000..9a01dd116 --- /dev/null +++ b/lightllm/distributed/pynccl.py @@ -0,0 +1,332 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl.py +# of the vllm-project/vllm GitHub repository. +# +# Copyright 2023 ModelTC Team +# Copyright 2023 vLLM Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from datetime import timedelta +import pickle +import time +from typing import Optional, Union, Dict, Deque, Tuple, Any +from collections import deque +import logging + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp, TCPStore + +from lightllm.distributed.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) + +logger = logging.getLogger(__name__) + +_current_stream = None + +def current_stream() -> torch.cuda.Stream: + global _current_stream + if _current_stream is None: + _current_stream = torch.cuda.current_stream() + return _current_stream + +@dataclasses.dataclass +class StatelessP2PProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + + dest_id: int + src_id: int + is_server: bool + + rank: int = 0 + world_size: int = 2 + store: TCPStore = None + data_expiration_seconds: int = 3600 # 1 hour + # dst rank -> counter + send_dst_counter: int = 0 + # src rank -> counter + recv_src_counter: int = 0 + entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + self.rank = 0 if self.is_server else 1 + self.world_size = 2 + self.send_dst_counter = 0 + self.recv_src_counter = 0 + + def send_obj(self, obj: Any): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{self.dest_id}/{self.send_dst_counter}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter += 1 + self.entries.append((key, time.time())) + + def expire_data(self): + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.time() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get( + f"send_to/{self.dest_id}/{self.recv_src_counter}")) + self.recv_src_counter += 1 + return obj + + @staticmethod + def create( + src_id: int, + dest_id: int, + is_server: bool, + store: torch._C._distributed_c10d.Store + ) -> "StatelessP2PProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + return StatelessP2PProcessGroup(src_id=src_id, dest_id=dest_id, is_server=is_server, store=store) + + +class PyNcclCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessP2PProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessP2PProcessGroup): + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "PyNcclCommunicator should be attached to a non-NCCL group.") + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("LightLLM is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessP2PProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + if group.rank == 0: + group.send_obj(self.unique_id) + else: + self.unique_id = group.recv_obj() + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + stream.synchronize() + del data + + def destroy(self): + self.nccl.ncclCommDestroy(self.comm) + + def all_reduce(self, + in_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None) -> torch.Tensor: + if self.disabled: + return None + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + return out_tensor + + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + + def reduce_scatter(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + diff --git a/lightllm/distributed/pynccl_wrapper.py b/lightllm/distributed/pynccl_wrapper.py new file mode 100644 index 000000000..d35ec8e3a --- /dev/null +++ b/lightllm/distributed/pynccl_wrapper.py @@ -0,0 +1,405 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl_wrapper.py +# of the vllm-project/vllm GitHub repository. +# +# Copyright 2023 ModelTC Team +# Copyright 2023 vLLM Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +import logging + +logger = logging.getLogger(__name__) + + +def find_nccl_library() -> str: + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ + so_file = None + + # manually load the nccl library + if so_file: + logger.info( + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", + so_file) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info("Found nccl from library %s", so_file) + return so_file + + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("ncclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s. " + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] + + + +def test_ncclGetUniqueId(): + lib = NCCLLibrary() + unique_id = lib.ncclGetUniqueId() + print(unique_id.internal) + # `list(unique_id.internal)` is something like this: + # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # as long as the function doesn't raise an exception, we're good + assert unique_id is not None + +if __name__ == '__main__': + import torch; + torch.cuda.set_device(0) + test_ncclGetUniqueId() diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index d5d22c8ea..222cd5887 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -74,6 +74,19 @@ class DecodeNodeInfo: rpyc_port: str max_new_tokens: int +@dataclass +class PDTransJoinInfo: + decode_id: int + decode_device_id: int + prefill_id: int + prefill_device_id: int + prefill_ip: str + prefill_port: int + +@dataclass +class PDTransLeaveInfo: + decode_id: int + prefill_id: int @dataclass class KVMoveTask: diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 30096e3e5..0c46b6dd5 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -16,7 +16,7 @@ from .decode_infer_rpyc import PDDecodeInferRpcServer from ..task_queue import TaskQueue import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus +from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus, PDTransJoinInfo, PDTransLeaveInfo from lightllm.utils.retry_utils import retry import numpy as np from rpyc import AsyncResult @@ -33,12 +33,11 @@ @dataclass class TransProcessObj: - prefill_node_id: str = None - process: mp.Process = None + prefill_node_id: int = None task_in_queue: mp.Queue = None task_out_queue: mp.Queue = None - nccl_ip: str = None - nccl_port: str = None + prefill_ip: str = None + prefill_port: int = None device_index: int = None manager: "DecodeKVMoveManager" = None has_error: bool = False @@ -48,26 +47,31 @@ class TransProcessObj: put_to_radix_thread: threading.Thread = None latest_check_time: float = None - def create(self, prefill_node_id: str, nccl_ip: str, nccl_port: int, manager: "DecodeKVMoveManager"): - from .decode_trans_process import start_decode_trans_process + def create( + self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manager: "DecodeKVMoveManager" + ): - task_in_queue = mp.Queue() - task_out_queue = mp.Queue() device_index = manager.get_next_device_index() - proc = start_decode_trans_process( - manager.args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, manager.mem_queues - ) - assert task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_mem_queue() - assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" + decode_node_id = manager.args.pd_node_id + task_in_queue = manager.kv_trans_task_in_queue + task_out_queue = manager.kv_trans_task_out_queue + + task_in_queue.put(PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=-1, + prefill_ip=prefill_ip, + prefill_port=prefill_port, + decode_id=decode_node_id, + decode_device_id=device_index, + )) assert task_out_queue.get(timeout=60) == "nccl_ok" self.prefill_node_id = prefill_node_id - self.process = proc + self.decode_node_id = decode_node_id self.task_in_queue = task_in_queue self.task_out_queue = task_out_queue - self.nccl_ip = nccl_ip - self.nccl_port = nccl_port + self.prefill_ip = prefill_ip + self.prefill_port = prefill_port self.device_index = device_index self.manager = manager @@ -86,20 +90,6 @@ def create(self, prefill_node_id: str, nccl_ip: str, nccl_port: int, manager: "D self.put_to_radix_thread.start() return - def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.set_has_error() - if raise_exception: - raise Exception(f"trans process: {self.process.pid} is dead") - return - - def timer_to_check_status(self, raise_exception=True): - if time.time() - self.latest_check_time >= 2.0: - self.latest_check_time = time.time() - self.check_trans_process(raise_exception=raise_exception) - return - def _transfer_kv(self, move_tasks: List[KVMoveTask]): with self.manager.device_locks[self.device_index]: self.task_in_queue.put(move_tasks.copy(), timeout=10) @@ -130,8 +120,6 @@ def kv_move_loop(self): logger.info(f"{func_name} get task {task.to_decode_log_info()}") try: - self.timer_to_check_status(raise_exception=True) - if not kv_trans_use_p2p(): with self.manager.kv_trans_lock: self._transfer_kv(move_tasks) @@ -148,6 +136,10 @@ def kv_move_loop(self): self.manager.put_to_fail_release_task_queue(move_tasks) logger.error(f"{func_name} prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") + self.task_in_queue.put(PDTransLeaveInfo( + decode_id=self.decode_node_id, + prefill_id=self.prefill_node_id + )) return def put_to_radix_loop(self): @@ -163,8 +155,6 @@ def put_to_radix_loop(self): try: # random to check stats - self.timer_to_check_status(raise_exception=True) - self.manager._put_kv_received_to_radix_cache(move_tasks.copy()) for task in move_tasks.copy(): logger.info( @@ -239,12 +229,6 @@ def __del__(self): logger.error(f"trans obj deled, prefill node id {self.prefill_node_id} device_index {self.device_index}") - # 强制关闭连接和杀掉传输进程 - if self.process is not None: - logger.warning(f"trans kv process {self.process.pid} is killed") - os.kill(self.process.pid, signal.SIGKILL) - pass - class DecodeKVMoveManager(rpyc.Service): def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): @@ -284,6 +268,18 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.kv_trans_lock = threading.Lock() # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] + + # start a single kv trans process + self.kv_trans_task_in_queue = mp.Queue() + self.kv_trans_task_out_queue = mp.Queue() + from .decode_trans_process import start_decode_trans_process + self.kv_trans_process = start_decode_trans_process( + self.args, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) + + assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + return def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): @@ -392,17 +388,17 @@ def exposed_check_alive(self): # 用于 prefill node check 通信连接的状态。 return - def exposed_build_trans_process(self, prefill_node_id, nccl_ip, nccl_port, prefill_node_max_kv_trans_num): - prefill_node_id, nccl_ip, nccl_port, prefill_node_max_kv_trans_num = list( - map(obtain, [prefill_node_id, nccl_ip, nccl_port, prefill_node_max_kv_trans_num]) + def exposed_build_trans_process(self, prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num): + prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num = list( + map(obtain, [prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num]) ) thread_local_data.prefill_node_id = prefill_node_id - logger.info(f"build trans infos {prefill_node_id} {nccl_ip} {nccl_port}") + logger.info(f"build trans infos {prefill_node_id} {prefill_ip} {prefill_port}") # 如果有历史残留,一并移除 self.remove_trans_obj(prefill_node_id) tran_obj = TransProcessObj() - tran_obj.create(prefill_node_id, nccl_ip, nccl_port, self) + tran_obj.create(prefill_node_id, prefill_ip, prefill_port, self) self.node_id_to_trans_obj[prefill_node_id] = tran_obj return min(prefill_node_max_kv_trans_num, self.args.max_total_token_num) @@ -499,10 +495,25 @@ def remove_trans_obj(self, prefill_node_id): trans_obj.set_has_error() return + def check_trans_process(self, raise_exception=True): + process = psutil.Process(self.kv_trans_process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + if raise_exception: + raise Exception(f"trans process: {self.kv_trans_process.pid} is dead") + return + def timer_loop(self): - while True: - self._unfrozen_time_out_reqs_tokens() - time.sleep(3.5) + try: + last_check_time = time.time() + while True: + self._unfrozen_time_out_reqs_tokens() + time.sleep(3.5) + if last_check_time - time.time() > 10.0: + self.check_trans_process() + last_check_time = time.time() + except (BaseException, RuntimeError) as e: + logger.exception(str(e)) + raise e def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index b70bf8efe..010074b10 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -3,91 +3,100 @@ import sys import inspect import torch.multiprocessing as mp -from typing import List, Dict +from torch.distributed import TCPStore +from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry +from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup logger = init_logger(__name__) +def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], prefill_to_comm: Dict[int, PyNcclCommunicator], + dp_size_in_node: int): + total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) + try: + prefill_id = move_tasks[0].prefill_node_id + device_index = prefill_to_comm[prefill_id].device.index + start = time.time() + if total_move_kv_len != 0: + cur_mem = mem_managers[device_index] + logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") + if kv_trans_use_p2p(): + cur_mem.receive_from_prefill_node_p2p(move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id]) + else: + cur_mem.receive_from_prefill_node(move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id]) + logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") + torch.cuda.synchronize() + logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") + task_out_queue.put("ok") + except BaseException as e: + logger.exception(str(e)) + task_out_queue.put("fail") + raise e + +def _handle_prefill_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, prefill_to_comm: Dict[int, PyNcclCommunicator]): + try: + store_client = TCPStore(host_name=node_info.prefill_ip, port=node_info.prefill_port, is_master=False, use_libuv=False) + group = StatelessP2PProcessGroup.create( + src_id=node_info.prefill_id, + dest_id=node_info.decode_id, + is_server=False, + store=store_client) + comm = PyNcclCommunicator(group, node_info.decode_device_id) + prefill_to_comm[node_info.prefill_id] = comm + logger.info(f"{node_info} kv trans connected") + task_out_queue.put('nccl_ok') + except Exception as e: + logger.warning(f"error while connect to prefill node: {e}") + def _init_env( args, - device_index: int, - nccl_ip, - nccl_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], -): - import os - - # os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_MAX_NCHANNELS"] = "2" - os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" - os.environ["NCCL_SOCKET_NTHREADS"] = "1" - torch.backends.cudnn.enabled = False + mem_queues: List[mp.Queue]): dp_size_in_node = max(1, args.dp // args.nnodes) node_world_size = args.tp // args.nnodes try: - # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - torch.cuda.set_device(device_index) - task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] assert len(mem_managers) == node_world_size task_out_queue.put("get_mem_managers_ok") - import torch.distributed as dist - from datetime import timedelta - - dist.init_process_group( - "nccl", init_method=f"tcp://{nccl_ip}:{nccl_port}", rank=1, world_size=2, timeout=timedelta(seconds=60) - ) - task_out_queue.put("nccl_ok") + prefill_to_comm: Dict[int, PyNcclCommunicator] = {} while True: - move_tasks: List[KVMoveTask] = task_in_queue.get() - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - start = time.time() - if total_move_kv_len != 0: - cur_mem = mem_managers[device_index] - logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") - if kv_trans_use_p2p(): - cur_mem.receive_from_prefill_node_p2p(move_tasks, mem_managers, dp_size_in_node) - else: - cur_mem.receive_from_prefill_node(move_tasks, mem_managers, dp_size_in_node) - logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - raise e - except BaseException as e: - logger.exception(str(e)) - sys.exit(-1) - return + task: Union[List, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() + if isinstance(task, List): + _handle_kvmove_task(task, task_out_queue, mem_managers, prefill_to_comm, dp_size_in_node) + elif isinstance(task, PDTransJoinInfo): + _handle_prefill_join(task, task_out_queue, prefill_to_comm) + elif isinstance(task, PDTransLeaveInfo): + prefill_to_comm[task.prefill_id].destroy() + logger.info(f"destory {task.prefill_id} nccl communicator.") + else: + logger.warning(f'unexpected task type: {task}') + + except Exception as e: + logger.error(f"Fatal error happened in kv trans process: {e}") + raise def start_decode_trans_process( args, - device_index: int, - nccl_ip, - nccl_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): proc = mp.Process( - target=_init_env, args=(args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, mem_queues) + target=_init_env, args=(args, task_in_queue, task_out_queue, mem_queues) ) proc.start() assert proc.is_alive() - logger.info(f"decode trans kv process start, nccl_ip: {nccl_ip}, nccl_port: {nccl_port}") + logger.info(f"decode trans kv process start!") return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 27b0fbb19..fbff30d20 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -17,7 +17,7 @@ from .prefill_infer_rpyc import PDPrefillInferRpcServer from lightllm.common.mem_manager import MemoryManager import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo from lightllm.utils.net_utils import find_available_port from lightllm.utils.retry_utils import retry from rpyc.utils.classic import obtain @@ -35,13 +35,10 @@ @dataclass class TransProcessObj: - decode_node_id: str = None + decode_node_id: int = None rpyc_conn: object = None # rpyc_con 的连接对象 - process: mp.Process = None task_in_queue: mp.Queue = None task_out_queue: mp.Queue = None - nccl_ip: str = None - nccl_port: str = None device_index: str = None # 使用的gpu序号 manager: "PrefillKVMoveManager" = None has_error: bool = False @@ -52,42 +49,38 @@ class TransProcessObj: latest_check_time: float = None def create( - self, decode_node_id: str, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" + self, decode_node_id: int, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" ): con = rpyc.connect( host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True ) - nccl_ip = manager.host_ip - nccl_port = find_available_port(manager.args.pd_p_allowed_port_min, manager.args.pd_p_allowed_port_max) - if nccl_port is None: - raise Exception("no pd nccl port can be used") - - from .prefill_trans_process import start_prefill_trans_process device_index = manager.get_next_device_index() # 分配 trans 进程使用的显卡 - task_in_queue = mp.Queue() - task_out_queue = mp.Queue() - proc = start_prefill_trans_process( - manager.args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, manager.mem_queues - ) - assert task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_mem_queue() - assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" prefill_node_id = manager.args.pd_node_id + task_in_queue = manager.kv_trans_task_in_queue + task_out_queue = manager.kv_trans_task_out_queue + + task_in_queue.put(PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=device_index, + prefill_ip=manager.host_ip, + prefill_port=manager.kv_trans_port, + decode_id=decode_node_id, + decode_device_id=-1 + )) + # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 max_kv_trans_token_num = obtain( - con.root.build_trans_process(prefill_node_id, nccl_ip, nccl_port, manager.args.max_total_token_num) + con.root.build_trans_process(prefill_node_id, manager.host_ip, manager.kv_trans_port, manager.args.max_total_token_num) ) self.max_kv_trans_token_num = max_kv_trans_token_num assert task_out_queue.get(timeout=60) == "nccl_ok" self.decode_node_id = decode_node_id + self.prefill_node_id = prefill_node_id self.rpyc_conn = con - self.process = proc self.task_in_queue = task_in_queue self.task_out_queue = task_out_queue - self.nccl_port = nccl_port - self.nccl_ip = nccl_ip self.device_index = device_index self.manager = manager self.latest_check_time = time.time() @@ -114,13 +107,6 @@ def _get_request_tasks(self, datas: List[KVMoveTask]): break return ans_list - def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.set_has_error() - if raise_exception: - raise Exception(f"trans process: {self.process.pid} is dead") - return def check_connect(self, raise_exception=True): try: @@ -134,7 +120,6 @@ def check_connect(self, raise_exception=True): def timer_check_status(self, raise_exception=True): if time.time() - self.latest_check_time >= 2.0: self.latest_check_time = time.time() - self.check_trans_process(raise_exception=raise_exception) self.check_connect(raise_exception=raise_exception) if self.has_error: self.manager.remove_trans_obj(self.decode_node_id) @@ -249,6 +234,8 @@ def kv_trans_handle_loop(self): self.manager.put_to_release_task_queue(move_tasks) logger.error(f"trans kv thread, decode id {self.decode_node_id} device_index {self.device_index} thread quit") + self.task_in_queue.put(PDTransLeaveInfo( + decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def wait_thread_quit(self): @@ -302,12 +289,6 @@ def __del__(self): logger.error(f"trans obj deled, decode node id {self.decode_node_id} device_index {self.device_index}") - # 强制关闭连接和杀掉传输进程 - if self.process is not None: - logger.warning(f"prefill trans process {self.process.pid} is killed") - os.kill(self.process.pid, signal.SIGKILL) - pass - class PrefillKVMoveManager: def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): @@ -344,6 +325,19 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.release_task_queue = TaskQueue(lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=None) self.release_tasks_thread = threading.Thread(target=self.handle_release_task_loop, daemon=True) self.release_tasks_thread.start() + + # start a single kv trans process + self.kv_trans_task_in_queue = mp.Queue() + self.kv_trans_task_out_queue = mp.Queue() + from .prefill_trans_process import start_decode_trans_process + self.kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) + self.kv_trans_process = start_decode_trans_process( + self.args, self.host_ip, self.kv_trans_port, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) + + assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + return def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): @@ -364,6 +358,13 @@ def handle_release_task_loop(self): self._remove_req_refs_from_prompt_cache(handle_list) return + def check_trans_process(self, raise_exception=True): + process = psutil.Process(self.kv_trans_process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + if raise_exception: + raise Exception(f"trans process: {self.kv_trans_process.pid} is dead") + return + def get_next_device_index(self): counts = [0 for _ in range(self.node_world_size)] for obj in self.node_id_to_trans_obj.values(): @@ -403,6 +404,7 @@ def remove_dead_trans_obj(self): def task_dispatcher_loop(self): try: + last_check_time = time.time() # 获取任务,并分发给相关卡的处理队列 while True: move_task: KVMoveTask = self.info_queue.get() @@ -415,6 +417,10 @@ def task_dispatcher_loop(self): finally: trans_obj = None + if time.time() - last_check_time > 10.0: + self.check_trans_process() + last_check_time = time.time() + except (BaseException, RuntimeError) as e: logger.exception(str(e)) raise e diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 1973aabac..b9b9b7242 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -3,97 +3,102 @@ import sys import inspect import torch.multiprocessing as mp -from typing import List, Dict +from torch.distributed import TCPStore +from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry +from lightllm.distributed.pynccl import StatelessP2PProcessGroup, PyNcclCommunicator logger = init_logger(__name__) -# device_index 是用来指示,当前传输进程使用的用于数据传输的显卡id -# 当模型是多卡推理的时候,需要传输的 kv 需要先移动到 device_index -# 指定的显卡上,然后再进行传输,因为torch nccl 限制了只能操作一张显卡上的数据 +def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], decode_to_comm: Dict[int, PyNcclCommunicator], + dp_size_in_node: int): + total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) + try: + decode_id = move_tasks[0].decode_node.node_id + device_index = decode_to_comm[decode_id].device.index + torch.cuda.set_device(device_index) + start = time.time() + if total_move_kv_len != 0: + logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") + cur_mem = mem_managers[device_index] + if kv_trans_use_p2p(): + cur_mem.send_to_decode_node_p2p(move_tasks, mem_managers, dp_size_in_node, decode_to_comm[decode_id]) + else: + cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node, decode_to_comm[decode_id]) + logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") + torch.cuda.synchronize() + logger.info( + f"trans cost time: {(time.time() - start)}," + f"move_total_kv_len: {total_move_kv_len}, {move_tasks[0].to_prefill_log_info()}" + ) + task_out_queue.put("ok") + except BaseException as e: + logger.exception(str(e)) + task_out_queue.put("fail") + +def _handle_decode_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, decode_to_comm: Dict[str, PyNcclCommunicator], store: TCPStore): + try: + group = StatelessP2PProcessGroup.create(node_info.prefill_id, node_info.decode_id, True, store) + comm = PyNcclCommunicator(group, node_info.prefill_device_id) + decode_to_comm[node_info.decode_id] = comm + logger.info(f"{node_info} kv trans connected!") + task_out_queue.put("nccl_ok") + except Exception as e: + logger.warning(f"error while connect to decode node: {e}") + def _init_env( args, - device_index: int, - nccl_ip, - nccl_port, + store_ip, + store_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], -): - import os - - # os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_MAX_NCHANNELS"] = "2" - os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" - os.environ["NCCL_SOCKET_NTHREADS"] = "1" - torch.backends.cudnn.enabled = False - - dp_size_in_node = max(1, args.dp // args.nnodes) - node_world_size = args.tp // args.nnodes - + mem_queues: List[mp.Queue],): try: - # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - torch.cuda.set_device(device_index) - + master_store = TCPStore(host_name=store_ip, port=store_port, is_master=True, use_libuv=True) + dp_size_in_node = max(1, args.dp // args.nnodes) + node_world_size = args.tp // args.nnodes task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] assert len(mem_managers) == node_world_size task_out_queue.put("get_mem_managers_ok") - import torch.distributed as dist - from datetime import timedelta + decode_to_comm: Dict[int, PyNcclCommunicator] = {} - dist.init_process_group( - "nccl", init_method=f"tcp://{nccl_ip}:{nccl_port}", rank=0, world_size=2, timeout=timedelta(seconds=60) - ) - task_out_queue.put("nccl_ok") while True: - move_tasks: List[KVMoveTask] = task_in_queue.get() - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - start = time.time() - if total_move_kv_len != 0: - logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") - cur_mem = mem_managers[device_index] - if kv_trans_use_p2p(): - cur_mem.send_to_decode_node_p2p(move_tasks, mem_managers, dp_size_in_node) - else: - cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node) - logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info( - f"trans cost time: {(time.time() - start)}," - f"move_total_kv_len: {total_move_kv_len}, {move_tasks[0].to_prefill_log_info()}" - ) - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - raise e - except BaseException as e: - logger.exception(str(e)) - sys.exit(-1) - return + task: Union[List, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() + if isinstance(task, List): + _handle_kvmove_task(task, task_out_queue, mem_managers, decode_to_comm, dp_size_in_node) + elif isinstance(task, PDTransJoinInfo): + _handle_decode_join(task, task_out_queue, decode_to_comm, master_store) + elif isinstance(task, PDTransLeaveInfo): + decode_to_comm[task.decode_id].destroy() + logger.info(f"destory {task.decode_id} nccl communicator.") + else: + logger.warning(f'unexpected task type: {task}') + + except Exception as e: + logger.error(f"Fatal error happened in kv trans process: {e}") + pass -def start_prefill_trans_process( +def start_decode_trans_process( args, - device_index: int, - nccl_ip, - nccl_port, + store_ip, + store_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): proc = mp.Process( - target=_init_env, args=(args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, mem_queues) + target=_init_env, args=(args, store_ip, store_port, task_in_queue, task_out_queue, mem_queues) ) proc.start() assert proc.is_alive() - logger.info(f"trans kv process start, nccl_ip: {nccl_ip}, nccl_port: {nccl_port}") - return proc + logger.info(f"trans kv process started!") + return proc \ No newline at end of file From 9d6e4fde20a32be412b24195324fae137ef16acf Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 10 Mar 2025 13:02:01 +0800 Subject: [PATCH 02/20] fix name. --- .../pd_mode/prefill_node_impl/prefill_kv_move_manager.py | 4 ++-- .../pd_mode/prefill_node_impl/prefill_trans_process.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index fbff30d20..e0c342654 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -329,9 +329,9 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # start a single kv trans process self.kv_trans_task_in_queue = mp.Queue() self.kv_trans_task_out_queue = mp.Queue() - from .prefill_trans_process import start_decode_trans_process + from .prefill_trans_process import start_prefill_trans_process self.kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) - self.kv_trans_process = start_decode_trans_process( + self.kv_trans_process = start_prefill_trans_process( self.args, self.host_ip, self.kv_trans_port, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index b9b9b7242..b6fa0f032 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -87,7 +87,7 @@ def _init_env( pass -def start_decode_trans_process( +def start_prefill_trans_process( args, store_ip, store_port, From 083b9d6b7b7193a3b66e2dfdf600ad9c646aaab7 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 10 Mar 2025 13:59:13 +0800 Subject: [PATCH 03/20] fix style. --- lightllm/common/deepseek2_mem_manager.py | 28 ++- lightllm/common/mem_manager.py | 24 ++- lightllm/distributed/pynccl.py | 138 +++++++----- lightllm/distributed/pynccl_wrapper.py | 201 ++++++++++-------- lightllm/server/pd_io_struct.py | 3 + .../decode_kv_move_manager.py | 33 ++- .../decode_node_impl/decode_trans_process.py | 50 +++-- .../prefill_kv_move_manager.py | 37 ++-- .../prefill_trans_process.py | 30 ++- 9 files changed, 318 insertions(+), 226 deletions(-) diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 94dec293e..ddf2478df 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -36,8 +36,11 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["Deepseek2MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -63,8 +66,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -96,8 +102,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -149,8 +158,11 @@ def _get_kv_move_data_p2p( return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): if not hasattr(self, "mem_ptrs_dict"): self.mem_ptrs_dict = {} diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 8b3e1b73b..aae7112ff 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -92,8 +92,11 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -129,7 +132,10 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -163,8 +169,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -195,7 +204,10 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 diff --git a/lightllm/distributed/pynccl.py b/lightllm/distributed/pynccl.py index 9a01dd116..3637b04dd 100644 --- a/lightllm/distributed/pynccl.py +++ b/lightllm/distributed/pynccl.py @@ -33,19 +33,27 @@ from torch.distributed import ProcessGroup, ReduceOp, TCPStore from lightllm.distributed.pynccl_wrapper import ( - NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, - ncclRedOpTypeEnum, ncclUniqueId) + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) logger = logging.getLogger(__name__) _current_stream = None + def current_stream() -> torch.cuda.Stream: global _current_stream if _current_stream is None: _current_stream = torch.cuda.current_stream() return _current_stream + @dataclasses.dataclass class StatelessP2PProcessGroup: """A dataclass to hold a metadata store, and the rank, world_size of the @@ -94,18 +102,13 @@ def expire_data(self): def recv_obj(self) -> Any: """Receive an object from a source rank.""" - obj = pickle.loads( - self.store.get( - f"send_to/{self.dest_id}/{self.recv_src_counter}")) + obj = pickle.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}")) self.recv_src_counter += 1 return obj @staticmethod def create( - src_id: int, - dest_id: int, - is_server: bool, - store: torch._C._distributed_c10d.Store + src_id: int, dest_id: int, is_server: bool, store: torch._C._distributed_c10d.Store ) -> "StatelessP2PProcessGroup": """A replacement for `torch.distributed.init_process_group` that does not pollute the global state. @@ -121,12 +124,11 @@ def create( used for exchanging metadata. With this function, process A and process B can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. - """ # noqa + """ # noqa return StatelessP2PProcessGroup(src_id=src_id, dest_id=dest_id, is_server=is_server, store=store) class PyNcclCommunicator: - def __init__( self, group: Union[ProcessGroup, StatelessP2PProcessGroup], @@ -146,8 +148,9 @@ def __init__( """ if not isinstance(group, StatelessP2PProcessGroup): assert dist.is_initialized() - assert dist.get_backend(group) != dist.Backend.NCCL, ( - "PyNcclCommunicator should be attached to a non-NCCL group.") + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "PyNcclCommunicator should be attached to a non-NCCL group." # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) @@ -207,8 +210,7 @@ def __init__( # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one with torch.cuda.device(device): - self.comm: ncclComm_t = self.nccl.ncclCommInitRank( - self.world_size, self.unique_id, self.rank) + self.comm: ncclComm_t = self.nccl.ncclCommInitRank(self.world_size, self.unique_id, self.rank) stream = current_stream() # A small all_reduce for warmup. @@ -220,10 +222,7 @@ def __init__( def destroy(self): self.nccl.ncclCommDestroy(self.comm) - def all_reduce(self, - in_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None) -> torch.Tensor: + def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor: if self.disabled: return None # nccl communicator created on a specific device @@ -231,24 +230,25 @@ def all_reduce(self, # otherwise it will cause "illegal memory access" assert in_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {in_tensor.device}") + f"but the input tensor is on {in_tensor.device}" + ) out_tensor = torch.empty_like(in_tensor) if stream is None: stream = current_stream() - self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), - buffer_type(out_tensor.data_ptr()), - in_tensor.numel(), - ncclDataTypeEnum.from_torch(in_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) + self.nccl.ncclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) return out_tensor - def all_gather(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - stream=None): + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None): if self.disabled: return # nccl communicator created on a specific device @@ -256,20 +256,22 @@ def all_gather(self, # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, - cudaStream_t(stream.cuda_stream)) - - def reduce_scatter(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None): + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def reduce_scatter( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None + ): if self.disabled: return # nccl communicator created on a specific device @@ -277,46 +279,63 @@ def reduce_scatter(self, # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() self.nccl.ncclReduceScatter( buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() - self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), dst, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def recv(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() - self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def broadcast(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() if src == self.rank: @@ -326,7 +345,12 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): else: sendbuff = buffer_type() recvbuff = buffer_type(tensor.data_ptr()) - self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) - + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) diff --git a/lightllm/distributed/pynccl_wrapper.py b/lightllm/distributed/pynccl_wrapper.py index d35ec8e3a..344689d96 100644 --- a/lightllm/distributed/pynccl_wrapper.py +++ b/lightllm/distributed/pynccl_wrapper.py @@ -64,9 +64,7 @@ def find_nccl_library() -> str: # manually load the nccl library if so_file: - logger.info( - "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", - so_file) + logger.info("Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file) else: if torch.version.cuda is not None: so_file = "libnccl.so.2" @@ -173,77 +171,74 @@ class NCCLLibrary: # const char* ncclGetErrorString(ncclResult_t result) Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), # ncclResult_t ncclGetVersion(int *version); - Function("ncclGetVersion", ncclResult_t, - [ctypes.POINTER(ctypes.c_int)]), + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); - Function("ncclGetUniqueId", ncclResult_t, - [ctypes.POINTER(ncclUniqueId)]), + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), # ncclResult_t ncclCommInitRank( # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); # note that ncclComm_t is a pointer type, so the first argument # is a pointer to a pointer - Function("ncclCommInitRank", ncclResult_t, [ - ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, - ctypes.c_int - ]), + Function( + "ncclCommInitRank", ncclResult_t, [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int] + ), # ncclResult_t ncclAllReduce( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclAllReduce", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclAllReduce", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclAllGather( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclAllGather", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclAllGather", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclReduceScatter( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclReduceScatter", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclReduceScatter", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); - Function("ncclSend", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclSend", + ncclResult_t, + [buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclRecv( # void* recvbuff, size_t count, ncclDataType_t datatype, # int src, ncclComm_t comm, cudaStream_t stream); - Function("ncclRecv", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclRecv", + ncclResult_t, + [buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclBroadcast( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, int root, ncclComm_t comm, # cudaStream_t stream); - Function("ncclBroadcast", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ctypes.c_int, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclBroadcast", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], + ), # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -277,8 +272,10 @@ def __init__(self, so_file: Optional[str] = None): "or it does not support the current platform %s. " "If you already have the library, please set the " "environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) raise e if so_file not in NCCLLibrary.path_to_dict_mapping: @@ -311,80 +308,100 @@ def ncclGetVersion(self) -> str: def ncclGetUniqueId(self) -> ncclUniqueId: unique_id = ncclUniqueId() - self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( - ctypes.byref(unique_id))) + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) return unique_id - def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, - rank: int) -> ncclComm_t: + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, rank: int) -> ncclComm_t: comm = ncclComm_t() - self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), - world_size, unique_id, - rank)) + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), world_size, unique_id, rank)) return comm - def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, - datatype, op, comm, - stream)) - - def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream)) + + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, - count, datatype, op, - comm, stream)) - - def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, count, datatype, op, comm, stream)) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # which is an aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, - datatype, comm, stream)) - - def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, - dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, - dest, comm, stream)) - - def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, - src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, - comm, stream)) - - def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, root: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, - datatype, root, comm, - stream)) + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, datatype, comm, stream)) + + def ncclSend( + self, sendbuff: buffer_type, count: int, datatype: int, dest: int, comm: ncclComm_t, stream: cudaStream_t + ) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)) + + def ncclRecv( + self, recvbuff: buffer_type, count: int, datatype: int, src: int, comm: ncclComm_t, stream: cudaStream_t + ) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)) + + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, datatype, root, comm, stream)) def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) __all__ = [ - "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", - "ncclComm_t", "cudaStream_t", "buffer_type" + "NCCLLibrary", + "ncclDataTypeEnum", + "ncclRedOpTypeEnum", + "ncclUniqueId", + "ncclComm_t", + "cudaStream_t", + "buffer_type", ] - def test_ncclGetUniqueId(): lib = NCCLLibrary() unique_id = lib.ncclGetUniqueId() @@ -399,7 +416,9 @@ def test_ncclGetUniqueId(): # as long as the function doesn't raise an exception, we're good assert unique_id is not None -if __name__ == '__main__': - import torch; + +if __name__ == "__main__": + import torch + torch.cuda.set_device(0) test_ncclGetUniqueId() diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 222cd5887..d63a38e99 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -74,6 +74,7 @@ class DecodeNodeInfo: rpyc_port: str max_new_tokens: int + @dataclass class PDTransJoinInfo: decode_id: int @@ -83,11 +84,13 @@ class PDTransJoinInfo: prefill_ip: str prefill_port: int + @dataclass class PDTransLeaveInfo: decode_id: int prefill_id: int + @dataclass class KVMoveTask: group_request_id: int diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 0c46b6dd5..8af8952fa 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -47,23 +47,23 @@ class TransProcessObj: put_to_radix_thread: threading.Thread = None latest_check_time: float = None - def create( - self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manager: "DecodeKVMoveManager" - ): + def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manager: "DecodeKVMoveManager"): device_index = manager.get_next_device_index() decode_node_id = manager.args.pd_node_id task_in_queue = manager.kv_trans_task_in_queue task_out_queue = manager.kv_trans_task_out_queue - task_in_queue.put(PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=-1, - prefill_ip=prefill_ip, - prefill_port=prefill_port, - decode_id=decode_node_id, - decode_device_id=device_index, - )) + task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=-1, + prefill_ip=prefill_ip, + prefill_port=prefill_port, + decode_id=decode_node_id, + decode_device_id=device_index, + ) + ) assert task_out_queue.get(timeout=60) == "nccl_ok" self.prefill_node_id = prefill_node_id @@ -136,10 +136,7 @@ def kv_move_loop(self): self.manager.put_to_fail_release_task_queue(move_tasks) logger.error(f"{func_name} prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") - self.task_in_queue.put(PDTransLeaveInfo( - decode_id=self.decode_node_id, - prefill_id=self.prefill_node_id - )) + self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def put_to_radix_loop(self): @@ -269,12 +266,14 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - # start a single kv trans process + # start a single kv trans process self.kv_trans_task_in_queue = mp.Queue() self.kv_trans_task_out_queue = mp.Queue() from .decode_trans_process import start_decode_trans_process + self.kv_trans_process = start_decode_trans_process( - self.args, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) + self.args, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues + ) assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" self._put_mem_manager_to_mem_queue() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 010074b10..7f3f9c676 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -15,9 +15,13 @@ logger = init_logger(__name__) -def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], prefill_to_comm: Dict[int, PyNcclCommunicator], - dp_size_in_node: int): +def _handle_kvmove_task( + move_tasks: List[KVMoveTask], + task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], + prefill_to_comm: Dict[int, PyNcclCommunicator], + dp_size_in_node: int, +): total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) try: prefill_id = move_tasks[0].prefill_node_id @@ -27,9 +31,13 @@ def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, cur_mem = mem_managers[device_index] logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") if kv_trans_use_p2p(): - cur_mem.receive_from_prefill_node_p2p(move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id]) + cur_mem.receive_from_prefill_node_p2p( + move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id] + ) else: - cur_mem.receive_from_prefill_node(move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id]) + cur_mem.receive_from_prefill_node( + move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id] + ) logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") torch.cuda.synchronize() logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") @@ -39,26 +47,26 @@ def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, task_out_queue.put("fail") raise e -def _handle_prefill_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, prefill_to_comm: Dict[int, PyNcclCommunicator]): + +def _handle_prefill_join( + node_info: PDTransJoinInfo, task_out_queue: mp.Queue, prefill_to_comm: Dict[int, PyNcclCommunicator] +): try: - store_client = TCPStore(host_name=node_info.prefill_ip, port=node_info.prefill_port, is_master=False, use_libuv=False) + store_client = TCPStore( + host_name=node_info.prefill_ip, port=node_info.prefill_port, is_master=False, use_libuv=False + ) group = StatelessP2PProcessGroup.create( - src_id=node_info.prefill_id, - dest_id=node_info.decode_id, - is_server=False, - store=store_client) + src_id=node_info.prefill_id, dest_id=node_info.decode_id, is_server=False, store=store_client + ) comm = PyNcclCommunicator(group, node_info.decode_device_id) prefill_to_comm[node_info.prefill_id] = comm logger.info(f"{node_info} kv trans connected") - task_out_queue.put('nccl_ok') + task_out_queue.put("nccl_ok") except Exception as e: logger.warning(f"error while connect to prefill node: {e}") -def _init_env( - args, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, - mem_queues: List[mp.Queue]): + +def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]): dp_size_in_node = max(1, args.dp // args.nnodes) node_world_size = args.tp // args.nnodes @@ -80,7 +88,7 @@ def _init_env( prefill_to_comm[task.prefill_id].destroy() logger.info(f"destory {task.prefill_id} nccl communicator.") else: - logger.warning(f'unexpected task type: {task}') + logger.warning(f"unexpected task type: {task}") except Exception as e: logger.error(f"Fatal error happened in kv trans process: {e}") @@ -93,10 +101,8 @@ def start_decode_trans_process( task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): - proc = mp.Process( - target=_init_env, args=(args, task_in_queue, task_out_queue, mem_queues) - ) + proc = mp.Process(target=_init_env, args=(args, task_in_queue, task_out_queue, mem_queues)) proc.start() assert proc.is_alive() - logger.info(f"decode trans kv process start!") + logger.info("decode trans kv process start!") return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index e0c342654..5c4e946cf 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -60,18 +60,22 @@ def create( task_in_queue = manager.kv_trans_task_in_queue task_out_queue = manager.kv_trans_task_out_queue - task_in_queue.put(PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=device_index, - prefill_ip=manager.host_ip, - prefill_port=manager.kv_trans_port, - decode_id=decode_node_id, - decode_device_id=-1 - )) + task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=device_index, + prefill_ip=manager.host_ip, + prefill_port=manager.kv_trans_port, + decode_id=decode_node_id, + decode_device_id=-1, + ) + ) # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 max_kv_trans_token_num = obtain( - con.root.build_trans_process(prefill_node_id, manager.host_ip, manager.kv_trans_port, manager.args.max_total_token_num) + con.root.build_trans_process( + prefill_node_id, manager.host_ip, manager.kv_trans_port, manager.args.max_total_token_num + ) ) self.max_kv_trans_token_num = max_kv_trans_token_num assert task_out_queue.get(timeout=60) == "nccl_ok" @@ -107,7 +111,6 @@ def _get_request_tasks(self, datas: List[KVMoveTask]): break return ans_list - def check_connect(self, raise_exception=True): try: self.rpyc_conn.root.check_alive() @@ -234,8 +237,7 @@ def kv_trans_handle_loop(self): self.manager.put_to_release_task_queue(move_tasks) logger.error(f"trans kv thread, decode id {self.decode_node_id} device_index {self.device_index} thread quit") - self.task_in_queue.put(PDTransLeaveInfo( - decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) + self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def wait_thread_quit(self): @@ -326,13 +328,20 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.release_tasks_thread = threading.Thread(target=self.handle_release_task_loop, daemon=True) self.release_tasks_thread.start() - # start a single kv trans process + # start a single kv trans process self.kv_trans_task_in_queue = mp.Queue() self.kv_trans_task_out_queue = mp.Queue() from .prefill_trans_process import start_prefill_trans_process + self.kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) self.kv_trans_process = start_prefill_trans_process( - self.args, self.host_ip, self.kv_trans_port, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) + self.args, + self.host_ip, + self.kv_trans_port, + self.kv_trans_task_in_queue, + self.kv_trans_task_out_queue, + self.mem_queues, + ) assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" self._put_mem_manager_to_mem_queue() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index b6fa0f032..c6def3b3c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -15,9 +15,14 @@ logger = init_logger(__name__) -def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], decode_to_comm: Dict[int, PyNcclCommunicator], - dp_size_in_node: int): + +def _handle_kvmove_task( + move_tasks: List[KVMoveTask], + task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], + decode_to_comm: Dict[int, PyNcclCommunicator], + dp_size_in_node: int, +): total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) try: decode_id = move_tasks[0].decode_node.node_id @@ -42,7 +47,10 @@ def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, logger.exception(str(e)) task_out_queue.put("fail") -def _handle_decode_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, decode_to_comm: Dict[str, PyNcclCommunicator], store: TCPStore): + +def _handle_decode_join( + node_info: PDTransJoinInfo, task_out_queue: mp.Queue, decode_to_comm: Dict[str, PyNcclCommunicator], store: TCPStore +): try: group = StatelessP2PProcessGroup.create(node_info.prefill_id, node_info.decode_id, True, store) comm = PyNcclCommunicator(group, node_info.prefill_device_id) @@ -52,13 +60,15 @@ def _handle_decode_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, de except Exception as e: logger.warning(f"error while connect to decode node: {e}") + def _init_env( args, store_ip, store_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue],): + mem_queues: List[mp.Queue], +): try: graceful_registry(inspect.currentframe().f_code.co_name) master_store = TCPStore(host_name=store_ip, port=store_port, is_master=True, use_libuv=True) @@ -80,7 +90,7 @@ def _init_env( decode_to_comm[task.decode_id].destroy() logger.info(f"destory {task.decode_id} nccl communicator.") else: - logger.warning(f'unexpected task type: {task}') + logger.warning(f"unexpected task type: {task}") except Exception as e: logger.error(f"Fatal error happened in kv trans process: {e}") @@ -95,10 +105,8 @@ def start_prefill_trans_process( task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): - proc = mp.Process( - target=_init_env, args=(args, store_ip, store_port, task_in_queue, task_out_queue, mem_queues) - ) + proc = mp.Process(target=_init_env, args=(args, store_ip, store_port, task_in_queue, task_out_queue, mem_queues)) proc.start() assert proc.is_alive() - logger.info(f"trans kv process started!") - return proc \ No newline at end of file + logger.info("prefill trans kv process started!") + return proc From ac076f3a67e17a9a1595b164538ea8f6d761c321 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Tue, 11 Mar 2025 17:57:08 +0800 Subject: [PATCH 04/20] one kv trans process per tp. --- lightllm/distributed/pynccl.py | 71 ----------------- .../decode_kv_move_manager.py | 59 ++++++++++---- .../decode_node_impl/decode_trans_process.py | 12 +-- .../prefill_kv_move_manager.py | 79 +++++++++++++------ .../prefill_trans_process.py | 12 +-- 5 files changed, 110 insertions(+), 123 deletions(-) diff --git a/lightllm/distributed/pynccl.py b/lightllm/distributed/pynccl.py index 3637b04dd..b96e0d1ba 100644 --- a/lightllm/distributed/pynccl.py +++ b/lightllm/distributed/pynccl.py @@ -248,51 +248,6 @@ def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, strea ) return out_tensor - def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None): - if self.disabled: - return - # nccl communicator created on a specific device - # will only work on tensors on the same device - # otherwise it will cause "illegal memory access" - assert input_tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}" - ) - if stream is None: - stream = current_stream() - self.nccl.ncclAllGather( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), - input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), - self.comm, - cudaStream_t(stream.cuda_stream), - ) - - def reduce_scatter( - self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None - ): - if self.disabled: - return - # nccl communicator created on a specific device - # will only work on tensors on the same device - # otherwise it will cause "illegal memory access" - assert input_tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}" - ) - if stream is None: - stream = current_stream() - self.nccl.ncclReduceScatter( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), - output_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), - self.comm, - cudaStream_t(stream.cuda_stream), - ) - def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return @@ -328,29 +283,3 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): self.comm, cudaStream_t(stream.cuda_stream), ) - - def broadcast(self, tensor: torch.Tensor, src: int, stream=None): - if self.disabled: - return - assert tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}" - ) - if stream is None: - stream = current_stream() - if src == self.rank: - sendbuff = buffer_type(tensor.data_ptr()) - # NCCL requires the sender also to have a receive buffer - recvbuff = buffer_type(tensor.data_ptr()) - else: - sendbuff = buffer_type() - recvbuff = buffer_type(tensor.data_ptr()) - self.nccl.ncclBroadcast( - sendbuff, - recvbuff, - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - src, - self.comm, - cudaStream_t(stream.cuda_stream), - ) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 8af8952fa..7a6f120cb 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -51,8 +51,8 @@ def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manag device_index = manager.get_next_device_index() decode_node_id = manager.args.pd_node_id - task_in_queue = manager.kv_trans_task_in_queue - task_out_queue = manager.kv_trans_task_out_queue + task_in_queue = manager.kv_trans_task_in_queues[device_index] + task_out_queue = manager.kv_trans_task_out_queues[device_index] task_in_queue.put( PDTransJoinInfo( @@ -136,7 +136,6 @@ def kv_move_loop(self): self.manager.put_to_fail_release_task_queue(move_tasks) logger.error(f"{func_name} prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") - self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def put_to_radix_loop(self): @@ -217,6 +216,7 @@ def __del__(self): try: self.set_has_error() self.wait_thread_quit() + self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) if self.ready_to_move_queue is not None: self.ready_to_move_queue.clear_tasks() if self.move_finished_queue is not None: @@ -266,18 +266,31 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - # start a single kv trans process - self.kv_trans_task_in_queue = mp.Queue() - self.kv_trans_task_out_queue = mp.Queue() from .decode_trans_process import start_decode_trans_process - self.kv_trans_process = start_decode_trans_process( - self.args, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues - ) + self.kv_trans_processes = [] + self.kv_trans_task_in_queues = [] + self.kv_trans_task_out_queues = [] + self.kv_trans_process_alive = [] + + for device_index in range(self.node_world_size): + kv_trans_task_in_queue = mp.Queue() + kv_trans_task_out_queue = mp.Queue() + kv_trans_process = start_decode_trans_process( + self.args, + device_index, + kv_trans_task_in_queue, + kv_trans_task_out_queue, + self.mem_queues, + ) + assert kv_trans_task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" - assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" - self._put_mem_manager_to_mem_queue() - assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + self.kv_trans_processes.append(kv_trans_process) + self.kv_trans_task_in_queues.append(kv_trans_task_in_queue) + self.kv_trans_task_out_queues.append(kv_trans_task_out_queue) + self.kv_trans_process_alive.append(True) return @@ -462,7 +475,9 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona return ans_list def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] + counts = [ + 0 if self.kv_trans_process_alive[device_id] else (1 << 20) for device_id in range(self.node_world_size) + ] for obj in self.node_id_to_trans_obj.values(): counts[obj.device_index] += 1 device_index = int(np.argmin(counts)) @@ -495,10 +510,22 @@ def remove_trans_obj(self, prefill_node_id): return def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.kv_trans_process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + at_least_one_alive = False + for device_id in range(self.node_world_size): + if not self.kv_trans_process_alive[device_id]: + continue + + process = psutil.Process(self.kv_trans_processes[device_id].pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + self.kv_trans_process_alive[device_id] = False + logger.error(f"kv trans process for device: {device_id} dead!!!") + else: + at_least_one_alive = True + + if not at_least_one_alive: if raise_exception: - raise Exception(f"trans process: {self.kv_trans_process.pid} is dead") + raise Exception("All trans process are dead!!!") + return def timer_loop(self): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 7f3f9c676..100b05eaf 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -66,16 +66,17 @@ def _handle_prefill_join( logger.warning(f"error while connect to prefill node: {e}") -def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]): +def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]): dp_size_in_node = max(1, args.dp // args.nnodes) - node_world_size = args.tp // args.nnodes try: + torch.cuda.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") + mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] - assert len(mem_managers) == node_world_size + task_out_queue.put("get_mem_managers_ok") prefill_to_comm: Dict[int, PyNcclCommunicator] = {} while True: @@ -97,12 +98,13 @@ def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queue def start_decode_trans_process( args, + device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): - proc = mp.Process(target=_init_env, args=(args, task_in_queue, task_out_queue, mem_queues)) + proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues)) proc.start() assert proc.is_alive() - logger.info("decode trans kv process start!") + logger.info(f"decode trans kv process for device: {device_id} start!") return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 5c4e946cf..5ebce1021 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -39,7 +39,7 @@ class TransProcessObj: rpyc_conn: object = None # rpyc_con 的连接对象 task_in_queue: mp.Queue = None task_out_queue: mp.Queue = None - device_index: str = None # 使用的gpu序号 + device_index: int = None # 使用的gpu序号 manager: "PrefillKVMoveManager" = None has_error: bool = False request_kv_trans_task_queue: TaskQueue = None @@ -57,15 +57,15 @@ def create( device_index = manager.get_next_device_index() # 分配 trans 进程使用的显卡 prefill_node_id = manager.args.pd_node_id - task_in_queue = manager.kv_trans_task_in_queue - task_out_queue = manager.kv_trans_task_out_queue + task_in_queue = manager.kv_trans_task_in_queues[device_index] + task_out_queue = manager.kv_trans_task_out_queues[device_index] task_in_queue.put( PDTransJoinInfo( prefill_id=prefill_node_id, prefill_device_id=device_index, prefill_ip=manager.host_ip, - prefill_port=manager.kv_trans_port, + prefill_port=manager.kv_trans_ports[device_index], decode_id=decode_node_id, decode_device_id=-1, ) @@ -74,7 +74,7 @@ def create( # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 max_kv_trans_token_num = obtain( con.root.build_trans_process( - prefill_node_id, manager.host_ip, manager.kv_trans_port, manager.args.max_total_token_num + prefill_node_id, manager.host_ip, manager.kv_trans_ports[device_index], manager.args.max_total_token_num ) ) self.max_kv_trans_token_num = max_kv_trans_token_num @@ -237,7 +237,6 @@ def kv_trans_handle_loop(self): self.manager.put_to_release_task_queue(move_tasks) logger.error(f"trans kv thread, decode id {self.decode_node_id} device_index {self.device_index} thread quit") - self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def wait_thread_quit(self): @@ -282,6 +281,7 @@ def __del__(self): try: self.set_has_error() self.wait_thread_quit() + self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) if self.request_kv_trans_task_queue is not None: self.request_kv_trans_task_queue.clear_tasks() if self.ready_kv_trans_task_queue is not None: @@ -329,24 +329,37 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.release_tasks_thread.start() # start a single kv trans process - self.kv_trans_task_in_queue = mp.Queue() - self.kv_trans_task_out_queue = mp.Queue() - from .prefill_trans_process import start_prefill_trans_process - - self.kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) - self.kv_trans_process = start_prefill_trans_process( - self.args, - self.host_ip, - self.kv_trans_port, - self.kv_trans_task_in_queue, - self.kv_trans_task_out_queue, - self.mem_queues, - ) - assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" - self._put_mem_manager_to_mem_queue() - assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + from .prefill_trans_process import start_prefill_trans_process + self.kv_trans_ports = [] + self.kv_trans_processes = [] + self.kv_trans_task_in_queues = [] + self.kv_trans_task_out_queues = [] + self.kv_trans_process_alive = [] + + for device_id in range(self.node_world_size): + kv_trans_task_in_queue = mp.Queue() + kv_trans_task_out_queue = mp.Queue() + kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) + kv_trans_process = start_prefill_trans_process( + self.args, + self.host_ip, + kv_trans_port, + device_id, + kv_trans_task_in_queue, + kv_trans_task_out_queue, + self.mem_queues, + ) + assert kv_trans_task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + + self.kv_trans_ports.append(kv_trans_port) + self.kv_trans_processes.append(kv_trans_process) + self.kv_trans_task_in_queues.append(kv_trans_task_in_queue) + self.kv_trans_task_out_queues.append(kv_trans_task_out_queue) + self.kv_trans_process_alive.append(True) return def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): @@ -368,14 +381,28 @@ def handle_release_task_loop(self): return def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.kv_trans_process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + at_least_one_alive = False + for device_id in range(self.node_world_size): + if not self.kv_trans_process_alive[device_id]: + continue + + process = psutil.Process(self.kv_trans_processes[device_id].pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + self.kv_trans_process_alive[device_id] = False + logger.error(f"kv trans process for device: {device_id} dead!!!") + else: + at_least_one_alive = True + + if not at_least_one_alive: if raise_exception: - raise Exception(f"trans process: {self.kv_trans_process.pid} is dead") + raise Exception("All trans process are dead!!!") + return def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] + counts = [ + 0 if self.kv_trans_process_alive[device_id] else (1 << 20) for device_id in range(self.node_world_size) + ] for obj in self.node_id_to_trans_obj.values(): counts[obj.device_index] += 1 device_index = int(np.argmin(counts)) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index c6def3b3c..62327a11c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -27,7 +27,6 @@ def _handle_kvmove_task( try: decode_id = move_tasks[0].decode_node.node_id device_index = decode_to_comm[decode_id].device.index - torch.cuda.set_device(device_index) start = time.time() if total_move_kv_len != 0: logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") @@ -65,18 +64,18 @@ def _init_env( args, store_ip, store_port, + device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): try: + torch.cuda.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) master_store = TCPStore(host_name=store_ip, port=store_port, is_master=True, use_libuv=True) dp_size_in_node = max(1, args.dp // args.nnodes) - node_world_size = args.tp // args.nnodes task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] - assert len(mem_managers) == node_world_size task_out_queue.put("get_mem_managers_ok") decode_to_comm: Dict[int, PyNcclCommunicator] = {} @@ -101,12 +100,15 @@ def start_prefill_trans_process( args, store_ip, store_port, + device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): - proc = mp.Process(target=_init_env, args=(args, store_ip, store_port, task_in_queue, task_out_queue, mem_queues)) + proc = mp.Process( + target=_init_env, args=(args, store_ip, store_port, device_id, task_in_queue, task_out_queue, mem_queues) + ) proc.start() assert proc.is_alive() - logger.info("prefill trans kv process started!") + logger.info(f"prefill trans kv process for device: {device_id} started!") return proc From a7367e6c9af2d2c8d2b57ba74fca52a149521c9f Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 19 Mar 2025 16:06:19 +0800 Subject: [PATCH 05/20] fix. --- lightllm/server/pd_io_struct.py | 4 +- .../decode_kv_move_manager.py | 162 ++++++++++++------ .../decode_node_impl/decode_trans_process.py | 2 +- .../prefill_kv_move_manager.py | 150 +++++++++++----- 4 files changed, 221 insertions(+), 97 deletions(-) diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index d63a38e99..44a712e75 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -81,8 +81,8 @@ class PDTransJoinInfo: decode_device_id: int prefill_id: int prefill_device_id: int - prefill_ip: str - prefill_port: int + pd_prefill_nccl_ip: str + pd_prefill_nccl_port: int @dataclass diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 7a6f120cb..d870fb16a 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -29,15 +29,17 @@ thread_local_data = threading.local() KV_MOVE_MAX_NUM = 16 +KV_MOVE_MAX_RESTART_CNT = 3 @dataclass class TransProcessObj: prefill_node_id: int = None + process: mp.Process = None task_in_queue: mp.Queue = None task_out_queue: mp.Queue = None - prefill_ip: str = None - prefill_port: int = None + pd_prefill_nccl_ip: str = None + pd_prefill_nccl_port: int = None device_index: int = None manager: "DecodeKVMoveManager" = None has_error: bool = False @@ -47,32 +49,36 @@ class TransProcessObj: put_to_radix_thread: threading.Thread = None latest_check_time: float = None - def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manager: "DecodeKVMoveManager"): + def create( + self, prefill_node_id: str, pd_prefill_nccl_ip: str, pd_prefill_nccl_port: int, manager: "DecodeKVMoveManager" + ): device_index = manager.get_next_device_index() decode_node_id = manager.args.pd_node_id task_in_queue = manager.kv_trans_task_in_queues[device_index] task_out_queue = manager.kv_trans_task_out_queues[device_index] - task_in_queue.put( - PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=-1, - prefill_ip=prefill_ip, - prefill_port=prefill_port, - decode_id=decode_node_id, - decode_device_id=device_index, + with manager.device_locks[device_index]: + task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=-1, + pd_prefill_nccl_ip=pd_prefill_nccl_ip, + pd_prefill_nccl_port=pd_prefill_nccl_port, + decode_id=decode_node_id, + decode_device_id=device_index, + ) ) - ) - assert task_out_queue.get(timeout=60) == "nccl_ok" + assert task_out_queue.get(timeout=60) == "nccl_ok" self.prefill_node_id = prefill_node_id self.decode_node_id = decode_node_id self.task_in_queue = task_in_queue self.task_out_queue = task_out_queue - self.prefill_ip = prefill_ip - self.prefill_port = prefill_port + self.pd_prefill_nccl_ip = pd_prefill_nccl_ip + self.pd_prefill_nccl_port = pd_prefill_nccl_port self.device_index = device_index + self.process = manager.kv_trans_processes[device_index] self.manager = manager self.latest_check_time = time.time() @@ -90,6 +96,20 @@ def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manag self.put_to_radix_thread.start() return + def check_trans_process(self, raise_exception=True): + process = psutil.Process(self.process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + self.set_has_error() + if raise_exception: + raise Exception(f"trans process: {self.process.pid} is dead") + return + + def timer_to_check_status(self, raise_exception=True): + if time.time() - self.latest_check_time >= 2.0: + self.latest_check_time = time.time() + self.check_trans_process(raise_exception=raise_exception) + return + def _transfer_kv(self, move_tasks: List[KVMoveTask]): with self.manager.device_locks[self.device_index]: self.task_in_queue.put(move_tasks.copy(), timeout=10) @@ -120,6 +140,7 @@ def kv_move_loop(self): logger.info(f"{func_name} get task {task.to_decode_log_info()}") try: + self.timer_to_check_status(raise_exception=True) if not kv_trans_use_p2p(): with self.manager.kv_trans_lock: self._transfer_kv(move_tasks) @@ -150,6 +171,7 @@ def put_to_radix_loop(self): logger.info(f"{func_name} get put radix task {task.to_decode_log_info()}") try: + self.timer_to_check_status(raise_exception=True) # random to check stats self.manager._put_kv_received_to_radix_cache(move_tasks.copy()) for task in move_tasks.copy(): @@ -266,31 +288,17 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - from .decode_trans_process import start_decode_trans_process - self.kv_trans_processes = [] self.kv_trans_task_in_queues = [] self.kv_trans_task_out_queues = [] - self.kv_trans_process_alive = [] - - for device_index in range(self.node_world_size): - kv_trans_task_in_queue = mp.Queue() - kv_trans_task_out_queue = mp.Queue() - kv_trans_process = start_decode_trans_process( - self.args, - device_index, - kv_trans_task_in_queue, - kv_trans_task_out_queue, - self.mem_queues, - ) - assert kv_trans_task_out_queue.get(timeout=30) == "proc_start" - self._put_mem_manager_to_mem_queue() - assert kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + self.kv_trans_process_restart_cnt = [] - self.kv_trans_processes.append(kv_trans_process) - self.kv_trans_task_in_queues.append(kv_trans_task_in_queue) - self.kv_trans_task_out_queues.append(kv_trans_task_out_queue) - self.kv_trans_process_alive.append(True) + for device_id in range(self.node_world_size): + self.kv_trans_task_in_queues.append(mp.Queue()) + self.kv_trans_task_out_queues.append(mp.Queue()) + self.kv_trans_process_restart_cnt.append(0) + self.kv_trans_processes.append(None) + assert self.start_trans_process(device_id) return @@ -400,17 +408,19 @@ def exposed_check_alive(self): # 用于 prefill node check 通信连接的状态。 return - def exposed_build_trans_process(self, prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num): - prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num = list( - map(obtain, [prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num]) + def exposed_build_trans_process( + self, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num + ): + prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num = list( + map(obtain, [prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num]) ) thread_local_data.prefill_node_id = prefill_node_id - logger.info(f"build trans infos {prefill_node_id} {prefill_ip} {prefill_port}") + logger.info(f"build trans infos {prefill_node_id} {pd_prefill_nccl_ip} {pd_prefill_nccl_port}") # 如果有历史残留,一并移除 self.remove_trans_obj(prefill_node_id) tran_obj = TransProcessObj() - tran_obj.create(prefill_node_id, prefill_ip, prefill_port, self) + tran_obj.create(prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, self) self.node_id_to_trans_obj[prefill_node_id] = tran_obj return min(prefill_node_max_kv_trans_num, self.args.max_total_token_num) @@ -476,7 +486,7 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona def get_next_device_index(self): counts = [ - 0 if self.kv_trans_process_alive[device_id] else (1 << 20) for device_id in range(self.node_world_size) + 0 if self.is_kv_trans_process_alive(device_id) else (1 << 20) for device_id in range(self.node_world_size) ] for obj in self.node_id_to_trans_obj.values(): counts[obj.device_index] += 1 @@ -509,16 +519,60 @@ def remove_trans_obj(self, prefill_node_id): trans_obj.set_has_error() return + def remove_trans_obj_by_deviceid(self, device_id): + for node_id, t_obj in self.node_id_to_trans_obj.items(): + if t_obj.device_index == device_id: + self.remove_dead_trans_obj(node_id) + + def start_trans_process(self, device_id: int): + task_in_queue = self.kv_trans_task_in_queues[device_id] + task_out_queue = self.kv_trans_task_out_queues[device_id] + self.kv_trans_process_restart_cnt[device_id] += 1 + + if self.kv_trans_processes[device_id]: + # force kill + try: + self.remove_trans_obj_by_deviceid(device_id) + process = psutil.Process(self.kv_trans_processes[device_id].pid) + process.kill() + self.kv_trans_processes[device_id] = None + except Exception: + pass + + try: + from .decode_trans_process import start_decode_trans_process + + kv_trans_process = start_decode_trans_process( + self.args, + device_id, + task_in_queue, + task_out_queue, + self.mem_queues, + ) + assert task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" + + self.kv_trans_processes[device_id] = kv_trans_process + + return True + except Exception as e: + logger.warning(f"Failed start kv trans process for device {device_id}: {e}") + return False + + def is_kv_trans_process_alive(self, device_id): + return self.kv_trans_process_restart_cnt[device_id] <= KV_MOVE_MAX_RESTART_CNT + def check_trans_process(self, raise_exception=True): at_least_one_alive = False for device_id in range(self.node_world_size): - if not self.kv_trans_process_alive[device_id]: + if not self.is_kv_trans_process_alive(device_id): continue process = psutil.Process(self.kv_trans_processes[device_id].pid) if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.kv_trans_process_alive[device_id] = False - logger.error(f"kv trans process for device: {device_id} dead!!!") + logger.error(f"kv trans process for device: {device_id} dead!!!, try start again...") + self.start_trans_process(device_id) else: at_least_one_alive = True @@ -530,17 +584,24 @@ def check_trans_process(self, raise_exception=True): def timer_loop(self): try: - last_check_time = time.time() while True: self._unfrozen_time_out_reqs_tokens() time.sleep(3.5) - if last_check_time - time.time() > 10.0: - self.check_trans_process() - last_check_time = time.time() except (BaseException, RuntimeError) as e: logger.exception(str(e)) raise e + def check_trans_process_loop(self): + try: + while True: + self.check_trans_process() + time.sleep(10.0) + except (BaseException, RuntimeError) as e: + logger.exception(str(e)) + # kill parent process if any exception occurred + os.kill(os.getppid(), signal.SIGTERM) + raise e + def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ @@ -552,6 +613,9 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp. t = ThreadedServer(manager, port=args.pd_decode_rpyc_port, protocol_config={"allow_pickle": True}) threading.Thread(target=lambda: t.start(), daemon=True).start() + kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True) + kv_trans_process_check.start() + event.set() manager.timer_loop() return diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 100b05eaf..ad23fd38f 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -53,7 +53,7 @@ def _handle_prefill_join( ): try: store_client = TCPStore( - host_name=node_info.prefill_ip, port=node_info.prefill_port, is_master=False, use_libuv=False + host_name=node_info.pd_prefill_nccl_ip, port=node_info.pd_prefill_nccl_port, is_master=False, use_libuv=True ) group = StatelessP2PProcessGroup.create( src_id=node_info.prefill_id, dest_id=node_info.decode_id, is_server=False, store=store_client diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 5ebce1021..cad199fc6 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -29,6 +29,7 @@ from lightllm.utils.envs_utils import get_unique_server_name KV_MOVE_MAX_NUM = 16 +KV_MOVE_MAX_RESTART_CNT = 3 logger = init_logger(__name__) @@ -37,6 +38,7 @@ class TransProcessObj: decode_node_id: int = None rpyc_conn: object = None # rpyc_con 的连接对象 + process: mp.Process = None task_in_queue: mp.Queue = None task_out_queue: mp.Queue = None device_index: int = None # 使用的gpu序号 @@ -60,25 +62,29 @@ def create( task_in_queue = manager.kv_trans_task_in_queues[device_index] task_out_queue = manager.kv_trans_task_out_queues[device_index] - task_in_queue.put( - PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=device_index, - prefill_ip=manager.host_ip, - prefill_port=manager.kv_trans_ports[device_index], - decode_id=decode_node_id, - decode_device_id=-1, + with manager.device_locks[device_index]: + task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=device_index, + pd_prefill_nccl_ip=manager.host_ip, + pd_prefill_nccl_port=manager.kv_trans_ports[device_index], + decode_id=decode_node_id, + decode_device_id=-1, + ) ) - ) - # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 - max_kv_trans_token_num = obtain( - con.root.build_trans_process( - prefill_node_id, manager.host_ip, manager.kv_trans_ports[device_index], manager.args.max_total_token_num + # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 + max_kv_trans_token_num = obtain( + con.root.build_trans_process( + prefill_node_id, + manager.host_ip, + manager.kv_trans_ports[device_index], + manager.args.max_total_token_num, + ) ) - ) - self.max_kv_trans_token_num = max_kv_trans_token_num - assert task_out_queue.get(timeout=60) == "nccl_ok" + self.max_kv_trans_token_num = max_kv_trans_token_num + assert task_out_queue.get(timeout=60) == "nccl_ok" self.decode_node_id = decode_node_id self.prefill_node_id = prefill_node_id @@ -88,6 +94,7 @@ def create( self.device_index = device_index self.manager = manager self.latest_check_time = time.time() + self.process = manager.kv_trans_processes[device_index] self.request_kv_trans_task_queue = TaskQueue( get_func=self._get_request_tasks, fail_func=self.manager.put_to_release_task_queue @@ -120,9 +127,18 @@ def check_connect(self, raise_exception=True): raise e return + def check_trans_process(self, raise_exception=True): + process = psutil.Process(self.process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + self.set_has_error() + if raise_exception: + raise Exception(f"trans process: {self.process.pid} is dead") + return + def timer_check_status(self, raise_exception=True): if time.time() - self.latest_check_time >= 2.0: self.latest_check_time = time.time() + self.check_trans_process(raise_exception=raise_exception) self.check_connect(raise_exception=raise_exception) if self.has_error: self.manager.remove_trans_obj(self.decode_node_id) @@ -336,30 +352,18 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.kv_trans_processes = [] self.kv_trans_task_in_queues = [] self.kv_trans_task_out_queues = [] - self.kv_trans_process_alive = [] + self.kv_trans_process_restart_cnt = [] for device_id in range(self.node_world_size): - kv_trans_task_in_queue = mp.Queue() - kv_trans_task_out_queue = mp.Queue() - kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) - kv_trans_process = start_prefill_trans_process( - self.args, - self.host_ip, - kv_trans_port, - device_id, - kv_trans_task_in_queue, - kv_trans_task_out_queue, - self.mem_queues, + self.kv_trans_task_in_queues.append(mp.Queue()) + self.kv_trans_task_out_queues.append(mp.Queue()) + self.kv_trans_ports.append( + find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) ) - assert kv_trans_task_out_queue.get(timeout=30) == "proc_start" - self._put_mem_manager_to_mem_queue() - assert kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + self.kv_trans_process_restart_cnt.append(0) + self.kv_trans_processes.append(None) + assert self.start_trans_process(device_id) - self.kv_trans_ports.append(kv_trans_port) - self.kv_trans_processes.append(kv_trans_process) - self.kv_trans_task_in_queues.append(kv_trans_task_in_queue) - self.kv_trans_task_out_queues.append(kv_trans_task_out_queue) - self.kv_trans_process_alive.append(True) return def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): @@ -380,16 +384,55 @@ def handle_release_task_loop(self): self._remove_req_refs_from_prompt_cache(handle_list) return + def start_trans_process(self, device_id: int): + task_in_queue = self.kv_trans_task_in_queues[device_id] + task_out_queue = self.kv_trans_task_out_queues[device_id] + kv_trans_port = self.kv_trans_ports[device_id] + self.kv_trans_process_restart_cnt[device_id] += 1 + + if self.kv_trans_processes[device_id]: + # force kill + try: + self.remove_trans_obj_by_deviceid(device_id) + process = psutil.Process(self.kv_trans_processes[device_id].pid) + process.kill() + self.kv_trans_processes[device_id] = None + except Exception: + pass + + try: + from .prefill_trans_process import start_prefill_trans_process + + kv_trans_process = start_prefill_trans_process( + self.args, + self.host_ip, + kv_trans_port, + device_id, + task_in_queue, + task_out_queue, + self.mem_queues, + ) + assert task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" + + self.kv_trans_processes[device_id] = kv_trans_process + + return True + except Exception as e: + logger.warning(f"Failed start kv trans process for device {device_id}: {e}") + return False + def check_trans_process(self, raise_exception=True): at_least_one_alive = False for device_id in range(self.node_world_size): - if not self.kv_trans_process_alive[device_id]: + if not self.is_kv_trans_process_alive(device_id): continue process = psutil.Process(self.kv_trans_processes[device_id].pid) if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.kv_trans_process_alive[device_id] = False - logger.error(f"kv trans process for device: {device_id} dead!!!") + logger.error(f"kv trans process for device: {device_id} dead!!!, try start again...") + self.start_trans_process(device_id) else: at_least_one_alive = True @@ -399,9 +442,24 @@ def check_trans_process(self, raise_exception=True): return + def check_trans_process_loop(self): + try: + while True: + self.check_trans_process() + time.sleep(10.0) + except (BaseException, RuntimeError) as e: + logger.exception(str(e)) + # kill parent process if any exception occurred + os.kill(os.getppid(), signal.SIGTERM) + raise e + + def is_kv_trans_process_alive(self, device_id): + return self.kv_trans_process_restart_cnt[device_id] <= KV_MOVE_MAX_RESTART_CNT + def get_next_device_index(self): + counts = [ - 0 if self.kv_trans_process_alive[device_id] else (1 << 20) for device_id in range(self.node_world_size) + 0 if self.is_kv_trans_process_alive(device_id) else (1 << 20) for device_id in range(self.node_world_size) ] for obj in self.node_id_to_trans_obj.values(): counts[obj.device_index] += 1 @@ -438,9 +496,13 @@ def remove_dead_trans_obj(self): gc.collect() return + def remove_trans_obj_by_deviceid(self, device_id): + for node_id, t_obj in self.node_id_to_trans_obj.items(): + if t_obj.device_index == device_id: + self.remove_dead_trans_obj(node_id) + def task_dispatcher_loop(self): try: - last_check_time = time.time() # 获取任务,并分发给相关卡的处理队列 while True: move_task: KVMoveTask = self.info_queue.get() @@ -453,10 +515,6 @@ def task_dispatcher_loop(self): finally: trans_obj = None - if time.time() - last_check_time > 10.0: - self.check_trans_process() - last_check_time = time.time() - except (BaseException, RuntimeError) as e: logger.exception(str(e)) raise e @@ -496,6 +554,8 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp. graceful_registry(inspect.currentframe().f_code.co_name) manager = PrefillKVMoveManager(args, info_queue, mem_queues) + kv_trans_process_check = threading.Thread(target=manager.check_trans_process_loop, daemon=True) + kv_trans_process_check.start() event.set() # 进入主循环 manager.task_dispatcher_loop() From efb6c5061c268768c8762bc6d95b24cda0d0fb0b Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 19 Mar 2025 16:35:56 +0800 Subject: [PATCH 06/20] fixup. --- .../decode_kv_move_manager.py | 24 +++++++------- .../prefill_kv_move_manager.py | 32 ++++++++----------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index d870fb16a..05cf5eed2 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -29,7 +29,7 @@ thread_local_data = threading.local() KV_MOVE_MAX_NUM = 16 -KV_MOVE_MAX_RESTART_CNT = 3 +KV_MOVE_MAX_START_CNT = 3 @dataclass @@ -288,16 +288,12 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - self.kv_trans_processes = [] - self.kv_trans_task_in_queues = [] - self.kv_trans_task_out_queues = [] - self.kv_trans_process_restart_cnt = [] + self.kv_trans_processes = [None] * self.node_world_size + self.kv_trans_task_in_queues = [None] * self.node_world_size + self.kv_trans_task_out_queues = [None] * self.node_world_size + self.kv_trans_process_start_cnt = [0] * self.node_world_size for device_id in range(self.node_world_size): - self.kv_trans_task_in_queues.append(mp.Queue()) - self.kv_trans_task_out_queues.append(mp.Queue()) - self.kv_trans_process_restart_cnt.append(0) - self.kv_trans_processes.append(None) assert self.start_trans_process(device_id) return @@ -525,9 +521,9 @@ def remove_trans_obj_by_deviceid(self, device_id): self.remove_dead_trans_obj(node_id) def start_trans_process(self, device_id: int): - task_in_queue = self.kv_trans_task_in_queues[device_id] - task_out_queue = self.kv_trans_task_out_queues[device_id] - self.kv_trans_process_restart_cnt[device_id] += 1 + task_in_queue = mp.Queue() + task_out_queue = mp.Queue() + self.kv_trans_process_start_cnt[device_id] += 1 if self.kv_trans_processes[device_id]: # force kill @@ -554,6 +550,8 @@ def start_trans_process(self, device_id: int): assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" self.kv_trans_processes[device_id] = kv_trans_process + self.kv_trans_task_in_queues[device_id] = task_in_queue + self.kv_trans_task_out_queues[device_id] = task_out_queue return True except Exception as e: @@ -561,7 +559,7 @@ def start_trans_process(self, device_id: int): return False def is_kv_trans_process_alive(self, device_id): - return self.kv_trans_process_restart_cnt[device_id] <= KV_MOVE_MAX_RESTART_CNT + return self.kv_trans_process_start_cnt[device_id] <= KV_MOVE_MAX_START_CNT def check_trans_process(self, raise_exception=True): at_least_one_alive = False diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index cad199fc6..4196c81d1 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -29,7 +29,7 @@ from lightllm.utils.envs_utils import get_unique_server_name KV_MOVE_MAX_NUM = 16 -KV_MOVE_MAX_RESTART_CNT = 3 +KV_MOVE_MAX_START_CNT = 3 logger = init_logger(__name__) @@ -348,20 +348,13 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): from .prefill_trans_process import start_prefill_trans_process - self.kv_trans_ports = [] - self.kv_trans_processes = [] - self.kv_trans_task_in_queues = [] - self.kv_trans_task_out_queues = [] - self.kv_trans_process_restart_cnt = [] + self.kv_trans_ports = [None] * self.node_world_size + self.kv_trans_processes = [None] * self.node_world_size + self.kv_trans_task_in_queues = [None] * self.node_world_size + self.kv_trans_task_out_queues = [None] * self.node_world_size + self.kv_trans_process_start_cnt = [0] * self.node_world_size for device_id in range(self.node_world_size): - self.kv_trans_task_in_queues.append(mp.Queue()) - self.kv_trans_task_out_queues.append(mp.Queue()) - self.kv_trans_ports.append( - find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) - ) - self.kv_trans_process_restart_cnt.append(0) - self.kv_trans_processes.append(None) assert self.start_trans_process(device_id) return @@ -385,10 +378,10 @@ def handle_release_task_loop(self): return def start_trans_process(self, device_id: int): - task_in_queue = self.kv_trans_task_in_queues[device_id] - task_out_queue = self.kv_trans_task_out_queues[device_id] - kv_trans_port = self.kv_trans_ports[device_id] - self.kv_trans_process_restart_cnt[device_id] += 1 + task_in_queue = mp.Queue() + task_out_queue = mp.Queue() + kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) + self.kv_trans_process_start_cnt[device_id] += 1 if self.kv_trans_processes[device_id]: # force kill @@ -417,6 +410,9 @@ def start_trans_process(self, device_id: int): assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" self.kv_trans_processes[device_id] = kv_trans_process + self.kv_trans_task_in_queues[device_id] = task_in_queue + self.kv_trans_task_out_queues[device_id] = task_out_queue + self.kv_trans_ports[device_id] = kv_trans_port return True except Exception as e: @@ -454,7 +450,7 @@ def check_trans_process_loop(self): raise e def is_kv_trans_process_alive(self, device_id): - return self.kv_trans_process_restart_cnt[device_id] <= KV_MOVE_MAX_RESTART_CNT + return self.kv_trans_process_start_cnt[device_id] <= KV_MOVE_MAX_START_CNT def get_next_device_index(self): From 2092ea6d43a39fd29c0a1611979197e96943955d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 9 Apr 2025 17:38:25 +0800 Subject: [PATCH 07/20] new pd code. --- lightllm/server/pd_io_struct.py | 17 +- .../decode_kv_move_manager.py | 495 +++++----------- .../decode_node_impl/decode_trans_obj.py | 292 ++++++++++ .../decode_node_impl/decode_trans_process.py | 35 +- .../prefill_kv_move_manager.py | 528 ++++-------------- .../prefill_node_impl/prefill_trans_obj.py | 372 ++++++++++++ .../prefill_trans_process.py | 37 +- .../continues_batch/pd_mode/task_queue.py | 4 +- .../continues_batch/pd_mode/utils.py | 9 + lightllm/utils/process_check.py | 2 +- lightllm/utils/time_utils.py | 16 + 11 files changed, 978 insertions(+), 829 deletions(-) create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py create mode 100644 lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py create mode 100644 lightllm/utils/time_utils.py diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 44a712e75..ad228ff8f 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -83,12 +83,18 @@ class PDTransJoinInfo: prefill_device_id: int pd_prefill_nccl_ip: str pd_prefill_nccl_port: int + # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分 + # 一次连接,使用一个 uuid 为其标识 + connect_id: str @dataclass class PDTransLeaveInfo: decode_id: int prefill_id: int + # 用于标识一次唯一的连接,prefill_id 和 decode_id 相同时,可能因为网络原因重连,为了更好的区分 + # 一次连接,使用一个 uuid 为其标识 + connect_id: str @dataclass @@ -106,6 +112,8 @@ class KVMoveTask: prefill_dp_index: int decode_dp_index: int mark_start_time: float = None + # 标记任务使用某个连接id进行传输 + connect_id: str = None def __post_init__(self): if len(self.input_tokens) <= 0: @@ -118,14 +126,14 @@ def to_prefill_log_info(self): d_i = self.prefill_dp_index id = self.group_request_id log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}" - return log + return log + f" connect_id: {self.connect_id}" def to_decode_log_info(self): v_len = None if self.decode_token_indexes is None else len(self.decode_token_indexes) d_i = self.decode_dp_index id = self.group_request_id log = f"id: {id} in_len:{len(self.input_tokens)} v_len: {v_len} move_len: {self.move_kv_len} dp_index:{d_i}" - return log + return log + f" connect_id: {self.connect_id}" def id(self): return self.group_request_id @@ -135,3 +143,8 @@ def get_cost_time(self): return time.time() - self.mark_start_time else: return 100000000000 + +@dataclass +class KVMoveTaskGroup: + tasks: List[KVMoveTask] + connect_id: str \ No newline at end of file diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 05cf5eed2..efa34c0d5 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -29,225 +29,6 @@ thread_local_data = threading.local() KV_MOVE_MAX_NUM = 16 -KV_MOVE_MAX_START_CNT = 3 - - -@dataclass -class TransProcessObj: - prefill_node_id: int = None - process: mp.Process = None - task_in_queue: mp.Queue = None - task_out_queue: mp.Queue = None - pd_prefill_nccl_ip: str = None - pd_prefill_nccl_port: int = None - device_index: int = None - manager: "DecodeKVMoveManager" = None - has_error: bool = False - ready_to_move_queue: TaskQueue = None - kv_move_thread: threading.Thread = None - move_finished_queue: TaskQueue = None - put_to_radix_thread: threading.Thread = None - latest_check_time: float = None - - def create( - self, prefill_node_id: str, pd_prefill_nccl_ip: str, pd_prefill_nccl_port: int, manager: "DecodeKVMoveManager" - ): - - device_index = manager.get_next_device_index() - decode_node_id = manager.args.pd_node_id - task_in_queue = manager.kv_trans_task_in_queues[device_index] - task_out_queue = manager.kv_trans_task_out_queues[device_index] - - with manager.device_locks[device_index]: - task_in_queue.put( - PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=-1, - pd_prefill_nccl_ip=pd_prefill_nccl_ip, - pd_prefill_nccl_port=pd_prefill_nccl_port, - decode_id=decode_node_id, - decode_device_id=device_index, - ) - ) - assert task_out_queue.get(timeout=60) == "nccl_ok" - - self.prefill_node_id = prefill_node_id - self.decode_node_id = decode_node_id - self.task_in_queue = task_in_queue - self.task_out_queue = task_out_queue - self.pd_prefill_nccl_ip = pd_prefill_nccl_ip - self.pd_prefill_nccl_port = pd_prefill_nccl_port - self.device_index = device_index - self.process = manager.kv_trans_processes[device_index] - - self.manager = manager - self.latest_check_time = time.time() - - self.ready_to_move_queue = TaskQueue( - get_func=lambda datas: datas[0:1], fail_func=self.manager.put_to_fail_release_task_queue - ) - self.kv_move_thread = threading.Thread(target=self.kv_move_loop, daemon=True) - self.kv_move_thread.start() - - self.move_finished_queue = TaskQueue( - get_func=lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=self.manager.put_to_fail_release_task_queue - ) - self.put_to_radix_thread = threading.Thread(target=self.put_to_radix_loop, daemon=True) - self.put_to_radix_thread.start() - return - - def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.set_has_error() - if raise_exception: - raise Exception(f"trans process: {self.process.pid} is dead") - return - - def timer_to_check_status(self, raise_exception=True): - if time.time() - self.latest_check_time >= 2.0: - self.latest_check_time = time.time() - self.check_trans_process(raise_exception=raise_exception) - return - - def _transfer_kv(self, move_tasks: List[KVMoveTask]): - with self.manager.device_locks[self.device_index]: - self.task_in_queue.put(move_tasks.copy(), timeout=10) - assert self.task_out_queue.get(timeout=60) == "ok" - logger.info(f"_transfer_kv ok {move_tasks[0].to_decode_log_info()}") - - # 标记 decode 接收到 kv cache 的时间 - for move_task in move_tasks: - move_task.mark_start_time = time.time() - - self.move_finished_queue.put_list(move_tasks) - move_tasks.clear() - - def kv_move_loop(self): - func_name = self.kv_move_loop.__name__ - while not self.has_error: - move_tasks: List[List[KVMoveTask]] = self.ready_to_move_queue.get_tasks(log_tag="ready_to_move_queue") - if len(move_tasks) == 0: - time.sleep(0.01) - continue - - if len(move_tasks) != 1: - logger.error(f"error get need 1, but get {len(move_tasks)}") - assert False - - move_tasks = move_tasks[0] - for task in move_tasks: - logger.info(f"{func_name} get task {task.to_decode_log_info()}") - - try: - self.timer_to_check_status(raise_exception=True) - if not kv_trans_use_p2p(): - with self.manager.kv_trans_lock: - self._transfer_kv(move_tasks) - else: - self._transfer_kv(move_tasks) - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.ready_to_move_queue.clear_tasks() - self.manager.remove_trans_obj(self.prefill_node_id) - - finally: - self.manager.put_to_fail_release_task_queue(move_tasks) - - logger.error(f"{func_name} prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") - return - - def put_to_radix_loop(self): - func_name = self.put_to_radix_loop.__name__ - while not self.has_error: - move_tasks: List[KVMoveTask] = self.move_finished_queue.get_tasks(log_tag="move_finished_queue") - if len(move_tasks) == 0: - time.sleep(0.01) - continue - - for task in move_tasks: - logger.info(f"{func_name} get put radix task {task.to_decode_log_info()}") - - try: - self.timer_to_check_status(raise_exception=True) - # random to check stats - self.manager._put_kv_received_to_radix_cache(move_tasks.copy()) - for task in move_tasks.copy(): - logger.info( - f"{func_name} put kv to radix cache ok, req_id: {task.id()} cost_time {task.get_cost_time()} s" - ) - self.manager.up_status_in_queue.put( - UpKVStatus(group_request_id=task.group_request_id, dp_index=task.decode_dp_index) - ) - logger.info(f"{func_name} up kv status req_id: {task.id()} finished") - move_tasks.clear() - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.move_finished_queue.clear_tasks() - self.manager.remove_trans_obj(self.prefill_node_id) - - finally: - self.manager.put_to_fail_release_task_queue(move_tasks) - - logger.error(f"{func_name}, prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") - return - - def wait_thread_quit(self): - if self.kv_move_thread is not None: - if self.kv_move_thread.is_alive(): - try: - self.kv_move_thread.join() - except: - pass - if self.put_to_radix_thread is not None: - if self.put_to_radix_thread.is_alive(): - try: - self.put_to_radix_thread.join() - except: - pass - return - - def has_error_status(self): - try: - assert self.has_error is False - assert self.kv_move_thread.is_alive() - assert self.put_to_radix_thread.is_alive() - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - return True - - return False - - def set_has_error(self): - self.has_error = True - try: - self.ready_to_move_queue.has_error = True - self.move_finished_queue.has_error = True - except: - pass - return - - def __del__(self): - logger.error(f"trans obj del start, prefill node id {self.prefill_node_id} device_index {self.device_index}") - - try: - self.set_has_error() - self.wait_thread_quit() - self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) - if self.ready_to_move_queue is not None: - self.ready_to_move_queue.clear_tasks() - if self.move_finished_queue is not None: - self.move_finished_queue.clear_tasks() - except BaseException as e: - logger.exception(str(e)) - - logger.error(f"trans obj deled, prefill node id {self.prefill_node_id} device_index {self.device_index}") - class DecodeKVMoveManager(rpyc.Service): def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): @@ -264,7 +45,10 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.mem_queues = mem_queues self.infer_rpyc_lock = threading.Lock() self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = [] - self.node_id_to_trans_obj: Dict[str, TransProcessObj] = {} + + from .decode_trans_obj import KVTransConnectObj + + self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {} for port in self.args.pd_node_infer_rpyc_ports: socket_path = f"/tmp/{get_unique_server_name()}_decode_node_infer_rpyc_{port}" from rpyc.utils.factory import unix_connect @@ -284,37 +68,27 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.fail_to_release_thread = threading.Thread(target=self.handle_fail_release_task_loop, daemon=True) self.fail_to_release_thread.start() + # 在不使用p2p 复制kv 的方案时,需要全局的传输锁进行控制。这个时候kv传输的效率会下降。 self.kv_trans_lock = threading.Lock() - # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 - self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - - self.kv_trans_processes = [None] * self.node_world_size - self.kv_trans_task_in_queues = [None] * self.node_world_size - self.kv_trans_task_out_queues = [None] * self.node_world_size - self.kv_trans_process_start_cnt = [0] * self.node_world_size - + + from .decode_trans_obj import KVTransProcess + + self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size for device_id in range(self.node_world_size): - assert self.start_trans_process(device_id) + self.kv_trans_processes[device_id] = KVTransProcess() + assert self.kv_trans_processes[device_id].init_all(device_id, self) return - - def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): - if isinstance(task, KVMoveTask): - self.fail_to_release_queue.put(task) - elif isinstance(task, list): - self.fail_to_release_queue.put_list(task) - else: - assert False, "error input" - return - - def handle_fail_release_task_loop(self): - while True: - handle_list: List[KVMoveTask] = self.fail_to_release_queue.get_tasks(log_tag="fail_to_release_queue") - if len(handle_list) == 0: - time.sleep(0.01) - else: - self._fail_to_realese_forzen_tokens(handle_list) - return + + # ================================================================================== + # _dp_alloc_to_frozen_some_tokens + # _put_kv_received_to_radix_cache + # _fail_to_realese_forzen_tokens + # _unfrozen_time_out_reqs_tokens + # _put_mem_manager_to_mem_queue + # 上述接口都是 kv move manager 与推理进程进行交互的接口,主要用于申请锁定kv资源或者释放 + # kv资源的接口 + # ================================================================================== async def wait_all_future_finish(self, futures: List[AsyncResult]): await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) @@ -384,17 +158,51 @@ def _put_mem_manager_to_mem_queue(self) -> None: for obj in self.infer_rpyc_objs: obj.put_mem_manager_to_mem_queue() return + + # ================================================================================== + # put_to_fail_release_task_queue 将因为一些原因失败,需要释放锁定的kv资源的请求放入到 + # 对应的处理队列中,handle_fail_release_task_loop 是一个循环的线程,专门处理这些失败的请求 + # 通过调用与推理进程交互的接口,释放掉申请锁定的 kv 资源。 + # ================================================================================== + + def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): + if isinstance(task, KVMoveTask): + self.fail_to_release_queue.put(task) + elif isinstance(task, list): + self.fail_to_release_queue.put_list(task) + else: + assert False, "error input" + return + + def handle_fail_release_task_loop(self): + while True: + handle_list: List[KVMoveTask] = self.fail_to_release_queue.get_tasks(log_tag="fail_to_release_queue") + if len(handle_list) == 0: + time.sleep(0.01) + else: + self._fail_to_realese_forzen_tokens(handle_list) + return + + # ================================================================================== + # on_connect + # on_disconnect + # exposed_check_alive + # exposed_build_trans_process + # exposed_request_data_transfer + # 上述接口是decode kv move manager 暴露的 rpyc 调用接口,用于 prefill kv move manager + # 进行连接,进行一些元数据资源的交互。 + # ================================================================================== def on_connect(self, conn): # 用于处理连接断开的时候,自动删除资源 - thread_local_data.prefill_node_id = None + thread_local_data.connect_id = None pass def on_disconnect(self, conn): # 用于处理连接断开的时候,自动删除资源 - if thread_local_data.prefill_node_id is not None: - self.remove_trans_obj(thread_local_data.prefill_node_id) - logger.info(f"prefill node id {thread_local_data.prefill_node_id} disconnect") + if thread_local_data.connect_id is not None: + self.remove_trans_obj(thread_local_data.connect_id) + logger.info(f"connect id {thread_local_data.connect_id} disconnect") import gc gc.collect() @@ -404,20 +212,22 @@ def exposed_check_alive(self): # 用于 prefill node check 通信连接的状态。 return - def exposed_build_trans_process( - self, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num + def exposed_build_trans_connect( + self, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num, connect_id ): prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num = list( map(obtain, [prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, prefill_node_max_kv_trans_num]) ) - thread_local_data.prefill_node_id = prefill_node_id - - logger.info(f"build trans infos {prefill_node_id} {pd_prefill_nccl_ip} {pd_prefill_nccl_port}") - # 如果有历史残留,一并移除 - self.remove_trans_obj(prefill_node_id) - tran_obj = TransProcessObj() - tran_obj.create(prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, self) - self.node_id_to_trans_obj[prefill_node_id] = tran_obj + connect_id = obtain(connect_id) + thread_local_data.connect_id = connect_id + + logger.info(f"build trans infos {prefill_node_id} {pd_prefill_nccl_ip} {pd_prefill_nccl_port} {connect_id}") + + from .decode_trans_obj import KVTransConnectObj + + tran_obj = KVTransConnectObj() + tran_obj.create(connect_id, prefill_node_id, pd_prefill_nccl_ip, pd_prefill_nccl_port, self) + self.connect_id_to_trans_obj[connect_id] = tran_obj return min(prefill_node_max_kv_trans_num, self.args.max_total_token_num) # 返回 None 代表繁忙, 放弃该任务的 kv 传送 @@ -465,141 +275,88 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona except BaseException as e: self.put_to_fail_release_task_queue(alloc_tokened_tasks) alloc_tokened_tasks = [] - self.remove_trans_obj(tasks[0].prefill_node_id) + self.remove_trans_obj(tasks[0].connect_id) logger.exception(str(e)) raise e + + if alloc_tokened_tasks: + trans_obj.ready_to_move_queue.put(alloc_tokened_tasks, error_handle_func=self.put_to_fail_release_task_queue) + return ans_list + + # ================================================================================== + # 定时检测kv 传输成功,但是长时间没有pd master来触发推理的请求, + # 释放这些超时请求占用的kv资源 + # ================================================================================== + + def timer_loop(self): try: - if len(alloc_tokened_tasks) != 0: - trans_obj.ready_to_move_queue.put(alloc_tokened_tasks) - except BaseException as e: + while True: + self._unfrozen_time_out_reqs_tokens() + time.sleep(3.5) + except (BaseException, RuntimeError) as e: logger.exception(str(e)) - self.put_to_fail_release_task_queue(alloc_tokened_tasks) - alloc_tokened_tasks = [] raise e - return ans_list + # ================================================================================== + # 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启 + # ================================================================================== + def check_trans_process_loop(self): + try: + while True: + for device_id in range(self.node_world_size): + if not self.kv_trans_processes[device_id].is_trans_process_health(): + raise Exception(f"device_id {device_id} kv process is unhealth") + + time.sleep(10.0) + except (BaseException, RuntimeError) as e: + logger.exception(str(e)) + + for device_id in range(self.node_world_size): + self.kv_trans_processes[device_id].killself() + + # 杀掉当前进程的父进程(router), 触发全局崩溃 + os.kill(os.getppid(), signal.SIGKILL) + os.kill(os.getpid(), signal.SIGKILL) + raise e + + # ================================================================================== + # 常用辅助功能函数 + # ================================================================================== def get_next_device_index(self): - counts = [ - 0 if self.is_kv_trans_process_alive(device_id) else (1 << 20) for device_id in range(self.node_world_size) - ] - for obj in self.node_id_to_trans_obj.values(): + counts = [0 for _ in range(self.node_world_size)] + for obj in self.connect_id_to_trans_obj.values(): counts[obj.device_index] += 1 device_index = int(np.argmin(counts)) return device_index def get_trans_obj(self, task: KVMoveTask): - self.remove_dead_trans_obj() - return self.node_id_to_trans_obj[task.prefill_node_id] + self.__remove_dead_trans_obj() + return self.connect_id_to_trans_obj[task.connect_id] - def remove_dead_trans_obj(self): - del_node_ids = [] - for node_id, t_obj in self.node_id_to_trans_obj.items(): + def __remove_dead_trans_obj(self): + del_connect_ids = [] + for connect_id, t_obj in self.connect_id_to_trans_obj.items(): if t_obj.has_error_status(): - del_node_ids.append(node_id) + del_connect_ids.append(connect_id) - for node_id in del_node_ids: - self.node_id_to_trans_obj.pop(node_id, None) + for connect_id in del_connect_ids: + self.connect_id_to_trans_obj.pop(connect_id, None) - if len(del_node_ids) != 0: + if del_connect_ids: import gc gc.collect() return - def remove_trans_obj(self, prefill_node_id): - if prefill_node_id in self.node_id_to_trans_obj: - trans_obj = self.node_id_to_trans_obj.pop(prefill_node_id, None) + def remove_trans_obj(self, connect_id): + if connect_id in self.connect_id_to_trans_obj: + trans_obj = self.connect_id_to_trans_obj.pop(connect_id, None) if trans_obj is not None: trans_obj.set_has_error() return - def remove_trans_obj_by_deviceid(self, device_id): - for node_id, t_obj in self.node_id_to_trans_obj.items(): - if t_obj.device_index == device_id: - self.remove_dead_trans_obj(node_id) - - def start_trans_process(self, device_id: int): - task_in_queue = mp.Queue() - task_out_queue = mp.Queue() - self.kv_trans_process_start_cnt[device_id] += 1 - - if self.kv_trans_processes[device_id]: - # force kill - try: - self.remove_trans_obj_by_deviceid(device_id) - process = psutil.Process(self.kv_trans_processes[device_id].pid) - process.kill() - self.kv_trans_processes[device_id] = None - except Exception: - pass - - try: - from .decode_trans_process import start_decode_trans_process - - kv_trans_process = start_decode_trans_process( - self.args, - device_id, - task_in_queue, - task_out_queue, - self.mem_queues, - ) - assert task_out_queue.get(timeout=30) == "proc_start" - self._put_mem_manager_to_mem_queue() - assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" - - self.kv_trans_processes[device_id] = kv_trans_process - self.kv_trans_task_in_queues[device_id] = task_in_queue - self.kv_trans_task_out_queues[device_id] = task_out_queue - - return True - except Exception as e: - logger.warning(f"Failed start kv trans process for device {device_id}: {e}") - return False - - def is_kv_trans_process_alive(self, device_id): - return self.kv_trans_process_start_cnt[device_id] <= KV_MOVE_MAX_START_CNT - - def check_trans_process(self, raise_exception=True): - at_least_one_alive = False - for device_id in range(self.node_world_size): - if not self.is_kv_trans_process_alive(device_id): - continue - - process = psutil.Process(self.kv_trans_processes[device_id].pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - logger.error(f"kv trans process for device: {device_id} dead!!!, try start again...") - self.start_trans_process(device_id) - else: - at_least_one_alive = True - - if not at_least_one_alive: - if raise_exception: - raise Exception("All trans process are dead!!!") - - return - - def timer_loop(self): - try: - while True: - self._unfrozen_time_out_reqs_tokens() - time.sleep(3.5) - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - raise e - - def check_trans_process_loop(self): - try: - while True: - self.check_trans_process() - time.sleep(10.0) - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) - # kill parent process if any exception occurred - os.kill(os.getppid(), signal.SIGTERM) - raise e - def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py new file mode 100644 index 000000000..4127c3545 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -0,0 +1,292 @@ +import time +import psutil +import threading +from typing import List +from dataclasses import dataclass +from lightllm.utils.log_utils import init_logger +from ..task_queue import TaskQueue +import torch.multiprocessing as mp +from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup +from lightllm.utils.device_utils import kv_trans_use_p2p +from .decode_kv_move_manager import DecodeKVMoveManager +from lightllm.utils.time_utils import TimeChecker +from ..utils import join_if_alive + +logger = init_logger(__name__) + +KV_MOVE_MAX_NUM = 16 + +@dataclass +class KVTransConnectObj: + connect_id: str = None + prefill_node_id: int = None + kv_trans_process: 'KVTransProcess' = None + pd_prefill_nccl_ip: str = None + pd_prefill_nccl_port: int = None + device_index: int = None + manager: "DecodeKVMoveManager" = None + has_error: bool = False + ready_to_move_queue: TaskQueue = None + kv_move_thread: threading.Thread = None + move_finished_queue: TaskQueue = None + put_to_radix_thread: threading.Thread = None + timer_checker: TimeChecker = None + + def create( + self, connect_id: str, prefill_node_id: str, pd_prefill_nccl_ip: str, pd_prefill_nccl_port: int, manager: "DecodeKVMoveManager" + ): + self.connect_id = connect_id + self.device_index = manager.get_next_device_index() + self.kv_trans_process = manager.kv_trans_processes[self.device_index] + decode_node_id = manager.args.pd_node_id + + with self.kv_trans_process.device_lock: + self.kv_trans_process.task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=-1, + pd_prefill_nccl_ip=pd_prefill_nccl_ip, + pd_prefill_nccl_port=pd_prefill_nccl_port, + decode_id=decode_node_id, + decode_device_id=self.device_index, + connect_id=self.connect_id + ) + ) + assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" + + self.prefill_node_id = prefill_node_id + self.decode_node_id = decode_node_id + self.pd_prefill_nccl_ip = pd_prefill_nccl_ip + self.pd_prefill_nccl_port = pd_prefill_nccl_port + + self.manager = manager + self.timer_checker = TimeChecker(3) + + self.ready_to_move_queue = TaskQueue( + get_func=lambda datas: datas[0:1], fail_func=self.manager.put_to_fail_release_task_queue + ) + self.kv_move_thread = threading.Thread(target=self.kv_move_loop, daemon=True) + self.kv_move_thread.start() + + self.move_finished_queue = TaskQueue( + get_func=lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=self.manager.put_to_fail_release_task_queue + ) + self.put_to_radix_thread = threading.Thread(target=self.put_to_radix_loop, daemon=True) + self.put_to_radix_thread.start() + return + + # ================================================================================== + # 处理接受所有进行 kv 传输的请求,完成后,将请求放入到 move_finished_queue 中 + # ================================================================================== + + def _transfer_kv(self, move_tasks: List[KVMoveTask]): + with self.kv_trans_process.device_lock: + kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) + kv_move_group.connect_id = self.connect_id + self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) + assert self.kv_trans_process.task_out_queue.get(timeout=60) == "ok" + logger.info(f"_transfer_kv ok {move_tasks[0].to_decode_log_info()}") + + # 标记 decode 接收到 kv cache 的时间 + for move_task in move_tasks: + move_task.mark_start_time = time.time() + + self.move_finished_queue.put_list(move_tasks) + move_tasks.clear() + + def kv_move_loop(self): + func_name = self.kv_move_loop.__name__ + while not self.has_error: + move_tasks: List[List[KVMoveTask]] = self.ready_to_move_queue.get_tasks(log_tag="ready_to_move_queue") + if len(move_tasks) == 0: + time.sleep(0.01) + continue + + if len(move_tasks) != 1: + logger.error(f"error get need 1, but get {len(move_tasks)}") + assert False + + move_tasks:List[KVMoveTask] = move_tasks[0] + for task in move_tasks: + logger.info(f"{func_name} get task {task.to_decode_log_info()}") + + try: + self.timer_to_check_status(raise_exception=True) + if not kv_trans_use_p2p(): + with self.manager.kv_trans_lock: + self._transfer_kv(move_tasks) + else: + self._transfer_kv(move_tasks) + + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + self.ready_to_move_queue.clear_tasks() + + finally: + self.manager.put_to_fail_release_task_queue(move_tasks) + + logger.error(f"{func_name} thread quit") + return + + # ================================================================================== + # 将传输完成的请求,放入到 radix cache 中进行管理。 + # ================================================================================== + + def put_to_radix_loop(self): + func_name = self.put_to_radix_loop.__name__ + while not self.has_error: + move_tasks: List[KVMoveTask] = self.move_finished_queue.get_tasks(log_tag="move_finished_queue") + if len(move_tasks) == 0: + time.sleep(0.01) + continue + + for task in move_tasks: + logger.info(f"{func_name} get put radix task {task.to_decode_log_info()}") + + try: + self.timer_to_check_status(raise_exception=True) + # random to check stats + self.manager._put_kv_received_to_radix_cache(move_tasks.copy()) + for task in move_tasks.copy(): + logger.info( + f"{func_name} put kv to radix cache ok, req_id: {task.id()} cost_time {task.get_cost_time()} s" + ) + self.manager.up_status_in_queue.put( + UpKVStatus(group_request_id=task.group_request_id, dp_index=task.decode_dp_index) + ) + logger.info(f"{func_name} up kv status req_id: {task.id()} finished") + move_tasks.clear() + + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + self.move_finished_queue.clear_tasks() + + finally: + self.manager.put_to_fail_release_task_queue(move_tasks) + + logger.error(f"{func_name} thread quit, info: {self.to_log_info()}") + return + + # ================================================================================== + # 错误处理检测操作的一些通用函数 + # ================================================================================== + + def timer_to_check_status(self, raise_exception=True): + if self.timer_checker.has_exceeded(): + try: + assert self.kv_trans_process.is_trans_process_health() + except BaseException as e: + logger.error(f"pid {self.kv_trans_process.process.pid} check failed") + logger.exception(str(e)) + + self.set_has_error() + if raise_exception: + raise e + return + + def has_error_status(self): + try: + assert self.has_error is False + assert self.kv_move_thread.is_alive() + assert self.put_to_radix_thread.is_alive() + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + return True + + return False + + def set_has_error(self): + self.has_error = True + + if self.ready_to_move_queue is not None: + self.ready_to_move_queue.has_error = True + + if self.move_finished_queue is not None: + self.move_finished_queue.has_error = True + + if self.manager is not None: + self.manager.remove_trans_obj(self.connect_id) + return + + def __del__(self): + logger.error(f"trans obj del start, info: {self.to_log_info()}") + + try: + self.set_has_error() + + join_if_alive(self.kv_move_thread) + join_if_alive(self.put_to_radix_thread) + + if self.connect_id is not None and self.kv_trans_process is not None: + self.kv_trans_process.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id)) + + if self.ready_to_move_queue is not None: + self.ready_to_move_queue.clear_tasks() + if self.move_finished_queue is not None: + self.move_finished_queue.clear_tasks() + + except BaseException as e: + logger.exception(str(e)) + + logger.error(f"trans obj deled, info: {self.to_log_info()}") + + def to_log_info(self): + log = f"connect_id: {self.connect_id} " + log += f"decode_node_id: {self.decode_node_id} " + log += f"prefill_node_id: {self.prefill_node_id} " + log += f"device_index: {self.device_index} " + return log + +@dataclass +class KVTransProcess: + process: mp.Process = None + # 需要每个卡有一个锁来规划每次只能有一个 connection obj 操作对应显卡上的传输任务。 + device_lock: threading.Lock = None + task_in_queue: mp.Queue = None + task_out_queue: mp.Queue = None + device_id: int = None + + + def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): + self.device_lock = threading.Lock() + self.device_id = device_id + self.task_in_queue = mp.Queue() + self.task_out_queue = mp.Queue() + + try: + from .decode_trans_process import start_decode_trans_process + + self.process = start_decode_trans_process( + manager.args, + device_id, + self.task_in_queue, + self.task_out_queue, + manager.mem_queues, + ) + assert self.task_out_queue.get(timeout=30) == "proc_start" + manager._put_mem_manager_to_mem_queue() + assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" + + return True + + except Exception as e: + logger.warning(f"Failed start kv trans process for device {device_id}: {e}") + logger.exception(str(e)) + return False + + def is_trans_process_health(self): + try: + process = psutil.Process(self.process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + logger.error(f"kv trans process for device: {self.device_id} dead!!!") + return False + else: + return True + except: + return False + + def killself(self): + self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index ad23fd38f..eec2a27d0 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -7,7 +7,7 @@ from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup @@ -19,24 +19,24 @@ def _handle_kvmove_task( move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, mem_managers: List[MemoryManager], - prefill_to_comm: Dict[int, PyNcclCommunicator], + connect_id_to_comm: Dict[str, PyNcclCommunicator], + connect_id: str, dp_size_in_node: int, ): total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) try: - prefill_id = move_tasks[0].prefill_node_id - device_index = prefill_to_comm[prefill_id].device.index + device_index = connect_id_to_comm[connect_id].device.index start = time.time() if total_move_kv_len != 0: cur_mem = mem_managers[device_index] logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") if kv_trans_use_p2p(): cur_mem.receive_from_prefill_node_p2p( - move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id] + move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] ) else: cur_mem.receive_from_prefill_node( - move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id] + move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] ) logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") torch.cuda.synchronize() @@ -49,7 +49,7 @@ def _handle_kvmove_task( def _handle_prefill_join( - node_info: PDTransJoinInfo, task_out_queue: mp.Queue, prefill_to_comm: Dict[int, PyNcclCommunicator] + node_info: PDTransJoinInfo, task_out_queue: mp.Queue, connect_id_to_comm: Dict[str, PyNcclCommunicator] ): try: store_client = TCPStore( @@ -59,10 +59,11 @@ def _handle_prefill_join( src_id=node_info.prefill_id, dest_id=node_info.decode_id, is_server=False, store=store_client ) comm = PyNcclCommunicator(group, node_info.decode_device_id) - prefill_to_comm[node_info.prefill_id] = comm + connect_id_to_comm[node_info.prefill_id] = comm logger.info(f"{node_info} kv trans connected") task_out_queue.put("nccl_ok") except Exception as e: + task_out_queue.put("nccl_fail") logger.warning(f"error while connect to prefill node: {e}") @@ -78,16 +79,20 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] task_out_queue.put("get_mem_managers_ok") - prefill_to_comm: Dict[int, PyNcclCommunicator] = {} + connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} while True: - task: Union[List, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() - if isinstance(task, List): - _handle_kvmove_task(task, task_out_queue, mem_managers, prefill_to_comm, dp_size_in_node) + task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() + if isinstance(task, KVMoveTaskGroup): + _handle_kvmove_task(task, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node) elif isinstance(task, PDTransJoinInfo): - _handle_prefill_join(task, task_out_queue, prefill_to_comm) + _handle_prefill_join(task, task_out_queue, connect_id_to_comm) elif isinstance(task, PDTransLeaveInfo): - prefill_to_comm[task.prefill_id].destroy() - logger.info(f"destory {task.prefill_id} nccl communicator.") + if task.connect_id in connect_id_to_comm: + connect_id_to_comm[task.prefill_id].destroy() + logger.info(f"destory {task} nccl communicator.") + else: + logger.info(f"no connect_id {task.connect_id} found in connect_id_to_comm") + else: logger.warning(f"unexpected task type: {task}") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 4196c81d1..9bf2a6847 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -11,303 +11,23 @@ import threading import inspect import collections -from dataclasses import dataclass from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from .prefill_infer_rpyc import PDPrefillInferRpcServer -from lightllm.common.mem_manager import MemoryManager import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo -from lightllm.utils.net_utils import find_available_port +from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.retry_utils import retry -from rpyc.utils.classic import obtain from rpyc import AsyncResult from lightllm.utils.net_utils import get_hostname_ip from ..task_queue import TaskQueue -from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_unique_server_name KV_MOVE_MAX_NUM = 16 -KV_MOVE_MAX_START_CNT = 3 logger = init_logger(__name__) -@dataclass -class TransProcessObj: - decode_node_id: int = None - rpyc_conn: object = None # rpyc_con 的连接对象 - process: mp.Process = None - task_in_queue: mp.Queue = None - task_out_queue: mp.Queue = None - device_index: int = None # 使用的gpu序号 - manager: "PrefillKVMoveManager" = None - has_error: bool = False - request_kv_trans_task_queue: TaskQueue = None - request_thread: threading.Thread = None - ready_kv_trans_task_queue: TaskQueue = None - kv_trans_thread: threading.Thread = None - latest_check_time: float = None - - def create( - self, decode_node_id: int, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" - ): - con = rpyc.connect( - host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True - ) - - device_index = manager.get_next_device_index() # 分配 trans 进程使用的显卡 - prefill_node_id = manager.args.pd_node_id - task_in_queue = manager.kv_trans_task_in_queues[device_index] - task_out_queue = manager.kv_trans_task_out_queues[device_index] - - with manager.device_locks[device_index]: - task_in_queue.put( - PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=device_index, - pd_prefill_nccl_ip=manager.host_ip, - pd_prefill_nccl_port=manager.kv_trans_ports[device_index], - decode_id=decode_node_id, - decode_device_id=-1, - ) - ) - - # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 - max_kv_trans_token_num = obtain( - con.root.build_trans_process( - prefill_node_id, - manager.host_ip, - manager.kv_trans_ports[device_index], - manager.args.max_total_token_num, - ) - ) - self.max_kv_trans_token_num = max_kv_trans_token_num - assert task_out_queue.get(timeout=60) == "nccl_ok" - - self.decode_node_id = decode_node_id - self.prefill_node_id = prefill_node_id - self.rpyc_conn = con - self.task_in_queue = task_in_queue - self.task_out_queue = task_out_queue - self.device_index = device_index - self.manager = manager - self.latest_check_time = time.time() - self.process = manager.kv_trans_processes[device_index] - - self.request_kv_trans_task_queue = TaskQueue( - get_func=self._get_request_tasks, fail_func=self.manager.put_to_release_task_queue - ) - self.request_thread = threading.Thread(target=self.request_kv_trans_loop, daemon=True) - self.request_thread.start() - - self.ready_kv_trans_task_queue = TaskQueue(lambda datas: datas[0:1], self.manager.put_to_release_task_queue) - self.kv_trans_thread = threading.Thread(target=self.kv_trans_handle_loop, daemon=True) - self.kv_trans_thread.start() - return - - def _get_request_tasks(self, datas: List[KVMoveTask]): - ans_list = [] - token_num = 0 - for task in datas: - if token_num + len(task.prefill_token_indexes) <= self.max_kv_trans_token_num: - ans_list.append(task) - token_num += len(task.prefill_token_indexes) - else: - break - return ans_list - - def check_connect(self, raise_exception=True): - try: - self.rpyc_conn.root.check_alive() - except BaseException as e: - self.set_has_error() - if raise_exception: - raise e - return - - def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.set_has_error() - if raise_exception: - raise Exception(f"trans process: {self.process.pid} is dead") - return - - def timer_check_status(self, raise_exception=True): - if time.time() - self.latest_check_time >= 2.0: - self.latest_check_time = time.time() - self.check_trans_process(raise_exception=raise_exception) - self.check_connect(raise_exception=raise_exception) - if self.has_error: - self.manager.remove_trans_obj(self.decode_node_id) - return - - def request_kv_trans_loop(self): - func_name = self.request_kv_trans_loop.__name__ - - while not self.has_error: - move_tasks: List[KVMoveTask] = self.request_kv_trans_task_queue.get_tasks( - log_tag="request_kv_trans_task_queue" - ) - if len(move_tasks) == 0: - # 周期检查通信状态 - self.timer_check_status(raise_exception=False) - time.sleep(0.01) - continue - try: - self.timer_check_status(raise_exception=True) - for move_task in move_tasks: - logger.info( - f"{func_name} get task {move_task.to_prefill_log_info()} " - f"queue time {move_task.get_cost_time()} s " - ) - - trans_move_tasks = [copy.copy(move_task) for move_task in move_tasks] - for trans_move_task in trans_move_tasks: - trans_move_task.prefill_token_indexes = None - - mark_start = time.time() - move_kv_lens = self.rpyc_conn.root.request_data_transfer(trans_move_tasks) - move_kv_lens = obtain(move_kv_lens) - request_data_transfer_cost_time = time.time() - mark_start - - logger.info( - f"{func_name} request_data_transfer ok, {move_tasks[0].to_prefill_log_info()}" - f" cost time: {request_data_transfer_cost_time} s" - ) - - ok_trans_list = [] - for i, move_task in enumerate(move_tasks.copy()): - if move_kv_lens[i] is not None: - move_task.move_kv_len = move_kv_lens[i] - ok_trans_list.append(move_task) - move_tasks.remove(move_task) - else: - logger.info(f"prefill node kv move task req_id: {move_task.id()} not send, decode is busy") - - if len(ok_trans_list) != 0: - self.ready_kv_trans_task_queue.put(ok_trans_list) - - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.manager.remove_trans_obj(self.decode_node_id) - self.request_kv_trans_task_queue.clear_tasks() - - finally: - self.manager.put_to_release_task_queue(move_tasks) - - logger.error(f"{func_name}, decode id {self.decode_node_id} device_index {self.device_index} thread quit") - return - - def _transfer_kv(self, move_tasks: List[KVMoveTask]): - with self.manager.device_locks[self.device_index]: - self.task_in_queue.put(move_tasks.copy(), timeout=10) - assert self.task_out_queue.get(timeout=60) == "ok" - self.manager.put_to_release_task_queue(move_tasks) - - logger.info( - f"_transfer_kv data ok, req_id: {move_tasks[0].id()}" - f" cost total time: {move_tasks[0].get_cost_time()} s" - ) - move_tasks.clear() - - def kv_trans_handle_loop(self): - func_name = self.kv_trans_handle_loop.__name__ - while not self.has_error: - move_tasks: List[List[KVMoveTask]] = self.ready_kv_trans_task_queue.get_tasks( - log_tag="ready_kv_trans_task_queue" - ) - if len(move_tasks) == 0: - self.timer_check_status(raise_exception=False) - time.sleep(0.01) - continue - - if len(move_tasks) != 1: - logger.error(f"error get kv trans move_tasks, must be 1, get {len(move_tasks)}") - assert len(move_tasks) == 1 - - move_tasks = move_tasks[0] - - try: - self.timer_check_status(raise_exception=True) - for move_task in move_tasks: - logger.info( - f"{func_name} get task {move_task.to_prefill_log_info()} to start kv move" - f"queue time {move_task.get_cost_time()} s " - ) - - if not kv_trans_use_p2p(): - with self.manager.kv_trans_lock: - self._transfer_kv(move_tasks) - else: - self._transfer_kv(move_tasks) - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - self.manager.remove_trans_obj(self.decode_node_id) - self.ready_kv_trans_task_queue.clear_tasks() - finally: - self.manager.put_to_release_task_queue(move_tasks) - - logger.error(f"trans kv thread, decode id {self.decode_node_id} device_index {self.device_index} thread quit") - return - - def wait_thread_quit(self): - if self.request_thread is not None: - if self.request_thread.is_alive(): - try: - self.request_thread.join() - except: - pass - if self.kv_trans_thread is not None: - if self.kv_trans_thread.is_alive(): - try: - self.kv_trans_thread.join() - except: - pass - return - - def has_error_status(self): - try: - assert self.has_error is False - assert self.request_thread.is_alive() - assert self.kv_trans_thread.is_alive() - except BaseException as e: - logger.exception(str(e)) - self.set_has_error() - return True - - return False - - def set_has_error(self): - self.has_error = True - try: - self.request_kv_trans_task_queue.has_error = True - self.ready_kv_trans_task_queue.has_error = True - except: - pass - return - - def __del__(self): - logger.error(f"trans obj del start, decode node id {self.decode_node_id} device_index {self.device_index}") - - try: - self.set_has_error() - self.wait_thread_quit() - self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) - if self.request_kv_trans_task_queue is not None: - self.request_kv_trans_task_queue.clear_tasks() - if self.ready_kv_trans_task_queue is not None: - self.ready_kv_trans_task_queue.clear_tasks() - except BaseException as e: - logger.exception(str(e)) - - logger.error(f"trans obj deled, decode node id {self.decode_node_id} device_index {self.device_index}") - - class PrefillKVMoveManager: def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.args = args @@ -321,7 +41,11 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.info_queue = info_queue self.mem_queues = mem_queues self.infer_rpyc_objs: List[PDPrefillInferRpcServer] = [] - self.node_id_to_trans_obj: Dict[str, TransProcessObj] = {} + + from .prefill_trans_obj import KVTransConnectObj + + self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {} + for port in self.args.pd_node_infer_rpyc_ports: socket_path = f"/tmp/{get_unique_server_name()}_prefill_node_infer_rpyc_{port}" from rpyc.utils.factory import unix_connect @@ -336,28 +60,45 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.infer_rpyc_lock = threading.Lock() self.kv_trans_lock = threading.Lock() - # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 - self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - # 释放token的task队列 self.release_task_queue = TaskQueue(lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=None) self.release_tasks_thread = threading.Thread(target=self.handle_release_task_loop, daemon=True) self.release_tasks_thread.start() - # start a single kv trans process - - from .prefill_trans_process import start_prefill_trans_process - - self.kv_trans_ports = [None] * self.node_world_size - self.kv_trans_processes = [None] * self.node_world_size - self.kv_trans_task_in_queues = [None] * self.node_world_size - self.kv_trans_task_out_queues = [None] * self.node_world_size - self.kv_trans_process_start_cnt = [0] * self.node_world_size - + from .prefill_trans_obj import KVTransProcess + + self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size for device_id in range(self.node_world_size): - assert self.start_trans_process(device_id) + self.kv_trans_processes[device_id] = KVTransProcess() + assert self.kv_trans_processes[device_id].init_all(device_id, self) return + + # ================================================================================== + # 主任务循环,接收需要进行kv传输的请求进行处理 + # ================================================================================== + + def task_dispatcher_loop(self): + try: + # 获取任务,并分发给相关卡的处理队列 + while True: + move_task: KVMoveTask = self.info_queue.get() + try: + trans_obj = self.__get_trans_obj(move_task) + trans_obj.request_kv_trans_task_queue.put(move_task) + except BaseException as e: + logger.exception(str(e)) + self.put_to_release_task_queue(move_task) + finally: + trans_obj = None + + except (BaseException, RuntimeError) as e: + logger.exception(str(e)) + raise e + + # ================================================================================== + # 请求出错或者完成kv传输后的处理队列和线程loop + # ================================================================================== def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): if isinstance(task, KVMoveTask): @@ -376,144 +117,34 @@ def handle_release_task_loop(self): else: self._remove_req_refs_from_prompt_cache(handle_list) return - - def start_trans_process(self, device_id: int): - task_in_queue = mp.Queue() - task_out_queue = mp.Queue() - kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) - self.kv_trans_process_start_cnt[device_id] += 1 - - if self.kv_trans_processes[device_id]: - # force kill - try: - self.remove_trans_obj_by_deviceid(device_id) - process = psutil.Process(self.kv_trans_processes[device_id].pid) - process.kill() - self.kv_trans_processes[device_id] = None - except Exception: - pass - - try: - from .prefill_trans_process import start_prefill_trans_process - - kv_trans_process = start_prefill_trans_process( - self.args, - self.host_ip, - kv_trans_port, - device_id, - task_in_queue, - task_out_queue, - self.mem_queues, - ) - assert task_out_queue.get(timeout=30) == "proc_start" - self._put_mem_manager_to_mem_queue() - assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" - - self.kv_trans_processes[device_id] = kv_trans_process - self.kv_trans_task_in_queues[device_id] = task_in_queue - self.kv_trans_task_out_queues[device_id] = task_out_queue - self.kv_trans_ports[device_id] = kv_trans_port - - return True - except Exception as e: - logger.warning(f"Failed start kv trans process for device {device_id}: {e}") - return False - - def check_trans_process(self, raise_exception=True): - at_least_one_alive = False - for device_id in range(self.node_world_size): - if not self.is_kv_trans_process_alive(device_id): - continue - - process = psutil.Process(self.kv_trans_processes[device_id].pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - logger.error(f"kv trans process for device: {device_id} dead!!!, try start again...") - self.start_trans_process(device_id) - else: - at_least_one_alive = True - - if not at_least_one_alive: - if raise_exception: - raise Exception("All trans process are dead!!!") - - return + + # ================================================================================== + # 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启 + # ================================================================================== def check_trans_process_loop(self): try: while True: - self.check_trans_process() + for device_id in range(self.node_world_size): + if not self.kv_trans_processes[device_id].is_trans_process_health(): + raise Exception(f"device_id {device_id} kv process is unhealth") + time.sleep(10.0) except (BaseException, RuntimeError) as e: logger.exception(str(e)) - # kill parent process if any exception occurred - os.kill(os.getppid(), signal.SIGTERM) - raise e - - def is_kv_trans_process_alive(self, device_id): - return self.kv_trans_process_start_cnt[device_id] <= KV_MOVE_MAX_START_CNT - - def get_next_device_index(self): - - counts = [ - 0 if self.is_kv_trans_process_alive(device_id) else (1 << 20) for device_id in range(self.node_world_size) - ] - for obj in self.node_id_to_trans_obj.values(): - counts[obj.device_index] += 1 - device_index = int(np.argmin(counts)) - return device_index - - def get_trans_obj(self, task: KVMoveTask): - self.remove_dead_trans_obj() - if task.decode_node.node_id not in self.node_id_to_trans_obj: - gc.collect() - trans_obj = TransProcessObj() - trans_obj.create(task.decode_node.node_id, task.decode_node.ip, task.decode_node.rpyc_port, self) - self.node_id_to_trans_obj[task.decode_node.node_id] = trans_obj - return self.node_id_to_trans_obj[task.decode_node.node_id] - - def remove_trans_obj(self, decode_node_id): - if decode_node_id in self.node_id_to_trans_obj: - trans_obj = self.node_id_to_trans_obj.pop(decode_node_id, None) - if trans_obj is not None: - trans_obj.set_has_error() - logger.error(f"remove tran obj id {trans_obj.decode_node_id}") - return - - def remove_dead_trans_obj(self): - del_node_ids = [] - for node_id, t_obj in self.node_id_to_trans_obj.items(): - if t_obj.has_error_status(): - del_node_ids.append(node_id) - - for node_id in del_node_ids: - self.node_id_to_trans_obj.pop(node_id, None) - - if len(del_node_ids) != 0: - gc.collect() - return - - def remove_trans_obj_by_deviceid(self, device_id): - for node_id, t_obj in self.node_id_to_trans_obj.items(): - if t_obj.device_index == device_id: - self.remove_dead_trans_obj(node_id) - - def task_dispatcher_loop(self): - try: - # 获取任务,并分发给相关卡的处理队列 - while True: - move_task: KVMoveTask = self.info_queue.get() - try: - trans_obj = self.get_trans_obj(move_task) - trans_obj.request_kv_trans_task_queue.put(move_task) - except BaseException as e: - logger.exception(str(e)) - self.put_to_release_task_queue(move_task) - finally: - trans_obj = None + + for device_id in range(self.node_world_size): + self.kv_trans_processes[device_id].killself() - except (BaseException, RuntimeError) as e: - logger.exception(str(e)) + # 杀掉当前进程的父进程(router), 触发全局崩溃 + os.kill(os.getppid(), signal.SIGKILL) + os.kill(os.getpid(), signal.SIGKILL) raise e + + # ================================================================================== + # 与推理进程交互接口, _remove_req_refs_from_prompt_cache 和 + # _put_mem_manager_to_mem_queue 都是通过 rpyc 与推理进程进行交互的接口 + # ================================================================================== def _remove_req_refs_from_prompt_cache(self, tasks: List[KVMoveTask]): with self.infer_rpyc_lock: @@ -541,7 +172,54 @@ def _put_mem_manager_to_mem_queue(self): async def wait_all_future_finish(self, futures: List[AsyncResult]): await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) return + + # ================================================================================== + # 辅助功能接口 + # ================================================================================== + + def get_next_device_index(self): + counts = [0 for _ in range(self.node_world_size)] + for obj in self.connect_id_to_trans_obj.values(): + counts[obj.device_index] += 1 + device_index = int(np.argmin(counts)) + return device_index + def remove_trans_obj(self, connect_id): + if connect_id in self.connect_id_to_trans_obj: + trans_obj = self.connect_id_to_trans_obj.pop(connect_id, None) + if trans_obj is not None: + trans_obj.set_has_error() + logger.error(f"remove tran obj id {trans_obj.decode_node_id}") + return + + def __get_trans_obj(self, task: KVMoveTask): + self.__remove_dead_trans_obj() + # 如果已经存在连接对象,直接返回 + for obj in self.connect_id_to_trans_obj.values(): + if obj.decode_node_id == task.decode_node.node_id: + return obj + + # 如果不存在连接对象,创建新的连接对象 + gc.collect() + from .prefill_trans_obj import KVTransConnectObj + + trans_obj = KVTransConnectObj() + trans_obj.create(task.decode_node.node_id, task.decode_node.ip, task.decode_node.rpyc_port, self) + self.connect_id_to_trans_obj[trans_obj.connect_id] = trans_obj + return trans_obj + + def __remove_dead_trans_obj(self): + del_connect_ids = [] + for connect_id, t_obj in self.connect_id_to_trans_obj.items(): + if t_obj.has_error_status(): + del_connect_ids.append(connect_id) + + for connect_id in del_connect_ids: + self.connect_id_to_trans_obj.pop(connect_id, None) + + if del_connect_ids: + gc.collect() + return def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py new file mode 100644 index 000000000..4ab56f242 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -0,0 +1,372 @@ +import time +import rpyc +import copy +import uuid +import numpy as np +import psutil +import threading +from dataclasses import dataclass +from typing import List, Dict, Union +from lightllm.utils.log_utils import init_logger +import torch.multiprocessing as mp +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup +from rpyc.utils.classic import obtain +from ..task_queue import TaskQueue +from lightllm.utils.device_utils import kv_trans_use_p2p +from lightllm.utils.time_utils import TimeChecker +from .prefill_kv_move_manager import PrefillKVMoveManager +from lightllm.utils.net_utils import find_available_port +from ..utils import join_if_alive + +logger = init_logger(__name__) + + +@dataclass +class KVTransConnectObj: + connect_id: str = None + decode_node_id: int = None + rpyc_conn: object = None # rpyc_con 的连接对象 + kv_trans_process: 'KVTransProcess' = None + device_index: int = None # 使用的gpu序号 + manager: "PrefillKVMoveManager" = None + has_error: bool = False + request_kv_trans_task_queue: TaskQueue = None + request_thread: threading.Thread = None + ready_kv_trans_task_queue: TaskQueue = None + kv_trans_thread: threading.Thread = None + timer_checker: TimeChecker = None + + # ================================================================================== + # 构建传输通信对象 + # ================================================================================== + + def create( + self, decode_node_id: int, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" + ): + device_index = manager.get_next_device_index() # 分配使用的显卡index + self.kv_trans_process = manager.kv_trans_processes[device_index] + prefill_node_id = manager.args.pd_node_id + self.connect_id = str(uuid.uuid4()) + + con = rpyc.connect( + host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True + ) + + # 创建 nccl 连接 + with self.kv_trans_process.device_lock: + self.kv_trans_process.task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=device_index, + pd_prefill_nccl_ip=manager.host_ip, + pd_prefill_nccl_port=self.kv_trans_process.kv_trans_port, + decode_id=decode_node_id, + decode_device_id=-1, + connect_id=self.connect_id + ) + ) + + # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 + max_kv_trans_token_num = obtain( + con.root.build_trans_connect( + prefill_node_id, + manager.host_ip, + self.kv_trans_process.kv_trans_port, + manager.args.max_total_token_num, + self.connect_id, + ) + ) + self.max_kv_trans_token_num = max_kv_trans_token_num + assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" + + self.decode_node_id = decode_node_id + self.prefill_node_id = prefill_node_id + self.rpyc_conn = con + self.device_index = device_index + self.manager = manager + self.timer_checker = TimeChecker(3) + + self.request_kv_trans_task_queue = TaskQueue( + get_func=self._get_request_tasks, fail_func=self.manager.put_to_release_task_queue + ) + self.request_thread = threading.Thread(target=self.request_kv_trans_loop, daemon=True) + self.request_thread.start() + + self.ready_kv_trans_task_queue = TaskQueue(lambda datas: datas[0:1], self.manager.put_to_release_task_queue) + self.kv_trans_thread = threading.Thread(target=self.kv_trans_handle_loop, daemon=True) + self.kv_trans_thread.start() + + logger.info(f"create KVTransConnectObj success: connect_id : {self.connect_id} prefill_id: {prefill_node_id}" + f" decode_id: {decode_node_id} device_index: {device_index} ") + return + + def _get_request_tasks(self, datas: List[KVMoveTask]): + """ + 根据可以p和d节点间协商得到的 max_kv_trans_token_num 限制,将排队等待 + 传输的请求打包成一个可以传输的list组。 + """ + ans_list = [] + token_num = 0 + for task in datas: + if token_num + len(task.prefill_token_indexes) <= self.max_kv_trans_token_num: + ans_list.append(task) + token_num += len(task.prefill_token_indexes) + else: + break + return ans_list + + # ================================================================================== + # 与 decode 节点进行元数据交互,申请锁定资源准备进行kv的传输 + # ================================================================================== + def request_kv_trans_loop(self): + func_name = self.request_kv_trans_loop.__name__ + + while not self.has_error: + move_tasks: List[KVMoveTask] = self.request_kv_trans_task_queue.get_tasks( + log_tag="request_kv_trans_task_queue" + ) + if len(move_tasks) == 0: + self.timer_check_status(raise_exception=False) + time.sleep(0.01) + continue + try: + self.timer_check_status(raise_exception=True) + for move_task in move_tasks: + move_task.connect_id = self.connect_id + logger.info( + f"{func_name} get task {move_task.to_prefill_log_info()} " + f"queue time {move_task.get_cost_time()} s " + ) + + trans_move_tasks = [copy.copy(move_task) for move_task in move_tasks] + for trans_move_task in trans_move_tasks: + trans_move_task.prefill_token_indexes = None + + mark_start = time.time() + move_kv_lens = self.rpyc_conn.root.request_data_transfer(trans_move_tasks) + move_kv_lens = obtain(move_kv_lens) + request_data_transfer_cost_time = time.time() - mark_start + + logger.info( + f"{func_name} request_data_transfer ok, {move_tasks[0].to_prefill_log_info()}" + f" cost time: {request_data_transfer_cost_time} s" + ) + + ok_trans_list = [] + for i, move_task in enumerate(move_tasks.copy()): + if move_kv_lens[i] is not None: + move_task.move_kv_len = move_kv_lens[i] + ok_trans_list.append(move_task) + move_tasks.remove(move_task) + else: + logger.info(f"prefill node kv move task req_id: {move_task.id()} not send, decode is busy") + + if ok_trans_list: + self.ready_kv_trans_task_queue.put(ok_trans_list, error_handle_func=self.manager.put_to_release_task_queue) + + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + self.request_kv_trans_task_queue.clear_tasks() + + finally: + # 将没有申请成功的请求放入到释放队列中 + self.manager.put_to_release_task_queue(move_tasks) + + logger.error(f"{func_name}, {self.to_log_info()} thread quit") + return + + # ================================================================================== + # 将准备好 kv 传输的请求进行 kv 传输 + # ================================================================================== + def _transfer_kv(self, move_tasks: List[KVMoveTask]): + with self.kv_trans_process.device_lock: + kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) + self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) + assert self.kv_trans_process.task_out_queue.get(timeout=60) == "ok" + self.manager.put_to_release_task_queue(move_tasks) + + logger.info( + f"_transfer_kv data ok, req_id: {move_tasks[0].id()}" + f" cost total time: {move_tasks[0].get_cost_time()} s" + ) + move_tasks.clear() + + def kv_trans_handle_loop(self): + func_name = self.kv_trans_handle_loop.__name__ + while not self.has_error: + move_tasks: List[List[KVMoveTask]] = self.ready_kv_trans_task_queue.get_tasks( + log_tag="ready_kv_trans_task_queue" + ) + if len(move_tasks) == 0: + self.timer_check_status(raise_exception=False) + time.sleep(0.01) + continue + + if len(move_tasks) != 1: + logger.error(f"error get kv trans move_tasks, must be 1, get {len(move_tasks)}") + assert len(move_tasks) == 1 + + move_tasks: List[KVMoveTask] = move_tasks[0] + + try: + self.timer_check_status(raise_exception=True) + for move_task in move_tasks: + logger.info( + f"{func_name} get task {move_task.to_prefill_log_info()} to start kv move" + f"queue time {move_task.get_cost_time()} s " + ) + + if not kv_trans_use_p2p(): + with self.manager.kv_trans_lock: + self._transfer_kv(move_tasks) + else: + self._transfer_kv(move_tasks) + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + self.ready_kv_trans_task_queue.clear_tasks() + finally: + self.manager.put_to_release_task_queue(move_tasks) + + logger.error(f"trans kv thread, {self.to_log_info()} thread quit") + return + + # ================================================================================== + # 错误处理检测操作的一些通用函数 + # ================================================================================== + + def has_error_status(self): + try: + assert self.has_error is False + assert self.request_thread.is_alive() + assert self.kv_trans_thread.is_alive() + except BaseException as e: + logger.exception(str(e)) + self.set_has_error() + return True + + return False + + def timer_check_status(self, raise_exception=True): + if self.timer_checker.has_exceeded(): + try: + self.rpyc_conn.root.check_alive() + assert self.kv_trans_process.is_trans_process_health() + except BaseException as e: + logger.error(f"pid {self.kv_trans_process.process.pid} check failed") + logger.exception(str(e)) + + self.set_has_error() + if raise_exception: + raise e + + return + + def set_has_error(self): + """ + 将当前传输对象标记为有错误,这样可以防止请求放入到处理队列中 + """ + self.has_error = True + + if self.request_kv_trans_task_queue is not None: + self.request_kv_trans_task_queue.has_error = True + + if self.ready_kv_trans_task_queue is not None: + self.ready_kv_trans_task_queue.has_error = True + + if self.manager is not None: + self.manager.remove_trans_obj(self.connect_id) + return + + def __del__(self): + """ + 函数中有很多判断是否是None的操作,主要是为了避免一些异常流程的del行为不报错。 + """ + logger.error(f"trans obj del start, info: {self.to_log_info()}") + + try: + self.set_has_error() + + join_if_alive(self.request_thread) + join_if_alive(self.kv_trans_thread) + + # 将未处理的请求,清理掉,clear_tasks 会将没处理完的请求 + # 放入到 manager 资源释放队列中 + if self.request_kv_trans_task_queue is not None: + self.request_kv_trans_task_queue.clear_tasks() + if self.ready_kv_trans_task_queue is not None: + self.ready_kv_trans_task_queue.clear_tasks() + + # 传输进程清理掉 nccl 连接 + if self.connect_id is not None: + self.kv_trans_process.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, + prefill_id=self.prefill_node_id, + connect_id=self.connect_id)) + + except BaseException as e: + logger.exception(str(e)) + + logger.error(f"trans obj deled, info: {self.to_log_info()}") + + + def to_log_info(self): + log = f"connect_id: {self.connect_id} " + log += f"decode_node_id: {self.decode_node_id} " + log += f"prefill_node_id: {self.prefill_node_id} " + log += f"device_index: {self.device_index} " + return log + + +@dataclass +class KVTransProcess: + process: mp.Process = None + # 需要每个卡有一个锁来规划每次只能有一个 connection obj 操作对应显卡上的传输任务。 + device_lock: threading.Lock = None + task_in_queue: mp.Queue = None + task_out_queue: mp.Queue = None + device_id: int = None + kv_trans_port: int = None + + def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): + self.device_id = device_id + self.device_lock = threading.Lock() + self.task_in_queue = mp.Queue() + self.task_out_queue = mp.Queue() + self.kv_trans_port = find_available_port(manager.args.pd_p_allowed_port_min, manager.args.pd_p_allowed_port_max) + + try: + from .prefill_trans_process import start_prefill_trans_process + + self.process = start_prefill_trans_process( + manager.args, + manager.host_ip, + self.kv_trans_port, + device_id, + self.task_in_queue, + self.task_out_queue, + manager.mem_queues, + ) + assert self.task_out_queue.get(timeout=30) == "proc_start" + manager._put_mem_manager_to_mem_queue() + assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" + + return True + except Exception as e: + logger.warning(f"Failed start kv trans process for device {device_id}: {e}") + logger.exception(str(e)) + return False + + def is_trans_process_health(self): + try: + process = psutil.Process(self.process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + logger.error(f"kv trans process for device: {self.device_id} dead!!!") + return False + else: + return True + except: + return False + + def killself(self): + self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 62327a11c..c9cd9c664 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -7,7 +7,7 @@ from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry from lightllm.distributed.pynccl import StatelessP2PProcessGroup, PyNcclCommunicator @@ -20,21 +20,21 @@ def _handle_kvmove_task( move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, mem_managers: List[MemoryManager], - decode_to_comm: Dict[int, PyNcclCommunicator], + connect_id_to_comm: Dict[int, PyNcclCommunicator], + connect_id: str, dp_size_in_node: int, ): total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) try: - decode_id = move_tasks[0].decode_node.node_id - device_index = decode_to_comm[decode_id].device.index + device_index = connect_id_to_comm[connect_id].device.index start = time.time() if total_move_kv_len != 0: logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") cur_mem = mem_managers[device_index] if kv_trans_use_p2p(): - cur_mem.send_to_decode_node_p2p(move_tasks, mem_managers, dp_size_in_node, decode_to_comm[decode_id]) + cur_mem.send_to_decode_node_p2p(move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id]) else: - cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node, decode_to_comm[decode_id]) + cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id]) logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") torch.cuda.synchronize() logger.info( @@ -48,16 +48,17 @@ def _handle_kvmove_task( def _handle_decode_join( - node_info: PDTransJoinInfo, task_out_queue: mp.Queue, decode_to_comm: Dict[str, PyNcclCommunicator], store: TCPStore + node_info: PDTransJoinInfo, task_out_queue: mp.Queue, connect_id_to_comm: Dict[str, PyNcclCommunicator], store: TCPStore ): try: group = StatelessP2PProcessGroup.create(node_info.prefill_id, node_info.decode_id, True, store) comm = PyNcclCommunicator(group, node_info.prefill_device_id) - decode_to_comm[node_info.decode_id] = comm + connect_id_to_comm[node_info.connect_id] = comm logger.info(f"{node_info} kv trans connected!") task_out_queue.put("nccl_ok") except Exception as e: - logger.warning(f"error while connect to decode node: {e}") + task_out_queue.put("nccl_fail") + logger.warning(f"error while connect to decode node: {e} node_info {node_info}") def _init_env( @@ -77,17 +78,21 @@ def _init_env( task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] task_out_queue.put("get_mem_managers_ok") - decode_to_comm: Dict[int, PyNcclCommunicator] = {} + connect_id_to_comm: Dict[str, PyNcclCommunicator] = {} while True: - task: Union[List, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() - if isinstance(task, List): - _handle_kvmove_task(task, task_out_queue, mem_managers, decode_to_comm, dp_size_in_node) + task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() + if isinstance(task, KVMoveTaskGroup): + _handle_kvmove_task(task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node) elif isinstance(task, PDTransJoinInfo): - _handle_decode_join(task, task_out_queue, decode_to_comm, master_store) + _handle_decode_join(task, task_out_queue, connect_id_to_comm, master_store) elif isinstance(task, PDTransLeaveInfo): - decode_to_comm[task.decode_id].destroy() - logger.info(f"destory {task.decode_id} nccl communicator.") + if task.connect_id in connect_id_to_comm: + connect_id_to_comm[task.connect_id].destroy() + connect_id_to_comm.pop(task.connect_id, None) + logger.info(f"destory {task} nccl communicator.") + else: + logger.error(f"connect id {task.connect_id} dont exist in connect_id_to_comm") else: logger.warning(f"unexpected task type: {task}") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py index 9dd4b3c5f..7b856e54a 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/task_queue.py @@ -15,8 +15,10 @@ def __init__(self, get_func, fail_func): def size(self): return len(self.datas) - def put(self, obj): + def put(self, obj, error_handle_func=None): if self.has_error: + if error_handle_func is not None: + error_handle_func(obj) raise Exception("has error") with self.lock: diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py new file mode 100644 index 000000000..241cf93e8 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py @@ -0,0 +1,9 @@ +import threading + +def join_if_alive(thread:threading.Thread): + if thread is not None and thread.is_alive(): + try: + thread.join() + except Exception: + pass + return \ No newline at end of file diff --git a/lightllm/utils/process_check.py b/lightllm/utils/process_check.py index 75a8f890d..00cc258bf 100644 --- a/lightllm/utils/process_check.py +++ b/lightllm/utils/process_check.py @@ -42,5 +42,5 @@ def start_parent_check_thread(): """ 检测父进程是否健康,如果出现问题,清理退出所有进程 """ - thread = threading.Thread(target=check_parent_alive) + thread = threading.Thread(target=check_parent_alive, daemon=True) thread.start() diff --git a/lightllm/utils/time_utils.py b/lightllm/utils/time_utils.py new file mode 100644 index 000000000..6b313c093 --- /dev/null +++ b/lightllm/utils/time_utils.py @@ -0,0 +1,16 @@ +import time + +class TimeChecker: + def __init__(self, threshold): + self.threshold = threshold + self.last_checked = time.time() + + def has_exceeded(self): + current_time = time.time() + if (current_time - self.last_checked) > self.threshold: + self._reset() + return True + return False + + def _reset(self): + self.last_checked = time.time() \ No newline at end of file From a3831ed0f81c1356eda8753dc8258a7642d26126 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 10 Apr 2025 09:52:59 +0800 Subject: [PATCH 08/20] format --- .../decode_kv_move_manager.py | 33 +++++++------ .../decode_node_impl/decode_trans_obj.py | 46 +++++++++++------- .../decode_node_impl/decode_trans_process.py | 4 +- .../prefill_kv_move_manager.py | 25 +++++----- .../prefill_node_impl/prefill_trans_obj.py | 47 ++++++++++--------- .../prefill_trans_process.py | 13 +++-- .../continues_batch/pd_mode/utils.py | 5 +- 7 files changed, 101 insertions(+), 72 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index efa34c0d5..457fd1b9c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -30,6 +30,7 @@ KV_MOVE_MAX_NUM = 16 + class DecodeKVMoveManager(rpyc.Service): def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): super().__init__() @@ -45,7 +46,7 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.mem_queues = mem_queues self.infer_rpyc_lock = threading.Lock() self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = [] - + from .decode_trans_obj import KVTransConnectObj self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {} @@ -70,16 +71,16 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # 在不使用p2p 复制kv 的方案时,需要全局的传输锁进行控制。这个时候kv传输的效率会下降。 self.kv_trans_lock = threading.Lock() - + from .decode_trans_obj import KVTransProcess - + self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size for device_id in range(self.node_world_size): self.kv_trans_processes[device_id] = KVTransProcess() assert self.kv_trans_processes[device_id].init_all(device_id, self) return - + # ================================================================================== # _dp_alloc_to_frozen_some_tokens # _put_kv_received_to_radix_cache @@ -158,13 +159,13 @@ def _put_mem_manager_to_mem_queue(self) -> None: for obj in self.infer_rpyc_objs: obj.put_mem_manager_to_mem_queue() return - + # ================================================================================== # put_to_fail_release_task_queue 将因为一些原因失败,需要释放锁定的kv资源的请求放入到 # 对应的处理队列中,handle_fail_release_task_loop 是一个循环的线程,专门处理这些失败的请求 # 通过调用与推理进程交互的接口,释放掉申请锁定的 kv 资源。 # ================================================================================== - + def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): if isinstance(task, KVMoveTask): self.fail_to_release_queue.put(task) @@ -182,9 +183,9 @@ def handle_fail_release_task_loop(self): else: self._fail_to_realese_forzen_tokens(handle_list) return - + # ================================================================================== - # on_connect + # on_connect # on_disconnect # exposed_check_alive # exposed_build_trans_process @@ -278,12 +279,14 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona self.remove_trans_obj(tasks[0].connect_id) logger.exception(str(e)) raise e - + if alloc_tokened_tasks: - trans_obj.ready_to_move_queue.put(alloc_tokened_tasks, error_handle_func=self.put_to_fail_release_task_queue) + trans_obj.ready_to_move_queue.put( + alloc_tokened_tasks, error_handle_func=self.put_to_fail_release_task_queue + ) return ans_list - + # ================================================================================== # 定时检测kv 传输成功,但是长时间没有pd master来触发推理的请求, # 释放这些超时请求占用的kv资源 @@ -308,11 +311,11 @@ def check_trans_process_loop(self): for device_id in range(self.node_world_size): if not self.kv_trans_processes[device_id].is_trans_process_health(): raise Exception(f"device_id {device_id} kv process is unhealth") - + time.sleep(10.0) except (BaseException, RuntimeError) as e: logger.exception(str(e)) - + for device_id in range(self.node_world_size): self.kv_trans_processes[device_id].killself() @@ -320,12 +323,12 @@ def check_trans_process_loop(self): os.kill(os.getppid(), signal.SIGKILL) os.kill(os.getpid(), signal.SIGKILL) raise e - + # ================================================================================== # 常用辅助功能函数 # ================================================================================== def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] + counts = [0 for _ in range(self.node_world_size)] for obj in self.connect_id_to_trans_obj.values(): counts[obj.device_index] += 1 device_index = int(np.argmin(counts)) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py index 4127c3545..d2497a77a 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -16,11 +16,12 @@ KV_MOVE_MAX_NUM = 16 + @dataclass class KVTransConnectObj: connect_id: str = None prefill_node_id: int = None - kv_trans_process: 'KVTransProcess' = None + kv_trans_process: "KVTransProcess" = None pd_prefill_nccl_ip: str = None pd_prefill_nccl_port: int = None device_index: int = None @@ -33,8 +34,13 @@ class KVTransConnectObj: timer_checker: TimeChecker = None def create( - self, connect_id: str, prefill_node_id: str, pd_prefill_nccl_ip: str, pd_prefill_nccl_port: int, manager: "DecodeKVMoveManager" - ): + self, + connect_id: str, + prefill_node_id: str, + pd_prefill_nccl_ip: str, + pd_prefill_nccl_port: int, + manager: "DecodeKVMoveManager", + ): self.connect_id = connect_id self.device_index = manager.get_next_device_index() self.kv_trans_process = manager.kv_trans_processes[self.device_index] @@ -49,7 +55,7 @@ def create( pd_prefill_nccl_port=pd_prefill_nccl_port, decode_id=decode_node_id, decode_device_id=self.device_index, - connect_id=self.connect_id + connect_id=self.connect_id, ) ) assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" @@ -74,7 +80,7 @@ def create( self.put_to_radix_thread = threading.Thread(target=self.put_to_radix_loop, daemon=True) self.put_to_radix_thread.start() return - + # ================================================================================== # 处理接受所有进行 kv 传输的请求,完成后,将请求放入到 move_finished_queue 中 # ================================================================================== @@ -106,7 +112,7 @@ def kv_move_loop(self): logger.error(f"error get need 1, but get {len(move_tasks)}") assert False - move_tasks:List[KVMoveTask] = move_tasks[0] + move_tasks: List[KVMoveTask] = move_tasks[0] for task in move_tasks: logger.info(f"{func_name} get task {task.to_decode_log_info()}") @@ -128,7 +134,7 @@ def kv_move_loop(self): logger.error(f"{func_name} thread quit") return - + # ================================================================================== # 将传输完成的请求,放入到 radix cache 中进行管理。 # ================================================================================== @@ -168,11 +174,11 @@ def put_to_radix_loop(self): logger.error(f"{func_name} thread quit, info: {self.to_log_info()}") return - + # ================================================================================== # 错误处理检测操作的一些通用函数 # ================================================================================== - + def timer_to_check_status(self, raise_exception=True): if self.timer_checker.has_exceeded(): try: @@ -203,10 +209,10 @@ def set_has_error(self): if self.ready_to_move_queue is not None: self.ready_to_move_queue.has_error = True - + if self.move_finished_queue is not None: self.move_finished_queue.has_error = True - + if self.manager is not None: self.manager.remove_trans_obj(self.connect_id) return @@ -219,15 +225,19 @@ def __del__(self): join_if_alive(self.kv_move_thread) join_if_alive(self.put_to_radix_thread) - + if self.connect_id is not None and self.kv_trans_process is not None: - self.kv_trans_process.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id)) + self.kv_trans_process.task_in_queue.put( + PDTransLeaveInfo( + decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id + ) + ) if self.ready_to_move_queue is not None: self.ready_to_move_queue.clear_tasks() if self.move_finished_queue is not None: self.move_finished_queue.clear_tasks() - + except BaseException as e: logger.exception(str(e)) @@ -240,6 +250,7 @@ def to_log_info(self): log += f"device_index: {self.device_index} " return log + @dataclass class KVTransProcess: process: mp.Process = None @@ -249,7 +260,6 @@ class KVTransProcess: task_out_queue: mp.Queue = None device_id: int = None - def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): self.device_lock = threading.Lock() self.device_id = device_id @@ -271,12 +281,12 @@ def init_all(self, device_id: int, manager: "DecodeKVMoveManager"): assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok" return True - + except Exception as e: logger.warning(f"Failed start kv trans process for device {device_id}: {e}") logger.exception(str(e)) return False - + def is_trans_process_health(self): try: process = psutil.Process(self.process.pid) @@ -287,6 +297,6 @@ def is_trans_process_health(self): return True except: return False - + def killself(self): self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index eec2a27d0..2c140e7ee 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -83,7 +83,9 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. while True: task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() if isinstance(task, KVMoveTaskGroup): - _handle_kvmove_task(task, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node) + _handle_kvmove_task( + task, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node + ) elif isinstance(task, PDTransJoinInfo): _handle_prefill_join(task, task_out_queue, connect_id_to_comm) elif isinstance(task, PDTransLeaveInfo): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 9bf2a6847..6b9b17182 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -66,14 +66,14 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.release_tasks_thread.start() from .prefill_trans_obj import KVTransProcess - + self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size for device_id in range(self.node_world_size): self.kv_trans_processes[device_id] = KVTransProcess() assert self.kv_trans_processes[device_id].init_all(device_id, self) return - + # ================================================================================== # 主任务循环,接收需要进行kv传输的请求进行处理 # ================================================================================== @@ -95,7 +95,7 @@ def task_dispatcher_loop(self): except (BaseException, RuntimeError) as e: logger.exception(str(e)) raise e - + # ================================================================================== # 请求出错或者完成kv传输后的处理队列和线程loop # ================================================================================== @@ -117,7 +117,7 @@ def handle_release_task_loop(self): else: self._remove_req_refs_from_prompt_cache(handle_list) return - + # ================================================================================== # 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启 # ================================================================================== @@ -128,11 +128,11 @@ def check_trans_process_loop(self): for device_id in range(self.node_world_size): if not self.kv_trans_processes[device_id].is_trans_process_health(): raise Exception(f"device_id {device_id} kv process is unhealth") - + time.sleep(10.0) except (BaseException, RuntimeError) as e: logger.exception(str(e)) - + for device_id in range(self.node_world_size): self.kv_trans_processes[device_id].killself() @@ -140,9 +140,9 @@ def check_trans_process_loop(self): os.kill(os.getppid(), signal.SIGKILL) os.kill(os.getpid(), signal.SIGKILL) raise e - + # ================================================================================== - # 与推理进程交互接口, _remove_req_refs_from_prompt_cache 和 + # 与推理进程交互接口, _remove_req_refs_from_prompt_cache 和 # _put_mem_manager_to_mem_queue 都是通过 rpyc 与推理进程进行交互的接口 # ================================================================================== @@ -172,7 +172,7 @@ def _put_mem_manager_to_mem_queue(self): async def wait_all_future_finish(self, futures: List[AsyncResult]): await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures]) return - + # ================================================================================== # 辅助功能接口 # ================================================================================== @@ -191,18 +191,18 @@ def remove_trans_obj(self, connect_id): trans_obj.set_has_error() logger.error(f"remove tran obj id {trans_obj.decode_node_id}") return - + def __get_trans_obj(self, task: KVMoveTask): self.__remove_dead_trans_obj() # 如果已经存在连接对象,直接返回 for obj in self.connect_id_to_trans_obj.values(): if obj.decode_node_id == task.decode_node.node_id: return obj - + # 如果不存在连接对象,创建新的连接对象 gc.collect() from .prefill_trans_obj import KVTransConnectObj - + trans_obj = KVTransConnectObj() trans_obj.create(task.decode_node.node_id, task.decode_node.ip, task.decode_node.rpyc_port, self) self.connect_id_to_trans_obj[trans_obj.connect_id] = trans_obj @@ -221,6 +221,7 @@ def __remove_dead_trans_obj(self): gc.collect() return + def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py index 4ab56f242..789f5715b 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -26,7 +26,7 @@ class KVTransConnectObj: connect_id: str = None decode_node_id: int = None rpyc_conn: object = None # rpyc_con 的连接对象 - kv_trans_process: 'KVTransProcess' = None + kv_trans_process: "KVTransProcess" = None device_index: int = None # 使用的gpu序号 manager: "PrefillKVMoveManager" = None has_error: bool = False @@ -51,7 +51,7 @@ def create( con = rpyc.connect( host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True ) - + # 创建 nccl 连接 with self.kv_trans_process.device_lock: self.kv_trans_process.task_in_queue.put( @@ -62,7 +62,7 @@ def create( pd_prefill_nccl_port=self.kv_trans_process.kv_trans_port, decode_id=decode_node_id, decode_device_id=-1, - connect_id=self.connect_id + connect_id=self.connect_id, ) ) @@ -96,8 +96,10 @@ def create( self.kv_trans_thread = threading.Thread(target=self.kv_trans_handle_loop, daemon=True) self.kv_trans_thread.start() - logger.info(f"create KVTransConnectObj success: connect_id : {self.connect_id} prefill_id: {prefill_node_id}" - f" decode_id: {decode_node_id} device_index: {device_index} ") + logger.info( + f"create KVTransConnectObj success: connect_id : {self.connect_id} prefill_id: {prefill_node_id}" + f" decode_id: {decode_node_id} device_index: {device_index} " + ) return def _get_request_tasks(self, datas: List[KVMoveTask]): @@ -114,7 +116,7 @@ def _get_request_tasks(self, datas: List[KVMoveTask]): else: break return ans_list - + # ================================================================================== # 与 decode 节点进行元数据交互,申请锁定资源准备进行kv的传输 # ================================================================================== @@ -162,7 +164,9 @@ def request_kv_trans_loop(self): logger.info(f"prefill node kv move task req_id: {move_task.id()} not send, decode is busy") if ok_trans_list: - self.ready_kv_trans_task_queue.put(ok_trans_list, error_handle_func=self.manager.put_to_release_task_queue) + self.ready_kv_trans_task_queue.put( + ok_trans_list, error_handle_func=self.manager.put_to_release_task_queue + ) except BaseException as e: logger.exception(str(e)) @@ -175,7 +179,7 @@ def request_kv_trans_loop(self): logger.error(f"{func_name}, {self.to_log_info()} thread quit") return - + # ================================================================================== # 将准备好 kv 传输的请求进行 kv 传输 # ================================================================================== @@ -231,7 +235,7 @@ def kv_trans_handle_loop(self): logger.error(f"trans kv thread, {self.to_log_info()} thread quit") return - + # ================================================================================== # 错误处理检测操作的一些通用函数 # ================================================================================== @@ -247,7 +251,7 @@ def has_error_status(self): return True return False - + def timer_check_status(self, raise_exception=True): if self.timer_checker.has_exceeded(): try: @@ -260,7 +264,7 @@ def timer_check_status(self, raise_exception=True): self.set_has_error() if raise_exception: raise e - + return def set_has_error(self): @@ -271,7 +275,7 @@ def set_has_error(self): if self.request_kv_trans_task_queue is not None: self.request_kv_trans_task_queue.has_error = True - + if self.ready_kv_trans_task_queue is not None: self.ready_kv_trans_task_queue.has_error = True @@ -300,23 +304,24 @@ def __del__(self): # 传输进程清理掉 nccl 连接 if self.connect_id is not None: - self.kv_trans_process.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, - prefill_id=self.prefill_node_id, - connect_id=self.connect_id)) - + self.kv_trans_process.task_in_queue.put( + PDTransLeaveInfo( + decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id + ) + ) + except BaseException as e: logger.exception(str(e)) logger.error(f"trans obj deled, info: {self.to_log_info()}") - def to_log_info(self): log = f"connect_id: {self.connect_id} " log += f"decode_node_id: {self.decode_node_id} " log += f"prefill_node_id: {self.prefill_node_id} " log += f"device_index: {self.device_index} " return log - + @dataclass class KVTransProcess: @@ -334,7 +339,7 @@ def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): self.task_in_queue = mp.Queue() self.task_out_queue = mp.Queue() self.kv_trans_port = find_available_port(manager.args.pd_p_allowed_port_min, manager.args.pd_p_allowed_port_max) - + try: from .prefill_trans_process import start_prefill_trans_process @@ -356,7 +361,7 @@ def init_all(self, device_id: int, manager: "PrefillKVMoveManager"): logger.warning(f"Failed start kv trans process for device {device_id}: {e}") logger.exception(str(e)) return False - + def is_trans_process_health(self): try: process = psutil.Process(self.process.pid) @@ -367,6 +372,6 @@ def is_trans_process_health(self): return True except: return False - + def killself(self): self.process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index c9cd9c664..c77e9e493 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -32,7 +32,9 @@ def _handle_kvmove_task( logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") cur_mem = mem_managers[device_index] if kv_trans_use_p2p(): - cur_mem.send_to_decode_node_p2p(move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id]) + cur_mem.send_to_decode_node_p2p( + move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] + ) else: cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id]) logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") @@ -48,7 +50,10 @@ def _handle_kvmove_task( def _handle_decode_join( - node_info: PDTransJoinInfo, task_out_queue: mp.Queue, connect_id_to_comm: Dict[str, PyNcclCommunicator], store: TCPStore + node_info: PDTransJoinInfo, + task_out_queue: mp.Queue, + connect_id_to_comm: Dict[str, PyNcclCommunicator], + store: TCPStore, ): try: group = StatelessP2PProcessGroup.create(node_info.prefill_id, node_info.decode_id, True, store) @@ -83,7 +88,9 @@ def _init_env( while True: task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() if isinstance(task, KVMoveTaskGroup): - _handle_kvmove_task(task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node) + _handle_kvmove_task( + task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node + ) elif isinstance(task, PDTransJoinInfo): _handle_decode_join(task, task_out_queue, connect_id_to_comm, master_store) elif isinstance(task, PDTransLeaveInfo): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py index 241cf93e8..dbfd61c53 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py @@ -1,9 +1,10 @@ import threading -def join_if_alive(thread:threading.Thread): + +def join_if_alive(thread: threading.Thread): if thread is not None and thread.is_alive(): try: thread.join() except Exception: pass - return \ No newline at end of file + return From 2189ffe46f6fe196ac73744cb71e9de979a46c57 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 10 Apr 2025 10:45:21 +0800 Subject: [PATCH 09/20] fix --- .../pd_mode/prefill_node_impl/prefill_trans_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index c77e9e493..7e9dad6f5 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -20,7 +20,7 @@ def _handle_kvmove_task( move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, mem_managers: List[MemoryManager], - connect_id_to_comm: Dict[int, PyNcclCommunicator], + connect_id_to_comm: Dict[str, PyNcclCommunicator], connect_id: str, dp_size_in_node: int, ): From 89e40680a8177f32a9f86f4b0f9ffe98703edd59 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 10 Apr 2025 11:07:33 +0800 Subject: [PATCH 10/20] fix. --- .../decode_node_impl/decode_trans_obj.py | 15 +++++++-------- .../prefill_node_impl/prefill_trans_obj.py | 19 ++++++++----------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py index d2497a77a..e30b255b7 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -45,6 +45,13 @@ def create( self.device_index = manager.get_next_device_index() self.kv_trans_process = manager.kv_trans_processes[self.device_index] decode_node_id = manager.args.pd_node_id + self.prefill_node_id = prefill_node_id + self.decode_node_id = decode_node_id + self.pd_prefill_nccl_ip = pd_prefill_nccl_ip + self.pd_prefill_nccl_port = pd_prefill_nccl_port + + self.manager = manager + self.timer_checker = TimeChecker(3) with self.kv_trans_process.device_lock: self.kv_trans_process.task_in_queue.put( @@ -60,14 +67,6 @@ def create( ) assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" - self.prefill_node_id = prefill_node_id - self.decode_node_id = decode_node_id - self.pd_prefill_nccl_ip = pd_prefill_nccl_ip - self.pd_prefill_nccl_port = pd_prefill_nccl_port - - self.manager = manager - self.timer_checker = TimeChecker(3) - self.ready_to_move_queue = TaskQueue( get_func=lambda datas: datas[0:1], fail_func=self.manager.put_to_fail_release_task_queue ) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py index 789f5715b..e807688bf 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -47,11 +47,18 @@ def create( self.kv_trans_process = manager.kv_trans_processes[device_index] prefill_node_id = manager.args.pd_node_id self.connect_id = str(uuid.uuid4()) + self.decode_node_id = decode_node_id + self.prefill_node_id = prefill_node_id + self.device_index = device_index + self.manager = manager + self.timer_checker = TimeChecker(3) con = rpyc.connect( host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True ) + self.rpyc_conn = con + # 创建 nccl 连接 with self.kv_trans_process.device_lock: self.kv_trans_process.task_in_queue.put( @@ -79,13 +86,6 @@ def create( self.max_kv_trans_token_num = max_kv_trans_token_num assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok" - self.decode_node_id = decode_node_id - self.prefill_node_id = prefill_node_id - self.rpyc_conn = con - self.device_index = device_index - self.manager = manager - self.timer_checker = TimeChecker(3) - self.request_kv_trans_task_queue = TaskQueue( get_func=self._get_request_tasks, fail_func=self.manager.put_to_release_task_queue ) @@ -96,10 +96,7 @@ def create( self.kv_trans_thread = threading.Thread(target=self.kv_trans_handle_loop, daemon=True) self.kv_trans_thread.start() - logger.info( - f"create KVTransConnectObj success: connect_id : {self.connect_id} prefill_id: {prefill_node_id}" - f" decode_id: {decode_node_id} device_index: {device_index} " - ) + logger.info(f"create KVTransConnectObj success: {self.to_log_info()}") return def _get_request_tasks(self, datas: List[KVMoveTask]): From a01bd9009c535b1b1521b0a6b6f536fd7f3e999d Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 10 Apr 2025 12:16:01 +0800 Subject: [PATCH 11/20] fix --- .../pd_mode/decode_node_impl/decode_trans_process.py | 6 +++--- .../pd_mode/prefill_node_impl/prefill_kv_move_manager.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 2c140e7ee..f82d2c11b 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -59,7 +59,7 @@ def _handle_prefill_join( src_id=node_info.prefill_id, dest_id=node_info.decode_id, is_server=False, store=store_client ) comm = PyNcclCommunicator(group, node_info.decode_device_id) - connect_id_to_comm[node_info.prefill_id] = comm + connect_id_to_comm[node_info.connect_id] = comm logger.info(f"{node_info} kv trans connected") task_out_queue.put("nccl_ok") except Exception as e: @@ -84,13 +84,13 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() if isinstance(task, KVMoveTaskGroup): _handle_kvmove_task( - task, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node + task.tasks, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node ) elif isinstance(task, PDTransJoinInfo): _handle_prefill_join(task, task_out_queue, connect_id_to_comm) elif isinstance(task, PDTransLeaveInfo): if task.connect_id in connect_id_to_comm: - connect_id_to_comm[task.prefill_id].destroy() + connect_id_to_comm[task.connect_id].destroy() logger.info(f"destory {task} nccl communicator.") else: logger.info(f"no connect_id {task.connect_id} found in connect_id_to_comm") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 6b9b17182..a54b54980 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -189,7 +189,7 @@ def remove_trans_obj(self, connect_id): trans_obj = self.connect_id_to_trans_obj.pop(connect_id, None) if trans_obj is not None: trans_obj.set_has_error() - logger.error(f"remove tran obj id {trans_obj.decode_node_id}") + logger.error(f"remove tran obj decode_node_id {trans_obj.decode_node_id}") return def __get_trans_obj(self, task: KVMoveTask): From e41f3652d5374054b89c68549cd50993fa331d6b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 10 Apr 2025 15:54:24 +0800 Subject: [PATCH 12/20] fix --- .../decode_node_impl/decode_trans_obj.py | 2 +- .../decode_node_impl/decode_trans_process.py | 21 +++++++++++++++---- .../prefill_node_impl/prefill_trans_obj.py | 2 +- .../prefill_trans_process.py | 20 ++++++++++++++++-- 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py index e30b255b7..7c05a3c30 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -51,7 +51,7 @@ def create( self.pd_prefill_nccl_port = pd_prefill_nccl_port self.manager = manager - self.timer_checker = TimeChecker(3) + self.timer_checker = TimeChecker(6) with self.kv_trans_process.device_lock: self.kv_trans_process.task_in_queue.put( diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index f82d2c11b..5fe2ba5a9 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -4,6 +4,7 @@ import inspect import torch.multiprocessing as mp from torch.distributed import TCPStore +from datetime import timedelta from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager @@ -52,12 +53,17 @@ def _handle_prefill_join( node_info: PDTransJoinInfo, task_out_queue: mp.Queue, connect_id_to_comm: Dict[str, PyNcclCommunicator] ): try: + logger.info(f"connect start {node_info}") store_client = TCPStore( - host_name=node_info.pd_prefill_nccl_ip, port=node_info.pd_prefill_nccl_port, is_master=False, use_libuv=True - ) - group = StatelessP2PProcessGroup.create( - src_id=node_info.prefill_id, dest_id=node_info.decode_id, is_server=False, store=store_client + host_name=node_info.pd_prefill_nccl_ip, port=node_info.pd_prefill_nccl_port, is_master=False, use_libuv=True, timeout=timedelta(seconds=30) ) + src_id = node_info.prefill_id + dest_id = node_info.connect_id + logger.info(f"connect src_id {src_id} dest_id {dest_id}") + group = StatelessP2PProcessGroup.create(src_id=src_id, + dest_id=dest_id, + is_server=False, + store=store_client) comm = PyNcclCommunicator(group, node_info.decode_device_id) connect_id_to_comm[node_info.connect_id] = comm logger.info(f"{node_info} kv trans connected") @@ -68,6 +74,13 @@ def _handle_prefill_join( def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]): + import os + + # os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_MAX_NCHANNELS"] = "2" + os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" + os.environ["NCCL_SOCKET_NTHREADS"] = "1" + torch.backends.cudnn.enabled = False dp_size_in_node = max(1, args.dp // args.nnodes) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py index e807688bf..93639facc 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -51,7 +51,7 @@ def create( self.prefill_node_id = prefill_node_id self.device_index = device_index self.manager = manager - self.timer_checker = TimeChecker(3) + self.timer_checker = TimeChecker(6) con = rpyc.connect( host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 7e9dad6f5..07ee17164 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -4,6 +4,7 @@ import inspect import torch.multiprocessing as mp from torch.distributed import TCPStore +from datetime import timedelta from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager @@ -56,7 +57,14 @@ def _handle_decode_join( store: TCPStore, ): try: - group = StatelessP2PProcessGroup.create(node_info.prefill_id, node_info.decode_id, True, store) + logger.info(f"connect start {node_info}") + src_id = node_info.prefill_id + dest_id = node_info.connect_id + logger.info(f"connect src_id {src_id} dest_id {dest_id}") + group = StatelessP2PProcessGroup.create(src_id=src_id, + dest_id=dest_id, + is_server=True, + store=store) comm = PyNcclCommunicator(group, node_info.prefill_device_id) connect_id_to_comm[node_info.connect_id] = comm logger.info(f"{node_info} kv trans connected!") @@ -75,10 +83,18 @@ def _init_env( task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): + import os + + # os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_MAX_NCHANNELS"] = "2" + os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" + os.environ["NCCL_SOCKET_NTHREADS"] = "1" + torch.backends.cudnn.enabled = False + try: torch.cuda.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) - master_store = TCPStore(host_name=store_ip, port=store_port, is_master=True, use_libuv=True) + master_store = TCPStore(host_name=store_ip, port=store_port, is_master=True, use_libuv=True, timeout=timedelta(seconds=30)) dp_size_in_node = max(1, args.dp // args.nnodes) task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] From 7453390bb9f84cdb29a33f47fb80aaa2ddf8f7bd Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 10 Apr 2025 16:07:21 +0800 Subject: [PATCH 13/20] fix --- .../pd_mode/decode_node_impl/decode_trans_process.py | 11 ++++++----- .../prefill_node_impl/prefill_trans_process.py | 9 ++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 5fe2ba5a9..6cdb7edbb 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -55,15 +55,16 @@ def _handle_prefill_join( try: logger.info(f"connect start {node_info}") store_client = TCPStore( - host_name=node_info.pd_prefill_nccl_ip, port=node_info.pd_prefill_nccl_port, is_master=False, use_libuv=True, timeout=timedelta(seconds=30) + host_name=node_info.pd_prefill_nccl_ip, + port=node_info.pd_prefill_nccl_port, + is_master=False, + use_libuv=True, + timeout=timedelta(seconds=30), ) src_id = node_info.prefill_id dest_id = node_info.connect_id logger.info(f"connect src_id {src_id} dest_id {dest_id}") - group = StatelessP2PProcessGroup.create(src_id=src_id, - dest_id=dest_id, - is_server=False, - store=store_client) + group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=False, store=store_client) comm = PyNcclCommunicator(group, node_info.decode_device_id) connect_id_to_comm[node_info.connect_id] = comm logger.info(f"{node_info} kv trans connected") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 07ee17164..48df7a12d 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -61,10 +61,7 @@ def _handle_decode_join( src_id = node_info.prefill_id dest_id = node_info.connect_id logger.info(f"connect src_id {src_id} dest_id {dest_id}") - group = StatelessP2PProcessGroup.create(src_id=src_id, - dest_id=dest_id, - is_server=True, - store=store) + group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=True, store=store) comm = PyNcclCommunicator(group, node_info.prefill_device_id) connect_id_to_comm[node_info.connect_id] = comm logger.info(f"{node_info} kv trans connected!") @@ -94,7 +91,9 @@ def _init_env( try: torch.cuda.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) - master_store = TCPStore(host_name=store_ip, port=store_port, is_master=True, use_libuv=True, timeout=timedelta(seconds=30)) + master_store = TCPStore( + host_name=store_ip, port=store_port, is_master=True, use_libuv=True, timeout=timedelta(seconds=30) + ) dp_size_in_node = max(1, args.dp // args.nnodes) task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] From 1373a074416119bf8d18eeaeec7acb3d176fa98a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 10 Apr 2025 20:27:44 +0800 Subject: [PATCH 14/20] format. --- lightllm/server/pd_io_struct.py | 2 +- lightllm/utils/time_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index ad228ff8f..160f2ca18 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -79,7 +79,7 @@ class DecodeNodeInfo: class PDTransJoinInfo: decode_id: int decode_device_id: int - prefill_id: int + prefill_id: int prefill_device_id: int pd_prefill_nccl_ip: str pd_prefill_nccl_port: int diff --git a/lightllm/utils/time_utils.py b/lightllm/utils/time_utils.py index 6b313c093..f065d0437 100644 --- a/lightllm/utils/time_utils.py +++ b/lightllm/utils/time_utils.py @@ -3,7 +3,7 @@ class TimeChecker: def __init__(self, threshold): self.threshold = threshold - self.last_checked = time.time() + self.last_checked = time.time() def has_exceeded(self): current_time = time.time() From d9d3f4fba427381fa0f5ee2665afe2f6a5e7c492 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 10 Apr 2025 20:28:44 +0800 Subject: [PATCH 15/20] reformat. --- lightllm/server/pd_io_struct.py | 5 +++-- lightllm/utils/time_utils.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 160f2ca18..22405867c 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -79,7 +79,7 @@ class DecodeNodeInfo: class PDTransJoinInfo: decode_id: int decode_device_id: int - prefill_id: int + prefill_id: int prefill_device_id: int pd_prefill_nccl_ip: str pd_prefill_nccl_port: int @@ -144,7 +144,8 @@ def get_cost_time(self): else: return 100000000000 + @dataclass class KVMoveTaskGroup: tasks: List[KVMoveTask] - connect_id: str \ No newline at end of file + connect_id: str diff --git a/lightllm/utils/time_utils.py b/lightllm/utils/time_utils.py index f065d0437..648108d2b 100644 --- a/lightllm/utils/time_utils.py +++ b/lightllm/utils/time_utils.py @@ -1,9 +1,10 @@ import time + class TimeChecker: def __init__(self, threshold): self.threshold = threshold - self.last_checked = time.time() + self.last_checked = time.time() def has_exceeded(self): current_time = time.time() @@ -11,6 +12,6 @@ def has_exceeded(self): self._reset() return True return False - + def _reset(self): - self.last_checked = time.time() \ No newline at end of file + self.last_checked = time.time() From 18437d2fd9528fe71151e945fdf1bb8e9d093315 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 10 Apr 2025 20:59:56 +0800 Subject: [PATCH 16/20] add async nccl connect. --- .../decode_node_impl/decode_trans_process.py | 21 ++++++++++++++++--- .../prefill_trans_process.py | 20 +++++++++++++++--- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 6cdb7edbb..7a0391df0 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -2,6 +2,7 @@ import time import sys import inspect +import threading import torch.multiprocessing as mp from torch.distributed import TCPStore from datetime import timedelta @@ -64,9 +65,23 @@ def _handle_prefill_join( src_id = node_info.prefill_id dest_id = node_info.connect_id logger.info(f"connect src_id {src_id} dest_id {dest_id}") - group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=False, store=store_client) - comm = PyNcclCommunicator(group, node_info.decode_device_id) - connect_id_to_comm[node_info.connect_id] = comm + + result_list = [] + + def async_connect(): + torch.cuda.set_device(node_info.decode_device_id) + group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=False, store=store_client) + comm = PyNcclCommunicator(group, node_info.decode_device_id) + result_list.append(comm) + return + + connect_task = threading.Thread(target=async_connect, daemon=True) + connect_task.start() + connect_task.join(timeout=50) + if connect_task.is_alive(): + raise Exception(f"{node_info} connect time out") + + connect_id_to_comm[node_info.connect_id] = result_list[0] logger.info(f"{node_info} kv trans connected") task_out_queue.put("nccl_ok") except Exception as e: diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 48df7a12d..7936b721a 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -2,6 +2,7 @@ import time import sys import inspect +import threading import torch.multiprocessing as mp from torch.distributed import TCPStore from datetime import timedelta @@ -61,9 +62,22 @@ def _handle_decode_join( src_id = node_info.prefill_id dest_id = node_info.connect_id logger.info(f"connect src_id {src_id} dest_id {dest_id}") - group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=True, store=store) - comm = PyNcclCommunicator(group, node_info.prefill_device_id) - connect_id_to_comm[node_info.connect_id] = comm + result_list = [] + + def async_connect(): + torch.cuda.set_device(node_info.prefill_device_id) + group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=True, store=store) + comm = PyNcclCommunicator(group, node_info.prefill_device_id) + result_list.append(comm) + return + + connect_task = threading.Thread(target=async_connect, daemon=True) + connect_task.start() + connect_task.join(timeout=50) + if connect_task.is_alive(): + raise Exception(f"{node_info} connect time out") + + connect_id_to_comm[node_info.connect_id] = result_list[0] logger.info(f"{node_info} kv trans connected!") task_out_queue.put("nccl_ok") except Exception as e: From 5dcc4b8799cf59bffad2ef2ad30aa29ae1f4ce27 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 11 Apr 2025 13:23:13 +0800 Subject: [PATCH 17/20] fix --- .../pd_mode/decode_node_impl/decode_trans_obj.py | 4 +++- .../pd_mode/decode_node_impl/decode_trans_process.py | 2 +- .../pd_mode/prefill_node_impl/prefill_trans_obj.py | 10 ++++++++-- .../pd_mode/prefill_node_impl/prefill_trans_process.py | 2 +- .../mode_backend/continues_batch/pd_mode/utils.py | 10 +++++++++- 5 files changed, 22 insertions(+), 6 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py index 7c05a3c30..fd42b3772 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py @@ -10,7 +10,7 @@ from lightllm.utils.device_utils import kv_trans_use_p2p from .decode_kv_move_manager import DecodeKVMoveManager from lightllm.utils.time_utils import TimeChecker -from ..utils import join_if_alive +from ..utils import join_if_alive, clear_queue logger = init_logger(__name__) @@ -54,6 +54,7 @@ def create( self.timer_checker = TimeChecker(6) with self.kv_trans_process.device_lock: + clear_queue(self.kv_trans_process.task_out_queue) self.kv_trans_process.task_in_queue.put( PDTransJoinInfo( prefill_id=prefill_node_id, @@ -86,6 +87,7 @@ def create( def _transfer_kv(self, move_tasks: List[KVMoveTask]): with self.kv_trans_process.device_lock: + clear_queue(self.kv_trans_process.task_out_queue) kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) kv_move_group.connect_id = self.connect_id self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 7a0391df0..782c95326 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -77,7 +77,7 @@ def async_connect(): connect_task = threading.Thread(target=async_connect, daemon=True) connect_task.start() - connect_task.join(timeout=50) + connect_task.join(timeout=36) if connect_task.is_alive(): raise Exception(f"{node_info} connect time out") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py index 93639facc..eed48e684 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -16,7 +16,7 @@ from lightllm.utils.time_utils import TimeChecker from .prefill_kv_move_manager import PrefillKVMoveManager from lightllm.utils.net_utils import find_available_port -from ..utils import join_if_alive +from ..utils import join_if_alive, clear_queue logger = init_logger(__name__) @@ -54,13 +54,18 @@ def create( self.timer_checker = TimeChecker(6) con = rpyc.connect( - host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True + host=decode_node_ip, + port=decode_node_rpyc_port, + config={"allow_pickle": True, "sync_request_timeout": 60}, + keepalive=True ) self.rpyc_conn = con # 创建 nccl 连接 with self.kv_trans_process.device_lock: + clear_queue(self.kv_trans_process.task_out_queue) + self.kv_trans_process.task_in_queue.put( PDTransJoinInfo( prefill_id=prefill_node_id, @@ -182,6 +187,7 @@ def request_kv_trans_loop(self): # ================================================================================== def _transfer_kv(self, move_tasks: List[KVMoveTask]): with self.kv_trans_process.device_lock: + clear_queue(self.kv_trans_process.task_out_queue) kv_move_group = KVMoveTaskGroup(tasks=move_tasks.copy(), connect_id=self.connect_id) self.kv_trans_process.task_in_queue.put(kv_move_group, timeout=10) assert self.kv_trans_process.task_out_queue.get(timeout=60) == "ok" diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 7936b721a..3e42a532d 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -73,7 +73,7 @@ def async_connect(): connect_task = threading.Thread(target=async_connect, daemon=True) connect_task.start() - connect_task.join(timeout=50) + connect_task.join(timeout=36) if connect_task.is_alive(): raise Exception(f"{node_info} connect time out") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py index dbfd61c53..38c1c58d0 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py @@ -1,5 +1,6 @@ import threading - +import torch.multiprocessing as mp +from queue import Empty def join_if_alive(thread: threading.Thread): if thread is not None and thread.is_alive(): @@ -8,3 +9,10 @@ def join_if_alive(thread: threading.Thread): except Exception: pass return + +def clear_queue(queue: mp.Queue): + while not queue.empty(): + try: + queue.get_nowait() + except Empty: + break From 923d6e3864b6065c897e1656fe8e6300a3559546 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 11 Apr 2025 13:34:42 +0800 Subject: [PATCH 18/20] fix --- .../pd_mode/prefill_node_impl/prefill_trans_obj.py | 8 ++++---- .../mode_backend/continues_batch/pd_mode/utils.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py index eed48e684..f53761e09 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_obj.py @@ -54,10 +54,10 @@ def create( self.timer_checker = TimeChecker(6) con = rpyc.connect( - host=decode_node_ip, - port=decode_node_rpyc_port, - config={"allow_pickle": True, "sync_request_timeout": 60}, - keepalive=True + host=decode_node_ip, + port=decode_node_rpyc_port, + config={"allow_pickle": True, "sync_request_timeout": 60}, + keepalive=True, ) self.rpyc_conn = con diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py index 38c1c58d0..cd1360fd0 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/utils.py @@ -2,6 +2,7 @@ import torch.multiprocessing as mp from queue import Empty + def join_if_alive(thread: threading.Thread): if thread is not None and thread.is_alive(): try: @@ -10,6 +11,7 @@ def join_if_alive(thread: threading.Thread): pass return + def clear_queue(queue: mp.Queue): while not queue.empty(): try: From 66c5d28b99622f654c42a6aa869b607d9f86e9d4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 11 Apr 2025 17:23:49 +0800 Subject: [PATCH 19/20] fix --- lightllm/server/router/manager.py | 4 ++++ .../pd_mode/decode_node_impl/decode_infer_rpyc.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 51411c301..42b6d6453 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -240,10 +240,12 @@ async def loop_for_fwd( ) / self.max_total_token_num d_i = dp_index frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i) + estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) logger.debug( f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" f"dp_i {d_i} paused req num: {self.req_queue.get_paused_req_num()} \n" f"dp_i {d_i} frozen token num: {frozen_token_num} \n" + f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" ) @@ -270,7 +272,9 @@ async def loop_for_fwd( if log_time_ready("frozen_info", 60): for dp_i in range(self.dp_size_in_node): frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i) + estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(dp_i) logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n") + logger.debug(f"dp_i {dp_i} estimated_peak_token_count: {estimated_peak_token_count} \n") if self.running_batch is None: await asyncio.sleep(0.01) # 10ms diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py index 0b161f5de..8f88237ec 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py @@ -71,6 +71,20 @@ def recover_frozen_token(self, key_len, max_new_token): def _alloc_to_frozen_some_tokens(self, move_task: KVMoveTask): is_ok = self.judge_token_is_ok(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) if not is_ok: + if self.is_master_in_dp: + logger.info(f"req_id: {move_task.to_decode_log_info()} alloc token failed") + shared_token_load = self.backend.shared_token_load + dp_rank = self.dp_rank_in_node + frozen_token_num = shared_token_load.get_frozened_token_count(dp_rank) + estimated_peak_token_num = shared_token_load.get_estimated_peak_token_count(dp_rank) + logger.debug( + f"radix refed token num {self.backend.radix_cache.get_refed_tokens_num()}\n" + f"radix hold token num {self.backend.radix_cache.get_tree_total_tokens_num()}\n" + f"mem manager can alloc token num {self.backend.model.mem_manager.can_use_mem_size}\n" + f"mem manager total size {self.backend.model.mem_manager.size}" + f"frozened token num {frozen_token_num}\n" + f"estimated peak token num {estimated_peak_token_num}\n" + ) return None key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") From 65b04b7bcccd1d3946d10defc4992b7f405fec8a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 11 Apr 2025 21:19:35 +0800 Subject: [PATCH 20/20] fix --- lightllm/server/router/manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 42b6d6453..5424c12da 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -98,6 +98,7 @@ def __init__(self, args, router_port, detokenization_port, metric_port): self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval) self.metric_client = MetricClient(metric_port) self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode"] + self.is_pd_decode_mode = self.args.run_mode == "decode" # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 # 主要是为了防止调度失误,造成 OOM 等错误 self.router_lock = mp.Lock() @@ -249,7 +250,8 @@ async def loop_for_fwd( f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" ) - self.req_queue.update_token_load(self.running_batch, force_update=False) + # pd decode mode need to update token_load more frequently + self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode) self.stats_tool.print_stats() self.metric_client.gauge_set("lightllm_batch_current_size", len(self.running_batch.reqs)) self.metric_client.gauge_set("lightllm_batch_pause_size", self.req_queue.get_paused_req_num())