Skip to content

add mps for multimodal #817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
@@ -193,6 +193,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models."
)
parser.add_argument(
"--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service."
)
parser.add_argument("--enable_custom_allreduce", action="store_true", help="Whether to disable cutom allreduce.")
parser.add_argument("--enable_custom_allgather", action="store_true", help="Whether to enable cutom allgather.")
parser.add_argument(
7 changes: 7 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
@@ -63,6 +63,13 @@ def signal_handler(sig, frame):
def normal_or_p_d_start(args):
set_unique_server_name(args)

if args.enable_mps:
from lightllm.utils.device_utils import enable_mps, set_gpu_exclusive_mode

for i in range(args.tp):
set_gpu_exclusive_mode(gpu_index=i)
enable_mps()

if args.run_mode not in ["normal", "prefill", "decode"]:
return

6 changes: 4 additions & 2 deletions lightllm/server/visualserver/manager.py
Original file line number Diff line number Diff line change
@@ -55,7 +55,10 @@ async def wait_to_model_ready(self):
for dp_rank_id in range(self.vit_dp):
tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id]
for tp_rank_id in range(self.vit_tp):
rpc_model = await start_model_process(port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp)
device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id]
rpc_model = await start_model_process(
port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp, device_id=device_id
)
self.model_rpcs[dp_rank_id].append(rpc_model)

init_model_ret = []
@@ -159,7 +162,6 @@ def start_visual_process(args, router_port, visual_port, cache_port, model_rpc_p
# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)
start_parent_check_thread()

try:
visualserver = VisualManager(args, router_port, visual_port, cache_port, model_rpc_ports)
asyncio.run(visualserver.wait_to_model_ready())
17 changes: 14 additions & 3 deletions lightllm/server/visualserver/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
from lightllm.utils.dist_utils import init_vision_distributed_env
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.envs_utils import get_env_start_args


class VisualModelRpcServer(rpyc.Service):
@@ -139,19 +140,29 @@ async def encode(self, images: List[ImageItem]):
return ans


def _init_env(port):
def _init_env(port, device_id):
# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)
from lightllm.utils.device_utils import set_sm_limit

if get_env_start_args().enable_mps:
set_sm_limit(60, device_id) # the visual server can take up to 60% of the sm

t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True})
t.start()
return


async def start_model_process(port, vit_tp):
async def start_model_process(port, vit_tp, device_id):
import multiprocessing

proc = multiprocessing.Process(target=_init_env, args=(port,))
proc = multiprocessing.Process(
target=_init_env,
args=(
port,
device_id,
),
)
proc.start()
await asyncio.sleep(2)
repeat_count = 0
114 changes: 113 additions & 1 deletion lightllm/utils/device_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
from functools import lru_cache
import time
import shutil
import subprocess
from functools import lru_cache
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


@lru_cache(maxsize=None)
@@ -99,3 +104,110 @@ def has_nvlink():
except subprocess.CalledProcessError:
# If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink
return False


def is_mps_running(verbose=False):
result = subprocess.run(
"ps -ef | grep '[n]vidia-cuda-mps-control'",
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return result.returncode == 0


def stop_mps():
if is_mps_running():
result = subprocess.run("echo quit | nvidia-cuda-mps-control", shell=True)
logger.info("Stopping MPS...")
if result.returncode == 0:
logger.info("MPS stopped successfully.")
else:
logger.warning("Failed to stop MPS.")
else:
logger.info("MPS is not running, no need to stop.")


def enable_mps():
if is_mps_running():
logger.info("MPS is already running, no need to start.")
return

ret = os.system("nvidia-cuda-mps-control -d")

time.sleep(10)
if ret != 0:
logger.warning("Failed to start MPS.")
return
if is_mps_running():
logger.info("MPS started successfully.")
return


def get_gpu_compute_mode(gpu_index=0):
try:
if not shutil.which("nvidia-smi"):
logger.warning("nvidia-smi not found in PATH.")
return None

cmd = ["nvidia-smi", "-i", str(gpu_index), "--query-gpu=compute_mode", "--format=csv,noheader"]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

if result.returncode != 0:
logger.warning(f"Failed to query compute mode: {result.stderr.strip()}")
return None

mode = result.stdout.strip()
return mode

except Exception as e:
logger.warning(f"Exception occurred while checking GPU compute mode: {e}")
return None


def set_gpu_exclusive_mode(gpu_index=0):
logger.info(f"Setting GPU {gpu_index} to EXCLUSIVE_PROCESS mode...")
result = subprocess.run(
["nvidia-smi", "-i", str(gpu_index), "-c", "EXCLUSIVE_PROCESS"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode == 0:
logger.info(f"GPU {gpu_index} set to EXCLUSIVE_PROCESS mode.")
return True
else:
logger.warning(f"Failed to set EXCLUSIVE_PROCESS mode: {result.stderr.strip()}")
return False


def set_gpu_default_mode(gpu_index=0):
logger.info(f"Setting GPU {gpu_index} to DEFAULT mode...")
result = subprocess.run(
["nvidia-smi", "-i", str(gpu_index), "-c", "DEFAULT"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode == 0:
logger.info(f"GPU {gpu_index} set to DEFAULT mode.")
return True
else:
logger.warning(f"Failed to set DEFAULT mode: {result.stderr.strip()}")
return False


def set_sm_limit(percent: int, gpu_index=0):
"""
Sets CUDA_MPS_ACTIVE_THREAD_PERCENTAGE to the given value if the GPU is in EXCLUSIVE_PROCESS mode.
"""
if not (1 <= percent <= 100):
logger.error("SM usage percentage must be between 1 and 100.")
return False

mode = get_gpu_compute_mode(gpu_index)
if mode != "Exclusive_Process":
logger.warning(f"Cannot set SM limit. GPU {gpu_index} is in '{mode}' mode, not 'Exclusive_Process'.")
return False

os.environ["CUDA_MPS_ACTIVE_THREAD_PERCENTAGE"] = str(percent)
logger.info(f"Set CUDA_MPS_ACTIVE_THREAD_PERCENTAGE to {percent}% for GPU {gpu_index}.")
return True
12 changes: 12 additions & 0 deletions lightllm/utils/start_utils.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,8 @@ def start_submodule_processes(self, start_funcs=[], start_args=[]):
return

def terminate_all_processes(self):
from lightllm.utils.envs_utils import get_env_start_args

def kill_recursive(proc):
try:
parent = psutil.Process(proc.pid)
@@ -57,6 +59,16 @@ def kill_recursive(proc):
if proc.is_alive():
kill_recursive(proc)
proc.join()

# recover the gpu compute mode
is_enable_mps = get_env_start_args().enable_mps
world_size = get_env_start_args().tp
if is_enable_mps:
from lightllm.utils.device_utils import stop_mps, set_gpu_default_mode

stop_mps()
for i in range(world_size):
set_gpu_default_mode(gpu_index=i)
logger.info("All processes terminated gracefully.")