diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index c4d74ea5b4..7b5322ba02 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -52,7 +52,7 @@ logger = get_logger('lmdeploy') -class VariableInterface: +class ServerContext: """A IO interface maintaining variables.""" async_engine: AsyncEngine = None api_keys: Optional[List[str]] = None @@ -69,7 +69,7 @@ class VariableInterface: @staticmethod def get_session(session_id: int) -> int: - session_mgr = VariableInterface.get_session_manager() + session_mgr = ServerContext.get_session_manager() if session_id == -1: return session_mgr.get() else: @@ -77,16 +77,16 @@ def get_session(session_id: int) -> int: @staticmethod def get_session_manager(): - return VariableInterface.async_engine.session_mgr + return ServerContext.async_engine.session_mgr @staticmethod def get_engine_config(): - return VariableInterface.async_engine.backend_config + return ServerContext.async_engine.backend_config router = APIRouter() get_bearer_token = HTTPBearer(auto_error=False) -server_context = VariableInterface() +server_context = ServerContext() async def check_api_key(auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), ) -> str: @@ -94,8 +94,8 @@ async def check_api_key(auth: Optional[HTTPAuthorizationCredentials] = Depends(g Adopted from https://github.com/lm-sys/FastChat/blob/v0.2.35/fastchat/serve/openai_api_server.py#L108-L127 """ # noqa - if VariableInterface.api_keys: - if auth is None or (token := auth.credentials) not in VariableInterface.api_keys: + if ServerContext.api_keys: + if auth is None or (token := auth.credentials) not in ServerContext.api_keys: raise HTTPException( status_code=401, detail={ @@ -118,8 +118,8 @@ def get_model_list(): If it is a slora serving. The model list would be [model_name, adapter_name1, adapter_name2, ...] """ - model_names = [VariableInterface.async_engine.model_name] - cfg = VariableInterface.async_engine.backend_config + model_names = [ServerContext.async_engine.model_name] + cfg = ServerContext.async_engine.backend_config model_names += getattr(cfg, 'adapters', None) or [] return model_names @@ -279,7 +279,7 @@ async def terminate(): """Terminate server.""" import signal - if not VariableInterface.allow_terminate_by_client: + if not ServerContext.allow_terminate_by_client: return create_error_response( HTTPStatus.BAD_REQUEST, 'The server can not be terminated. Please add --allow-terminate-by-client when start the server.') @@ -402,7 +402,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque error_check_ret = check_request(request) if error_check_ret is not None: return error_check_ret - session = VariableInterface.get_session(request.session_id) + session = ServerContext.get_session(request.session_id) json_request = await raw_request.json() migration_request = json_request.pop('migration_request', None) @@ -413,12 +413,12 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque model_name = request.model adapter_name = None - if model_name != VariableInterface.async_engine.model_name: + if model_name != ServerContext.async_engine.model_name: adapter_name = model_name # got a adapter name request_id = str(session.session_id) created_time = int(time.time()) gpt_oss_parser = None - if VariableInterface.async_engine.arch == 'GptOssForCausalLM': + if ServerContext.async_engine.arch == 'GptOssForCausalLM': gpt_oss_parser = GptOssChatParser() if isinstance(request.stop, str): @@ -434,7 +434,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque if request.logit_bias is not None: try: logits_processors = [ - logit_bias_logits_processor(request.logit_bias, VariableInterface.async_engine.tokenizer.model) + logit_bias_logits_processor(request.logit_bias, ServerContext.async_engine.tokenizer.model) ] except Exception as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -496,7 +496,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque else: logger.warning('`enable_thinking` in `chat_template_kwargs` will override the value in request.') enable_thinking = chat_template_kwargs.get('enable_thinking', None) - result_generator = VariableInterface.async_engine.generate( + result_generator = ServerContext.async_engine.generate( request.messages, session, gen_config=gen_config, @@ -536,12 +536,12 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: previous_token_ids = [] current_token_ids = [] delta_token_ids = [] - has_parser = VariableInterface.tool_parser is not None or VariableInterface.reasoning_parser is not None + has_parser = ServerContext.tool_parser is not None or ServerContext.reasoning_parser is not None streaming_tools = False async for res in result_generator: logprobs, usage = None, None if gen_logprobs and res.logprobs: - logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, res.token_ids, + logprobs = _create_chat_completion_logprobs(ServerContext.async_engine.tokenizer, res.token_ids, res.logprobs) # Only stream chunk `usage` in the final chunk according to OpenAI API spec if (res.finish_reason and request.stream_options and request.stream_options.include_usage): @@ -562,10 +562,10 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: if has_parser: current_text = current_text + res.response current_token_ids = current_token_ids + delta_token_ids - if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: + if request.tool_choice != 'none' and ServerContext.tool_parser is not None: if res.finish_reason == 'stop' and streaming_tools is True: res.finish_reason = 'tool_calls' - tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming( + tool_delta = ServerContext.tool_parser.extract_tool_calls_streaming( previous_text=previous_text, current_text=current_text, delta_text=delta_message.content, @@ -579,10 +579,10 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: if isinstance(tool_delta.tool_calls, List) and len(tool_delta.tool_calls): streaming_tools = True elif (request.tool_choice != 'none' and request.tools is not None - and VariableInterface.tool_parser is None): + and ServerContext.tool_parser is None): logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.') - if VariableInterface.reasoning_parser is not None and enable_thinking is not False: - reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming( + if ServerContext.reasoning_parser is not None and enable_thinking is not False: + reasoning_delta = ServerContext.reasoning_parser.extract_reasoning_content_streaming( previous_text=previous_text, current_text=current_text, delta_text=delta_message.content or '', @@ -640,9 +640,9 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: else: tool_calls = None reasoning_content = None - if request.tool_choice != 'none' and VariableInterface.tool_parser is not None: + if request.tool_choice != 'none' and ServerContext.tool_parser is not None: try: - tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request) + tool_call_info = ServerContext.tool_parser.extract_tool_calls(text, request=request) text, tool_calls = tool_call_info.content, tool_call_info.tool_calls if isinstance(tool_calls, List) and len(tool_calls): if final_res.finish_reason == 'stop': @@ -651,11 +651,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: except Exception as e: logger.error(f'Failed to parse {text}. Exception: {e}.') return create_error_response(HTTPStatus.BAD_REQUEST, 'Failed to parse fc related info to json format!') - elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None: + elif request.tool_choice != 'none' and request.tools is not None and ServerContext.tool_parser is None: logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.') - if VariableInterface.reasoning_parser is not None and enable_thinking is not False: - reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request) + if ServerContext.reasoning_parser is not None and enable_thinking is not False: + reasoning_content, text = ServerContext.reasoning_parser.extract_reasoning_content(text, request) message = ChatMessage(role='assistant', content=text, @@ -664,7 +664,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: logprobs = None if gen_logprobs and len(final_logprobs): - logprobs = _create_chat_completion_logprobs(VariableInterface.async_engine.tokenizer, final_token_ids, + logprobs = _create_chat_completion_logprobs(ServerContext.async_engine.tokenizer, final_token_ids, final_logprobs) assert final_res is not None @@ -769,17 +769,17 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None model_name = request.model adapter_name = None - if model_name != VariableInterface.async_engine.model_name: + if model_name != ServerContext.async_engine.model_name: adapter_name = model_name # got a adapter name request_id = str(request.session_id) created_time = int(time.time()) sessions = [] if isinstance(request.prompt, str): request.prompt = [request.prompt] - sessions.append(VariableInterface.get_session(request.session_id)) + sessions.append(ServerContext.get_session(request.session_id)) elif isinstance(request.prompt, list): for i in range(len(request.prompt)): - sessions.append(VariableInterface.get_session(i + 1)) + sessions.append(ServerContext.get_session(i+1)) if isinstance(request.stop, str): request.stop = [request.stop] random_seed = request.seed if request.seed else None @@ -805,7 +805,7 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None ) generators = [] for prompt, session in zip(request.prompt, sessions): - result_generator = VariableInterface.async_engine.generate( + result_generator = ServerContext.async_engine.generate( prompt, session, gen_config=gen_config, @@ -848,7 +848,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: usage = None if request.logprobs and res.logprobs: logprobs, offset, all_token_ids, state = _create_completion_logprobs( # noqa E501 - VariableInterface.async_engine.tokenizer, res.token_ids, res.logprobs, + ServerContext.async_engine.tokenizer, res.token_ids, res.logprobs, gen_config.skip_special_tokens, offset, all_token_ids, state, gen_config.spaces_between_special_tokens) # Only stream chunk `usage` in the final chunk according to OpenAI API spec @@ -894,7 +894,7 @@ async def _inner_call(i, generator): async for res in generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await VariableInterface.async_engine.stop_session(request.session_id) + await ServerContext.async_engine.stop_session(request.session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res text += res.response @@ -908,7 +908,7 @@ async def _inner_call(i, generator): logprobs = None if request.logprobs and len(final_logprobs): logprobs, _, _, _ = _create_completion_logprobs( - VariableInterface.async_engine.tokenizer, + ServerContext.async_engine.tokenizer, final_token_ids, final_logprobs, gen_config.skip_special_tokens, @@ -953,7 +953,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): error_check_ret = check_request(request) if error_check_ret is not None: return error_check_ret - session = VariableInterface.get_session(request.session_id) + session = ServerContext.get_session(request.session_id) prompt = request.prompt input_ids = request.input_ids @@ -990,7 +990,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): return_routed_experts=request.return_routed_experts, ) - result_generator = VariableInterface.async_engine.generate( + result_generator = ServerContext.async_engine.generate( messages=prompt, session_id=session, input_ids=input_ids, @@ -1082,8 +1082,8 @@ async def encode(request: EncodeRequest, raw_request: Request = None): def encode(prompt: str, do_preprocess: bool, add_bos: bool): if do_preprocess: - prompt = VariableInterface.async_engine.chat_template.get_prompt(prompt, sequence_start=add_bos) - input_ids = VariableInterface.async_engine.tokenizer.encode(prompt, add_bos=add_bos) + prompt = ServerContext.async_engine.chat_template.get_prompt(prompt, sequence_start=add_bos) + input_ids = ServerContext.async_engine.tokenizer.encode(prompt, add_bos=add_bos) return input_ids if isinstance(request.input, str): @@ -1114,7 +1114,7 @@ async def pooling(request: PoolingRequest, raw_request: Request = None): - **input** (List[int] | List[List[int]] | str | List[str]): input text to be embed """ - async_engine = VariableInterface.async_engine + async_engine = ServerContext.async_engine request_input = request.input model_name = request.model or async_engine.model_name @@ -1155,14 +1155,14 @@ async def pooling(request: PoolingRequest, raw_request: Request = None): @router.post('/update_weights', dependencies=[Depends(check_api_key)]) def update_params(request: UpdateParamsRequest, raw_request: Request = None): """Update weights for the model.""" - VariableInterface.async_engine.engine.update_params(request) + ServerContext.async_engine.engine.update_params(request) return JSONResponse(content=None) @router.post('/sleep', dependencies=[Depends(check_api_key)]) async def sleep(raw_request: Request = None): level = raw_request.query_params.get('level', '1') - VariableInterface.async_engine.sleep(int(level)) + ServerContext.async_engine.sleep(int(level)) return Response(status_code=200) @@ -1170,13 +1170,13 @@ async def sleep(raw_request: Request = None): async def wakeup(raw_request: Request = None): tags = raw_request.query_params.getlist('tags') tags = tags or None - VariableInterface.async_engine.wakeup(tags) + ServerContext.async_engine.wakeup(tags) return Response(status_code=200) @router.get('/is_sleeping', dependencies=[Depends(check_api_key)]) async def is_sleeping(raw_request: Request = None): - is_sleeping = VariableInterface.async_engine.is_sleeping + is_sleeping = ServerContext.async_engine.is_sleeping return JSONResponse(content={'is_sleeping': is_sleeping}) @@ -1185,7 +1185,7 @@ async def is_sleeping(raw_request: Request = None): @router.get('/distserve/engine_info') async def engine_info(): - engine_config = VariableInterface.async_engine.backend_config + engine_config = ServerContext.async_engine.backend_config response = DistServeEngineConfig(tp_size=engine_config.tp, dp_size=engine_config.dp, @@ -1201,23 +1201,23 @@ async def engine_info(): @router.post('/distserve/p2p_initialize') async def p2p_initialize(init_request: DistServeInitRequest): - return VariableInterface.async_engine.p2p_initialize(init_request) + return ServerContext.async_engine.p2p_initialize(init_request) @router.post('/distserve/p2p_connect') async def p2p_connect(conn_request: DistServeConnectionRequest): - return VariableInterface.async_engine.p2p_connect(conn_request) + return ServerContext.async_engine.p2p_connect(conn_request) @router.post('/distserve/p2p_drop_connect') async def p2p_drop_connect(drop_conn_request: DistServeDropConnectionRequest): - return VariableInterface.async_engine.p2p_drop_connect(drop_conn_request) + return ServerContext.async_engine.p2p_drop_connect(drop_conn_request) @router.post('/distserve/free_cache') async def free_cache(cache_free_request: DistServeCacheFreeRequest) -> JSONResponse: session_id = cache_free_request.remote_session_id - VariableInterface.async_engine.free_cache(session_id) + ServerContext.async_engine.free_cache(session_id) return {'status': 'SUCCESS'} @@ -1227,15 +1227,15 @@ async def free_cache(cache_free_request: DistServeCacheFreeRequest) -> JSONRespo @router.post('/abort_request') async def abort_request(request: AbortRequest, raw_request: Request = None): """Abort an ongoing request.""" - if not VariableInterface.enable_abort_handling: + if not ServerContext.enable_abort_handling: return Response( status_code=501, content='This server does not support abort requests. Enable with --enable-abort-handling flag.') if request.abort_all: - await VariableInterface.async_engine.stop_all_session() + await ServerContext.async_engine.stop_all_session() else: - session = VariableInterface.get_session(request.session_id) + session = ServerContext.get_session(request.session_id) await session.async_abort() return Response(status_code=200) @@ -1262,20 +1262,20 @@ def dummy_get_device_id(): @router.on_event('startup') async def startup_event(): - async_engine = VariableInterface.async_engine + async_engine = ServerContext.async_engine async_engine.start_loop(asyncio.get_running_loop(), use_async_api=True) - if VariableInterface.proxy_url is None: + if ServerContext.proxy_url is None: return elif getattr(async_engine.engine, 'is_dummy', False): logger.info('Dummy node started') return try: import requests - engine_config = VariableInterface.async_engine.backend_config + engine_config = ServerContext.async_engine.backend_config engine_role = engine_config.role.value if hasattr(engine_config, 'role') else 1 - url = f'{VariableInterface.proxy_url}/nodes/add' - data = {'url': VariableInterface.api_server_url, 'status': {'models': get_model_list(), 'role': engine_role}} + url = f'{ServerContext.proxy_url}/nodes/add' + data = {'url': ServerContext.api_server_url, 'status': {'models': get_model_list(), 'role': engine_role}} headers = {'accept': 'application/json', 'Content-Type': 'application/json'} response = requests.post(url, headers=headers, json=data) @@ -1287,7 +1287,7 @@ async def startup_event(): @router.on_event('shutdown') async def shutdown_event(): - async_engine = VariableInterface.async_engine + async_engine = ServerContext.async_engine if async_engine is not None: async_engine.close() @@ -1320,8 +1320,8 @@ def set_parsers(reasoning_parser: Optional[str] = None, tool_parser: Optional[st # set reasoning parser if reasoning_parser is not None: if reasoning_parser in ReasoningParserManager.module_dict: - tokenizer = VariableInterface.async_engine.tokenizer - VariableInterface.reasoning_parser = ReasoningParserManager.get(reasoning_parser)(tokenizer) + tokenizer = ServerContext.async_engine.tokenizer + ServerContext.reasoning_parser = ReasoningParserManager.get(reasoning_parser)(tokenizer) else: raise ValueError( f'The reasoning parser {reasoning_parser} is not in the parser list: {ReasoningParserManager.module_dict.keys()}' # noqa @@ -1329,8 +1329,8 @@ def set_parsers(reasoning_parser: Optional[str] = None, tool_parser: Optional[st # set tool parsers if tool_parser is not None: if tool_parser in ToolParserManager.module_dict: - tokenizer = VariableInterface.async_engine.tokenizer - VariableInterface.tool_parser = ToolParserManager.get(tool_parser)(tokenizer) + tokenizer = ServerContext.async_engine.tokenizer + ServerContext.tool_parser = ToolParserManager.get(tool_parser)(tokenizer) else: raise ValueError( f'The reasoning parser {tool_parser} is not in the parser list: {ToolParserManager.module_dict.keys()}' # noqa @@ -1465,12 +1465,12 @@ def serve(model_path: str, os.environ['TM_LOG_LEVEL'] = log_level logger.setLevel(log_level) - VariableInterface.allow_terminate_by_client = allow_terminate_by_client - VariableInterface.enable_abort_handling = enable_abort_handling + ServerContext.allow_terminate_by_client = allow_terminate_by_client + ServerContext.enable_abort_handling = enable_abort_handling if api_keys is not None: if isinstance(api_keys, str): api_keys = api_keys.split(',') - VariableInterface.api_keys = api_keys + ServerContext.api_keys = api_keys ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http' if ssl: ssl_keyfile = os.environ['SSL_KEYFILE'] @@ -1484,19 +1484,19 @@ def serve(model_path: str, # router replay if backend_config.enable_return_routed_experts: backend_config.enable_transfer_obj_ref = True - VariableInterface.async_engine = pipeline_class(model_path=model_path, - model_name=model_name, - backend=backend, - backend_config=backend_config, - chat_template_config=chat_template_config, - max_log_len=max_log_len, - speculative_config=speculative_config, - **kwargs) + ServerContext.async_engine = pipeline_class(model_path=model_path, + model_name=model_name, + backend=backend, + backend_config=backend_config, + chat_template_config=chat_template_config, + max_log_len=max_log_len, + speculative_config=speculative_config, + **kwargs) # set reasoning parser and tool parser set_parsers(reasoning_parser, tool_call_parser) # create FastAPI lifespan events - lifespan = create_lifespan_handler(backend_config, VariableInterface.async_engine) + lifespan = create_lifespan_handler(backend_config, ServerContext.async_engine) if disable_fastapi_docs: app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None, lifespan=lifespan) @@ -1521,8 +1521,8 @@ def serve(model_path: str, app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests) if proxy_url is not None: - VariableInterface.proxy_url = proxy_url - VariableInterface.api_server_url = f'{http_or_https}://{server_name}:{server_port}' # noqa + ServerContext.proxy_url = proxy_url + ServerContext.api_server_url = f'{http_or_https}://{server_name}:{server_port}' # noqa for i in range(3): print(f'HINT: Please open \033[93m\033[1m{http_or_https}://' f'{server_name}:{server_port}\033[0m in a browser for detailed api' diff --git a/lmdeploy/serve/openai/serving_chat_completion.py b/lmdeploy/serve/openai/serving_chat_completion.py index 415b2afeb7..8d77c98cb6 100644 --- a/lmdeploy/serve/openai/serving_chat_completion.py +++ b/lmdeploy/serve/openai/serving_chat_completion.py @@ -4,10 +4,10 @@ from .protocol import ChatCompletionRequest if TYPE_CHECKING: - from .api_server import VariableInterface + from .api_server import ServerContext -def check_request(request: ChatCompletionRequest, server_context: 'VariableInterface') -> str: +def check_request(request: ChatCompletionRequest, server_context: 'ServerContext') -> str: engine_config = server_context.get_engine_config() session_manager = server_context.get_session_manager() try: diff --git a/lmdeploy/serve/openai/serving_completion.py b/lmdeploy/serve/openai/serving_completion.py index 759972db36..86ad7fa348 100644 --- a/lmdeploy/serve/openai/serving_completion.py +++ b/lmdeploy/serve/openai/serving_completion.py @@ -4,10 +4,10 @@ from .protocol import CompletionRequest if TYPE_CHECKING: - from .api_server import VariableInterface + from .api_server import ServerContext -def check_request(request: CompletionRequest, server_context: 'VariableInterface') -> str: +def check_request(request: CompletionRequest, server_context: 'ServerContext') -> str: engine_config = server_context.get_engine_config() session_manager = server_context.get_session_manager() try: diff --git a/lmdeploy/serve/openai/serving_generate.py b/lmdeploy/serve/openai/serving_generate.py index 4615d9caea..7d6d483bfc 100644 --- a/lmdeploy/serve/openai/serving_generate.py +++ b/lmdeploy/serve/openai/serving_generate.py @@ -4,10 +4,10 @@ from .protocol import GenerateReqInput if TYPE_CHECKING: - from .api_server import VariableInterface + from .api_server import ServerContext -def check_request(request: GenerateReqInput, server_context: 'VariableInterface') -> str: +def check_request(request: GenerateReqInput, server_context: 'ServerContext') -> str: engine_config = server_context.get_engine_config() session_manager = server_context.get_session_manager() try: diff --git a/lmdeploy/serve/proxy/node_manager.py b/lmdeploy/serve/proxy/node_manager.py new file mode 100644 index 0000000000..241157a87e --- /dev/null +++ b/lmdeploy/serve/proxy/node_manager.py @@ -0,0 +1,306 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import asyncio +import json +import os +import random +import threading +import time +from typing import TYPE_CHECKING, Dict, Optional + +import aiohttp +import numpy as np +import requests + +from lmdeploy.pytorch.disagg.config import EngineRole +from lmdeploy.serve.proxy.constants import AIOHTTP_TIMEOUT, ErrorCodes, RoutingStrategy, err_msg +from lmdeploy.utils import get_logger + +if TYPE_CHECKING: + from .proxy import Status + +logger = get_logger('lmdeploy') + + +class Connector: + """Connector class responsible for creating and managing aiohttp + ClientSession.""" + + def __init__(self): + self.limits = int(os.getenv('LMDEPLOY_AIOHTTP_LIMITS', 1024)) + self.limits_per_host = int(os.getenv('LMDEPLOY_AIOHTTP_LIMITS_PER_HOST', 128)) + self._session = None + + async def get_session(self) -> aiohttp.ClientSession: + """Get the shared session.""" + if self._session is None or self._session.closed: + connector = aiohttp.TCPConnector( + limit=self.limits, + limit_per_host=self.limits_per_host, + force_close=False, # Keep connections alive + ) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT), + ) + return self._session + + async def cleanup(self): + """Cleanup resources, close session.""" + if self._session and not self._session.closed: + await self._session.close() + + async def update(self, num_hosts: int): + """Update the limits based on number of hosts.""" + new_limit = self.limits_per_host * num_hosts + # Only update if the limit changed significantly + if abs(new_limit - self.limits) > self.limits_per_host: + self.limits = new_limit + await self.cleanup() + await self.get_session() + + +connector = Connector() + + +class Node: + """Node class responsible for sending requests and receiving responses. + + A Node represents an API server and can handle concurrent requests to that server. All nodes share a common + ClientSession managed by NodeManager for efficient connection pooling and reuse. + """ + + def __init__(self, url: str, status: 'Status'): + """Initialize a Node. + + Args: + url (str): The node URL. + status (Status, optional): The node status. + """ + self.url = url + self.status = status + + async def _make_request(self, request: Dict, endpoint: str): + """Make HTTP POST request to the node.""" + session = await connector.get_session() + return await session.post(self.url + endpoint, json=request) + + async def stream_generate(self, request: Dict, endpoint: str): + """Return a generator to handle the input request.""" + try: + async with await self._make_request(request, endpoint) as response: + async for line in response.content: + if line.strip(): + yield line + b'\n\n' + except (Exception, GeneratorExit, aiohttp.ClientError) as e: + logger.error(f'Exception in stream_generate: {e}') + yield self.handle_api_timeout() + + async def generate(self, request: Dict, endpoint: str): + """Return the response of the input request.""" + try: + async with await self._make_request(request, endpoint) as response: + return await response.text() + except Exception as e: + logger.error(f'Exception in generate: {e}') + return self.handle_api_timeout() + + def pre_call(self): + """Preprocess before the request get processed.""" + self.status.unfinished += 1 + return time.time() + + def post_call(self, start: float): + """Post process after the response finished.""" + self.status.unfinished -= 1 + self.status.latency.append(time.time() - start) + + def handle_api_timeout(self): + """Handle the api time out.""" + logger.warning(f'api timeout: {self.url}') + return json.dumps({ + 'error_code': ErrorCodes.API_TIMEOUT.value, + 'text': err_msg[ErrorCodes.API_TIMEOUT], + }).encode() + b'\n' + + +CONTROLLER_HEART_BEAT_EXPIRATION = int(os.getenv('LMDEPLOY_CONTROLLER_HEART_BEAT_EXPIRATION', 90)) + + +def heart_beat_controller(proxy_controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + logger.info('Start heart beat check') + proxy_controller.remove_stale_nodes_by_expiration() + + +class NodeManager: + """Manage all the api_servers, each of which is defined as a Node + object.""" + + def __init__(self) -> None: + self.nodes = {} + self.routing_strategy = RoutingStrategy.MIN_EXPECTED_LATENCY + self._nodes_cache: Dict[EngineRole, Dict[str, Node]] = {} + self._nodes_cache_dirty = True + + self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self, ), daemon=True) + self.heart_beat_thread.start() + + def _invalidate_nodes_cache(self): + """Mark node cache as invalid.""" + self._nodes_cache_dirty = True + + def get_nodes(self, role: EngineRole) -> Dict[str, Node]: + """Get nodes for the specified role, using cache.""" + if self._nodes_cache_dirty or role not in self._nodes_cache: + self._nodes_cache = {} + for node_url, node_status in self.nodes.items(): + node_role = node_status.role + if node_role not in self._nodes_cache: + self._nodes_cache[node_role] = {} + self._nodes_cache[node_role][node_url] = Node(url=node_url, status=node_status) + self._nodes_cache_dirty = False + return self._nodes_cache.get(role, {}) + + @property + def hybrid_nodes(self): + return self.get_nodes(EngineRole.Hybrid) + + @property + def prefill_nodes(self): + return self.get_nodes(EngineRole.Prefill) + + @property + def decode_nodes(self): + return self.get_nodes(EngineRole.Decode) + + async def add(self, node_url: str, status: 'Status'): + """Add a node.""" + self.nodes[node_url] = status + self._invalidate_nodes_cache() + await connector.update(len(self.nodes)) + + async def remove(self, node_url: str): + """Remove a node.""" + if node_url not in self.nodes: + raise ValueError(f'Node {node_url} does not exist') + + self.nodes.pop(node_url) + self._invalidate_nodes_cache() + await connector.update(len(self.nodes)) + + async def terminate_node(self, node_url: str): + """Terminate a node.""" + if node_url not in self.nodes: + raise KeyError(f'Node {node_url} does not exist') + + self.nodes.pop(node_url) + self._invalidate_nodes_cache() + + session = await connector.get_session() + async with session.get(f'{node_url}/terminate', headers={'accept': 'application/json'}) as response: + if response.status != 200: + text = await response.text() + raise RuntimeError(f'Failed to terminate node {node_url}, status={response.status}, msg={text}') + + async def terminate_all_nodes(self): + """Terminate all nodes. + + Raises: + RuntimeError: If any node termination fails. + """ + if not self.nodes: + return + + node_urls = list(self.nodes.keys()) + results = await asyncio.gather(*[self.terminate_node(url) for url in node_urls], return_exceptions=True) + + # Check for failures + failures = [r for r in results if isinstance(r, Exception)] + if failures: + failed_count = len(failures) + total_count = len(node_urls) + error_msg = f'Failed to terminate {failed_count}/{total_count} nodes' + logger.error(f'{error_msg}: {[str(f) for f in failures]}') + raise RuntimeError(error_msg) + + def remove_stale_nodes_by_expiration(self): + """Remove stale nodes.""" + headers = {'accept': 'application/json'} + to_be_deleted = [url for url in self.nodes.keys() if not self._check_node_health(url, headers)] + for node_url in to_be_deleted: + # Note: remove is async but we can't await here in sync method + # The node will be removed from dict, but async cleanup won't happen + if node_url in self.nodes: + self.nodes.pop(node_url) + self._invalidate_nodes_cache() + logger.info(f'Removed node {node_url} due to heart beat expiration') + + def _check_node_health(self, node_url: str, headers: Dict) -> bool: + """Check if a node is healthy.""" + try: + response = requests.get(f'{node_url}/health', headers=headers) + return response.status_code == 200 + except Exception: + return False + + @property + def model_list(self): + """Supported model list.""" + return [model for status in self.nodes.values() for model in status.models] + + def _get_matched_nodes(self, model_name: str, role: EngineRole): + """Get matched nodes and their speeds for the model.""" + nodes_with_speeds, speeds, nodes_without_speeds = [], [], [] + for node in self.get_nodes(role).values(): + if model_name in node.status.models: + if node.status.speed is not None: + nodes_with_speeds.append(node) + speeds.append(node.status.speed) + else: + nodes_without_speeds.append(node) + + if not nodes_with_speeds and not nodes_without_speeds: + return None, None + + all_nodes = nodes_with_speeds + nodes_without_speeds + avg_speed = sum(speeds) / len(speeds) if speeds else 1 + all_speeds = speeds + [avg_speed] * len(nodes_without_speeds) + return all_nodes, all_speeds + + def get_node(self, model_name: str, role: EngineRole = EngineRole.Hybrid) -> Optional[Node]: + """Get a node for the specified model and role.""" + if self.routing_strategy == RoutingStrategy.RANDOM: + nodes, speeds = self._get_matched_nodes(model_name, role) + if not nodes: + return None + weights = [s / sum(speeds) for s in speeds] + return random.choices(nodes, weights=weights)[0] + + elif self.routing_strategy == RoutingStrategy.MIN_EXPECTED_LATENCY: + nodes, speeds = self._get_matched_nodes(model_name, role) + if not nodes: + return None + indexes = list(range(len(nodes))) + random.shuffle(indexes) + min_index = min(indexes, key=lambda i: nodes[i].status.unfinished / speeds[i]) + return nodes[min_index] + + elif self.routing_strategy == RoutingStrategy.MIN_OBSERVED_LATENCY: + nodes, latencies = [], [] + for node in self.get_nodes(role).values(): + if model_name in node.status.models: + nodes.append(node) + latencies.append(np.mean(node.status.latency) if node.status.latency else float('inf')) + if not nodes: + return None + return nodes[np.argmin(latencies)] + + else: + raise ValueError(f'Invalid strategy: {self.routing_strategy}') + + def get_node_url(self, model_name: str, role: EngineRole = EngineRole.Hybrid) -> Optional[str]: + """Get node URL.""" + node = self.get_node(model_name, role) + return node.url if node else None diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index a1852b2627..ad2981c01d 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -4,19 +4,15 @@ import copy import json import os -import os.path as osp -import random -import threading -import time from collections import deque -from http import HTTPStatus -from typing import Deque, Dict, List, Literal, Optional, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Deque, Literal, Optional, Tuple, Union + +if TYPE_CHECKING: + from .node_manager import Node -import aiohttp -import numpy as np -import requests import uvicorn -from fastapi import BackgroundTasks, Depends, FastAPI, Request +from fastapi import BackgroundTasks, Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field @@ -25,383 +21,43 @@ from lmdeploy.pytorch.disagg.conn.protocol import MigrationProtocol, MigrationRequest from lmdeploy.pytorch.disagg.conn.proxy_conn import PDConnectionPool from lmdeploy.pytorch.disagg.messages import PDConnectionMessage -from lmdeploy.serve.openai.api_server import check_api_key, create_error_response +from lmdeploy.serve.openai.api_server import check_api_key from lmdeploy.serve.openai.protocol import ModelCard # noqa: E501 from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest, ModelList, ModelPermission -from lmdeploy.serve.proxy.constants import AIOHTTP_TIMEOUT, LATENCY_DEQUE_LEN, ErrorCodes, RoutingStrategy, err_msg +from lmdeploy.serve.proxy.constants import LATENCY_DEQUE_LEN, RoutingStrategy +from lmdeploy.serve.proxy.node_manager import NodeManager, connector from lmdeploy.utils import get_logger -logger = get_logger('lmdeploy') +from .constants import ErrorCodes class Status(BaseModel): """Status protocol consists of models' information.""" role: EngineRole = EngineRole.Hybrid - models: Optional[List[str]] = Field(default=[], examples=[[]]) + models: Optional[list[str]] = Field(default_factory=list, examples=[[]]) unfinished: int = 0 - latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN), examples=[[]]) + latency: Deque = Field(default_factory=lambda: deque(maxlen=LATENCY_DEQUE_LEN), examples=[[]]) speed: Optional[int] = Field(default=None, examples=[None]) -class Node(BaseModel): - """Node protocol consists of url and status.""" +class NodeModel(BaseModel): + """Node protocol for API requests (Pydantic model).""" url: str status: Optional[Status] = None -CONTROLLER_HEART_BEAT_EXPIRATION = int(os.getenv('LMDEPLOY_CONTROLLER_HEART_BEAT_EXPIRATION', 90)) - - -def heart_beat_controller(proxy_controller): - while True: - time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) - logger.info('Start heart beat check') - proxy_controller.remove_stale_nodes_by_expiration() - - -class NodeManager: - """Manage all the sub nodes. - - Args: - config_path (str): the path of the config file. - strategy (str): the strategy to dispatch node to handle the requests. - - **random**: not fully radom, but decided by the speed of nodes. - - **min_expected_latency**: will compute the expected latency to - process the requests. The sooner of the node, the more requests - will be dispatched to it. - - **min_observed_latency**: Based on previous finished requests. The - sooner they get processed, the more requests will be dispatched - to. - """ - - def __init__(self, - config_path: Optional[str] = None, - serving_strategy: str = 'Hybrid', - routing_strategy: str = 'min_expected_latency', - migration_protocol: str = 'RDMA', - link_type: str = 'RoCE', - with_gdr: bool = True, - cache_status: Optional[bool] = True) -> None: - self.nodes = dict() - self.serving_strategy = ServingStrategy[serving_strategy] - self.routing_strategy = RoutingStrategy.from_str(routing_strategy) - - self.cache_status = cache_status - self.latencies = dict() - self.config_path = osp.join(osp.dirname(osp.realpath(__file__)), 'proxy_config.json') - if config_path is not None: - self.config_path = config_path - if osp.exists(self.config_path) and self.cache_status: - with open(self.config_path, 'r') as config_file: - if os.path.getsize(self.config_path) > 0: - logger.info(f'loading node configuration: {self.config_path}') - config = json.load(config_file) - self.nodes = { - node_url: Status.model_validate_json(node_status) - for node_url, node_status in config.items() - } - self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self, ), daemon=True) - self.heart_beat_thread.start() - self.aiotimeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) - - # For PD Disaggregation - self.migration_protocol = MigrationProtocol[migration_protocol] - self.rdma_config = DistServeRDMAConfig(with_gdr=with_gdr, link_type=RDMALinkType[link_type]) - self.pd_connection_pool = PDConnectionPool() - self.dummy_prefill = False - - def get_nodes(self, role: EngineRole) -> Dict: - items = list(self.nodes.items()) - return {node_url: node_status for (node_url, node_status) in items if node_status.role == role} - - @property - def hybrid_nodes(self): - return self.get_nodes(EngineRole.Hybrid) - - @property - def prefill_nodes(self): - return self.get_nodes(EngineRole.Prefill) - - @property - def decode_nodes(self): - return self.get_nodes(EngineRole.Decode) - - def update_config_file(self): - """Update the config file.""" - nodes = copy.deepcopy(self.nodes) - for _, status in nodes.items(): - status.latency = deque(list(status.latency)[-LATENCY_DEQUE_LEN:]) - if self.cache_status: - with open(self.config_path, 'w') as config_file: # update cfg yml - json.dump({ - node_url: node_status.model_dump_json() - for node_url, node_status in nodes.items() - }, - config_file, - indent=2) - - def add(self, node_url: str, status: Optional[Status] = None): - """Add a node to the manager. - - Args: - node_url (str): A http url. Can be the url generated by - `lmdeploy serve api_server`. - description (Dict): The description of the node. An example: - {'http://0.0.0.0:23333': {models: ['internlm-chat-7b]}, - speed: -1}. The speed here can be RPM or other metric. All the - values of nodes should be the same metric. - """ - if status is None: - status = self.nodes.get(node_url, Status()) - if status.models != []: # force register directly - self.remove(node_url) - self.nodes[node_url] = status - self.update_config_file() - return - try: - from lmdeploy.serve.openai.api_client import APIClient - client = APIClient(api_server_url=node_url) - status.models = client.available_models - self.nodes[node_url] = status - except requests.exceptions.RequestException as e: # noqa - logger.error(f'exception happened when adding node {node_url}, {e}') - return self.handle_api_timeout(node_url) - self.update_config_file() - - def remove(self, node_url: str): - """Remove a node.""" - if node_url in self.nodes.keys(): - self.nodes.pop(node_url) - self.update_config_file() - self.pd_connection_pool.dereg_instance(node_url) - - def terminate_node(self, node_url: str): - """Terminate a node.""" - success = True - if node_url in self.nodes: - self.nodes.pop(node_url) - headers = {'accept': 'application/json'} - try: - response = requests.get(f'{node_url}/terminate', headers=headers) - if response.status_code != 200: - success = False - logger.error(f'Failed to terminate node {node_url}, ' - f'error_code={response.status_code}, ' - f'error_msg={response.text}') - except Exception as e: # noqa - logger.error(f'exception happened when terminating node {node_url}, {e}') - success = False - else: - logger.error(f'terminating node {node_url} failed since it does not exist. ' - 'May try /nodes/status to check the node list') - success = False - self.update_config_file() - return success - - def terminate_all_nodes(self): - """Terminate all nodes.""" - node_url_li = list(self.nodes.keys()) - all_success = True - for node_url in node_url_li: - if not self.terminate_node(node_url): - all_success = False - return all_success - - def remove_stale_nodes_by_expiration(self): - """Remove stale nodes.""" - to_be_deleted = [] - node_urls = list(self.nodes.keys()) - for node_url in node_urls: - url = f'{node_url}/health' - headers = {'accept': 'application/json'} - try: - response = requests.get(url, headers=headers) - if response.status_code != 200: - to_be_deleted.append(node_url) - except: # noqa - to_be_deleted.append(node_url) - for node_url in to_be_deleted: - self.remove(node_url) - logger.info(f'Removed node_url: {node_url} ' - 'due to heart beat expiration') - - @property - def model_list(self): - """Supported model list.""" - model_names = [] - items = list(self.nodes.items()) - for _, status in items: - model_names.extend(status.models) - return model_names - - @property - def status(self): - """Return the status.""" - return self.nodes - - def get_node_url(self, model_name: str, role: EngineRole = EngineRole.Hybrid): - """Add a node to the manager. - - Args: - model_name (str): A http url. Can be the url generated by - `lmdeploy serve api_server`. - Return: - A node url or None. - """ - - def get_matched_urls(): - urls_with_speeds, speeds, urls_without_speeds = [], [], [] - for node_url, status in self.get_nodes(role).items(): - if model_name in status.models: - if status.speed is not None: - urls_with_speeds.append(node_url) - speeds.append(status.speed) - else: - urls_without_speeds.append(node_url) - all_matched_urls = urls_with_speeds + urls_without_speeds - if len(all_matched_urls) == 0: - return None - # some nodes does not contain speed - # we can set them the average speed value - average_speed = sum(speeds) / len(speeds) if len(speeds) else 1 - all_the_speeds = speeds + [average_speed] * len(urls_without_speeds) - return all_matched_urls, all_the_speeds - - if self.routing_strategy == RoutingStrategy.RANDOM: - all_matched_urls, all_the_speeds = get_matched_urls() - if len(all_matched_urls) == 0: - return None - speed_sum = sum(all_the_speeds) - weights = [speed / speed_sum for speed in all_the_speeds] - index = random.choices(range(len(all_matched_urls)), weights=weights)[0] - url = all_matched_urls[index] - return url - elif self.routing_strategy == RoutingStrategy.MIN_EXPECTED_LATENCY: - all_matched_urls, all_the_speeds = get_matched_urls() - if len(all_matched_urls) == 0: - return None - min_latency = float('inf') - min_index = 0 - # random traverse nodes for low concurrency situation - all_indexes = [i for i in range(len(all_the_speeds))] - random.shuffle(all_indexes) - for index in all_indexes: - latency = self.get_nodes(role)[all_matched_urls[index]].unfinished / all_the_speeds[index] - if min_latency > latency: - min_latency = latency - min_index = index - url = all_matched_urls[min_index] - return url - elif self.routing_strategy == RoutingStrategy.MIN_OBSERVED_LATENCY: - all_matched_urls, latencies = [], [] - for node_url, node_status in self.get_nodes(role).items(): - if model_name in node_status.models: - if len(node_status.latency): - latencies.append(np.mean(np.array(node_status.latency))) - else: - latencies.append(float('inf')) - all_matched_urls.append(node_url) - if len(all_matched_urls) == 0: - return None - index = np.argmin(np.array(latencies)) - return all_matched_urls[index] - else: - raise ValueError(f'Invalid strategy: {self.routing_strategy}') - - async def check_request_model(self, model_name) -> Optional[JSONResponse]: - """Check if a request is valid.""" - if model_name in self.model_list: - return - ret = create_error_response(HTTPStatus.NOT_FOUND, f'The model {model_name!r} does not exist.') - return ret - - def handle_unavailable_model(self, model_name): - """Handle unavailable model. - - Args: - model_name (str): the model in the request. - """ - logger.warning(f'no model name: {model_name}') - ret = { - 'error_code': ErrorCodes.MODEL_NOT_FOUND, - 'text': err_msg[ErrorCodes.MODEL_NOT_FOUND], - } - return json.dumps(ret).encode() + b'\n' - - def handle_api_timeout(self, node_url): - """Handle the api time out.""" - logger.warning(f'api timeout: {node_url}') - ret = { - 'error_code': ErrorCodes.API_TIMEOUT.value, - 'text': err_msg[ErrorCodes.API_TIMEOUT], - } - return json.dumps(ret).encode() + b'\n' - - async def stream_generate(self, request: Dict, node_url: str, endpoint: str): - """Return a generator to handle the input request. - - Args: - request (Dict): the input request. - node_url (str): the node url. - endpoint (str): the endpoint. Such as `/v1/chat/completions`. - """ - try: - async with aiohttp.ClientSession() as session: - async with session.post(node_url + endpoint, json=request, timeout=self.aiotimeout) as response: - async for line in response.content: - if line.strip(): - yield line + b'\n\n' - except (Exception, GeneratorExit, aiohttp.ClientError) as e: # noqa - logger.error(f'catched an exception: {e}') - # exception happened, reduce unfinished num - yield self.handle_api_timeout(node_url) - - async def generate(self, request: Dict, node_url: str, endpoint: str): - """Return a the response of the input request. - - Args: - request (Dict): the input request. - node_url (str): the node url. - endpoint (str): the endpoint. Such as `/v1/chat/completions`. - """ - try: - async with aiohttp.ClientSession() as session: - async with session.post(node_url + endpoint, json=request, timeout=self.aiotimeout) as response: - return await response.text() - except (Exception, GeneratorExit, aiohttp.ClientError, asyncio.CancelledError) as e: # noqa # yapf: disable - logger.error(f'catched an exception: {e}') - return self.handle_api_timeout(node_url) - - def pre_call(self, node_url): - """Preprocess before the request get processed. - - Args: - node_url (str): the node url. - """ - self.nodes[node_url].unfinished += 1 - return time.time() - - def post_call(self, node_url: str, start: int): - """Post process after the response finished. - - Args: - node_url (str): the node url. - start (int): the start time point. time.time() - """ - if node_url in self.nodes: - self.nodes[node_url].unfinished -= 1 - self.nodes[node_url].latency.append(time.time() - start) - - def create_background_tasks(self, url: str, start: int): - """To create a background task. - - Args: - node_url (str): the node url. - start (int): the start time point. time.time() - """ - background_tasks = BackgroundTasks() - background_tasks.add_task(self.post_call, url, start) - return background_tasks +@dataclass +class AppSettings: + serving_strategy: ServingStrategy = ServingStrategy.Hybrid + routing_strategy: RoutingStrategy = RoutingStrategy.MIN_EXPECTED_LATENCY + migration_protocol: MigrationProtocol = MigrationProtocol.RDMA + dummy_prefill: bool = False + api_keys: Optional[Union[list[str], str]] = None + rdma_config: DistServeRDMAConfig = DistServeRDMAConfig( + link_type=RDMALinkType.RoCE, + with_gdr=True, + ) + pd_connection_pool: PDConnectionPool = PDConnectionPool() app = FastAPI(docs_url='/') @@ -412,98 +68,268 @@ def create_background_tasks(self, url: str, start: int): allow_methods=['*'], allow_headers=['*'], ) +app_settings = AppSettings() node_manager = NodeManager() +logger = get_logger('lmdeploy') + + +def report_model_not_found(content: str): + """Report model not found error.""" + return JSONResponse(status_code=404, content={'error_code': ErrorCodes.MODEL_NOT_FOUND.value, 'error_msg': content}) + + +async def _generate_response(node: 'Node', request_dict: dict, endpoint: str, + stream: bool) -> Union[StreamingResponse, JSONResponse]: + """Generate streaming or non-streaming response. + + Args: + node: The node to handle the request. + request_dict: The request dictionary. + endpoint: The API endpoint. + stream: Whether to stream the response. + + Returns: + StreamingResponse or JSONResponse. + """ + start = node.pre_call() + if stream: + response = node.stream_generate(request_dict, endpoint) + background_task = BackgroundTasks() + background_task.add_task(node.post_call, start) + return StreamingResponse(response, background=background_task, media_type='text/event-stream') + else: + response = await node.generate(request_dict, endpoint) + node.post_call(start) + return JSONResponse(json.loads(response)) + + +async def _handle_hybrid_request(request: Union[ChatCompletionRequest, CompletionRequest], + endpoint: str) -> Union[StreamingResponse, JSONResponse]: + """Handle request with Hybrid serving strategy. + + Args: + request: The request object (ChatCompletionRequest or CompletionRequest). + endpoint: The API endpoint. + + Returns: + StreamingResponse or JSONResponse. + """ + node = node_manager.get_node(request.model, EngineRole.Hybrid) + if not node: + return report_model_not_found(f'The model {request.model} is not available.') + + logger.info(f'A request is dispatched to {node.url}') + request_dict = request.model_dump() + return await _generate_response(node, request_dict, endpoint, request.stream) + + +async def _handle_distserve_prefill(request: Union[ChatCompletionRequest, CompletionRequest], + endpoint: str) -> Tuple[dict, str, Optional['Node']]: + """Handle DistServe prefill phase. + + Returns: + Tuple of (prefill_info, p_url, p_node). Returns (None, None, None) if no prefill node. + """ + if app_settings.dummy_prefill: + return {}, 'dummy:dummy', None + + p_node = node_manager.get_node(request.model, EngineRole.Prefill) + if not p_node: + return None, None, None + + prefill_request_dict = copy.deepcopy(request.model_dump()) + prefill_request_dict.update({'max_tokens': 1, 'stream': False, 'with_cache': True, 'preserve_cache': True}) + if endpoint == '/v1/chat/completions': + prefill_request_dict['max_completion_tokens'] = 1 + + logger.info(f'A Prefill request is dispatched to {p_node.url}') + start = p_node.pre_call() + response_text = await p_node.generate(prefill_request_dict, endpoint) + prefill_info = json.loads(response_text) + p_node.post_call(start) + + return prefill_info, p_node.url, p_node + + +async def _ensure_pd_connection(p_url: str, d_url: str, with_error_handling: bool = False) -> Optional[JSONResponse]: + """Ensure PD connection is established. + + Returns: + None if successful, JSONResponse if error and with_error_handling is True. + Raises exception if error and with_error_handling is False. + """ + if app_settings.dummy_prefill or app_settings.pd_connection_pool.is_connected(p_url, d_url): + return None + + try: + await app_settings.pd_connection_pool.connect( + PDConnectionMessage( + p_url=p_url, + d_url=d_url, + protocol=app_settings.migration_protocol, + rdma_config=app_settings.rdma_config, + )) + except Exception as e: + if with_error_handling: + logger.error(f'Connection error: {e}') + return JSONResponse(status_code=500, + content={ + 'error': 'Connection error', + 'message': f'Cannot establish connection {(p_url, d_url)}' + }) + raise + return None + + +def _build_migration_request(prefill_info: dict, p_url: str): + """Build migration request from prefill info. + + Args: + prefill_info: Prefill response information. + p_url: Prefill node URL. + + Returns: + MigrationRequest dictionary. + """ + remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0 + remote_block_ids = prefill_info.get('cache_block_ids') or [] + remote_token_ids = prefill_info.get('remote_token_ids', []) + remote_token_id = remote_token_ids[-1] if remote_token_ids else 0 + + return MigrationRequest(protocol=app_settings.migration_protocol, + remote_engine_id=p_url, + remote_session_id=remote_session_id, + remote_block_ids=remote_block_ids, + remote_token_id=remote_token_id, + is_dummy_prefill=app_settings.dummy_prefill).model_dump(mode='json') + + +async def _handle_distserve_decode(request: Union[ChatCompletionRequest, CompletionRequest], + endpoint: str, + prefill_info: dict, + p_url: str, + ensure_connection: bool = False, + handle_shelf: bool = False) -> Union[StreamingResponse, JSONResponse]: + """Handle DistServe decode phase.""" + d_node = node_manager.get_node(request.model, EngineRole.Decode) + if not d_node: + return report_model_not_found(f'The decode node for model {request.model} is not available.') + + logger.info(f'A Decode request is dispatched to {d_node.url}') + + # Ensure connection if needed + if ensure_connection: + conn_error = await _ensure_pd_connection(p_url, d_node.url, with_error_handling=True) + if conn_error: + return conn_error + if handle_shelf and not app_settings.dummy_prefill: + prefill_id = prefill_info.get('id') + if prefill_id: + app_settings.pd_connection_pool.shelf_prefill_session((p_url, d_node.url), prefill_id) + + request_dict = request.model_dump() + request_dict['migration_request'] = _build_migration_request(prefill_info, p_url) + resp = await _generate_response(d_node, request_dict, endpoint, request.stream) + + # Cleanup + if not app_settings.dummy_prefill: + prefill_id = prefill_info.get('id') + if prefill_id: + app_settings.pd_connection_pool.unshelf_prefill_session((p_url, d_node.url), prefill_id) + + return resp + + +@app.on_event('startup') +async def startup_event(): + """Initialize session when application starts.""" + await connector.get_session() + + +@app.on_event('shutdown') +async def shutdown_event(): + """Cleanup resources when application shuts down.""" + await connector.cleanup() @app.get('/v1/models', dependencies=[Depends(check_api_key)]) def available_models(): """Show available models.""" - model_cards = [] - for model_name in node_manager.model_list: - model_cards.append(ModelCard(id=model_name, root=model_name, permission=[ModelPermission()])) - return ModelList(data=model_cards) + return ModelList( + data=[ModelCard(id=name, root=name, permission=[ModelPermission()]) for name in node_manager.model_list]) @app.get('/nodes/status', dependencies=[Depends(check_api_key)]) def node_status(): """Show nodes status.""" try: - return node_manager.status - except: # noqa - return False + return node_manager.nodes + except Exception as e: + logger.error(f'Failed to get node status: {e}') + return JSONResponse(status_code=500, content={'error': 'Failed to get node status', 'message': str(e)}) @app.post('/nodes/add', dependencies=[Depends(check_api_key)]) -def add_node(node: Node, raw_request: Request = None): - """Add a node to the manager. - - - **url** (str): A http url. Can be the url generated by - `lmdeploy serve api_server`. - - **status** (Dict): The description of the node. An example: - ``{models: ['internlm-chat-7b], speed: 1}``. The speed here can be - RPM or other metric. All the values of nodes should be the same metric. - """ +async def add_node(node: NodeModel): + """Add a node to the manager.""" try: - res = node_manager.add(node.url, node.status) - if res is not None: - logger.error(f'add node {node.url} failed, {res}') - return res + if node.status is None: + from lmdeploy.serve.openai.api_client import APIClient + node.status = Status(models=APIClient(api_server_url=node.url).available_models) + await node_manager.add(node.url, node.status) logger.info(f'add node {node.url} successfully') - return 'Added successfully' - except: # noqa - return 'Failed to add, please check the input url.' + return JSONResponse(status_code=200, content={'message': 'Added successfully', 'url': node.url}) + except Exception as e: + logger.error(f'add node {node.url} failed: {e}') + return JSONResponse(status_code=500, content={'error': 'Failed to add node', 'message': str(e)}) @app.post('/nodes/remove', dependencies=[Depends(check_api_key)]) -def remove_node(node: Node): - """Show available models.""" +async def remove_node(node: NodeModel): + """Remove a node.""" try: - node_url = node.url - node_manager.remove(node_url) - logger.info(f'delete node {node_url} successfully') - return 'Deleted successfully' - except: # noqa - logger.error(f'delete node {node.url} failed.') - return 'Failed to delete, please check the input url.' + await node_manager.remove(node.url) + app_settings.pd_connection_pool.dereg_instance(node.url) + logger.info(f'removed node {node.url} successfully') + return JSONResponse(status_code=200, content={'message': 'Removed successfully', 'url': node.url}) + except Exception as e: + logger.error(f'remove node {node.url} failed: {e}') + return JSONResponse(status_code=500, content={'error': 'Failed to remove node', 'message': str(e)}) @app.post('/nodes/terminate', dependencies=[Depends(check_api_key)]) -def terminate_node(node: Node): - """Terminate nodes.""" +async def terminate_node(node: NodeModel): + """Terminate a node.""" try: - node_url = node.url - success = node_manager.terminate_node(node_url) - if not success: - return f'Failed to terminate node {node_url}' - return 'Terminated successfully' - except: # noqa - logger.error(f'Terminate node {node_url} failed.') - return 'Failed to terminate node {node_url}, please check the input url.' + await node_manager.terminate_node(node.url) + logger.info(f'Terminated node {node.url} successfully') + return JSONResponse(status_code=200, content={'message': 'Terminated successfully', 'url': node.url}) + except Exception as e: + logger.error(f'Failed to terminate node {node.url}: {e}') + return JSONResponse(status_code=500, content={'error': 'Failed to terminate node', 'message': str(e)}) @app.get('/nodes/terminate_all', dependencies=[Depends(check_api_key)]) -def terminate_node_all(): - """Terminate nodes.""" +async def terminate_node_all(): + """Terminate all nodes.""" try: - success = node_manager.terminate_all_nodes() - if not success: - return 'Failed to terminate all nodes' - return 'All nodes terminated successfully' - except: # noqa - logger.error('Failed to terminate all nodes') - return 'Failed to terminate all nodes.' + await node_manager.terminate_all_nodes() + return JSONResponse(status_code=200, content={'message': 'All nodes terminated successfully'}) + except Exception as e: + logger.error(f'Failed to terminate all nodes: {e}') + return JSONResponse(status_code=500, content={'error': 'Failed to terminate all nodes', 'message': str(e)}) @app.post('/distserve/connection_warmup') async def connection_warmup(): await asyncio.gather(*[ - node_manager.pd_connection_pool.connect( + app_settings.pd_connection_pool.connect( PDConnectionMessage( p_url=p_url, d_url=d_url, - protocol=node_manager.migration_protocol, - rdma_config=node_manager.rdma_config, - )) for p_url in node_manager.prefill_nodes for d_url in node_manager.decode_nodes + protocol=app_settings.migration_protocol, + rdma_config=app_settings.rdma_config, + )) for p_url in node_manager.prefill_nodes.keys() for d_url in node_manager.decode_nodes.keys() ]) return JSONResponse({'SUCCESS': True}) @@ -514,327 +340,58 @@ async def cache_block_gc_to_be_migrated(): raise NotImplementedError +async def _handle_request(request: Union[ChatCompletionRequest, CompletionRequest], + endpoint: str, + is_chat: bool = False) -> Union[StreamingResponse, JSONResponse]: + """Handle completion request (unified for chat and completions).""" + if app_settings.serving_strategy == ServingStrategy.Hybrid: + return await _handle_hybrid_request(request, endpoint) + elif app_settings.serving_strategy == ServingStrategy.DistServe: + prefill_info, p_url, p_node = await _handle_distserve_prefill(request, endpoint) + if p_node is None and not app_settings.dummy_prefill: + return report_model_not_found(f'The prefill node for model {request.model} is not available.') + + # For chat_completions, ensure connection without error handling + if is_chat and not app_settings.dummy_prefill: + d_node = node_manager.get_node(request.model, EngineRole.Decode) + if d_node: + await _ensure_pd_connection(p_url, d_node.url, with_error_handling=False) + + return await _handle_distserve_decode(request, + endpoint, + prefill_info, + p_url, + ensure_connection=not is_chat, + handle_shelf=not is_chat) + else: + raise ValueError(f'No serving strategy named {app_settings.serving_strategy}') + + @app.post('/v1/chat/completions', dependencies=[Depends(check_api_key)]) -async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None): +async def chat_completions_v1(request: ChatCompletionRequest): """Completion API similar to OpenAI's API. - Refer to https://platform.openai.com/docs/api-reference/chat/create - for the API specification. - - The request should be a JSON object with the following fields: - - - **model**: model name. Available from /v1/models. - - **messages**: string prompt or chat history in OpenAI format. Chat history - example: `[{"role": "user", "content": "hi"}]`. - - **temperature** (float): to modulate the next token probability - - **top_p** (float): If set to float < 1, only the smallest set of most - probable tokens with probabilities that add up to top_p or higher - are kept for generation. - - **n** (int): How many chat completion choices to generate for each input - message. **Only support one here**. - - **stream**: whether to stream the results or not. Default to false. - - **max_completion_tokens** (int | None): output token nums. Default to None. - - **max_tokens** (int | None): output token nums. Default to None. - Deprecated: Use max_completion_tokens instead. - - **repetition_penalty** (float): The parameter for repetition penalty. - 1.0 means no penalty - - **stop** (str | List[str] | None): To stop generating further - tokens. Only accept stop words that's encoded to one token idex. - - **response_format** (Dict | None): To generate response according to given - schema. Examples: - - .. code-block:: json - - { - "type": "json_schema", - "json_schema":{ - "name": "test", - "schema":{ - "properties":{ - "name":{"type":"string"} - }, - "required":["name"], - "type":"object" - } - } - } - - or - ``{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}`` - - **logit_bias** (Dict): Bias to logits. Only supported in pytorch engine. - - **tools** (List): A list of tools the model may call. Currently, only - internlm2 functions are supported as a tool. Use this to specify a - list of functions for which the model can generate JSON inputs. - - **tool_choice** (str | object): Controls which (if any) tool is called by - the model. `none` means the model will not call any tool and instead - generates a message. Specifying a particular tool via - ``{"type": "function", "function": {"name": "my_function"}}`` - forces the model to call that tool. `auto` or `required` will put all - the tools information to the model. - - Additional arguments supported by LMDeploy: - - - **top_k** (int): The number of the highest probability vocabulary - tokens to keep for top-k-filtering - - **ignore_eos** (bool): indicator for ignoring eos - - **skip_special_tokens** (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - - **min_new_tokens** (int): To generate at least numbers of tokens. - - **min_p** (float): Minimum token probability, which will be scaled by the - probability of the most likely token. It must be a value between - 0 and 1. Typical values are in the 0.01-0.2 range, comparably - selective as setting `top_p` in the 0.99-0.8 range (use the - opposite of normal `top_p` values) - - Currently we do not support the following features: - - - **presence_penalty** (replaced with repetition_penalty) - - **frequency_penalty** (replaced with repetition_penalty) + See https://platform.openai.com/docs/api-reference/chat/create """ - check_response = await node_manager.check_request_model(request.model) - if check_response is not None: - return check_response - - if node_manager.serving_strategy == ServingStrategy.Hybrid: - node_url = node_manager.get_node_url(request.model) - if not node_url: - return node_manager.handle_unavailable_model(request.model) - - logger.info(f'A request is dispatched to {node_url}') - request_dict = request.model_dump() - start = node_manager.pre_call(node_url) - if request.stream is True: - response = node_manager.stream_generate(request_dict, node_url, '/v1/chat/completions') - background_task = node_manager.create_background_tasks(node_url, start) - return StreamingResponse(response, background=background_task, media_type='text/event-stream') - else: - response = await node_manager.generate(request_dict, node_url, '/v1/chat/completions') - node_manager.post_call(node_url, start) - return JSONResponse(json.loads(response)) - elif node_manager.serving_strategy == ServingStrategy.DistServe: - request_dict = request.model_dump() - - # Prefill - prefill_request_dict = copy.deepcopy(request_dict) - prefill_request_dict['max_tokens'] = 1 - prefill_request_dict['max_completion_tokens'] = 1 - prefill_request_dict['stream'] = False - prefill_request_dict['with_cache'] = True - prefill_request_dict['preserve_cache'] = True - - prefill_info = {} - p_url = 'dummy:dummy' - if not node_manager.dummy_prefill: - p_url = node_manager.get_node_url(request.model, EngineRole.Prefill) - if not p_url: - return node_manager.handle_unavailable_model(request.model) - logger.info(f'A Prefill request is dispatched to {p_url}') - - start = node_manager.pre_call(p_url) - prefill_info = json.loads(await node_manager.generate(prefill_request_dict, p_url, '/v1/chat/completions')) - node_manager.post_call(p_url, start) - - # # Decode - d_url = node_manager.get_node_url(request.model, EngineRole.Decode) - if not d_url: - return node_manager.handle_unavailable_model(request.model) - logger.info(f'A Decode request is dispatched to {d_url}') - - if not node_manager.dummy_prefill: - if not node_manager.pd_connection_pool.is_connected(p_url, d_url): - await node_manager.pd_connection_pool.connect( - PDConnectionMessage( - p_url=p_url, - d_url=d_url, - protocol=node_manager.migration_protocol, - rdma_config=node_manager.rdma_config, - )) - - remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0 - remote_block_ids = prefill_info.get('cache_block_ids') or [] - remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0 - - request_dict['migration_request'] = MigrationRequest( - protocol=node_manager.migration_protocol, - remote_engine_id=p_url, - remote_session_id=remote_session_id, - remote_block_ids=remote_block_ids, - remote_token_id=remote_token_id, - is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json') - - start = node_manager.pre_call(d_url) - if not node_manager.dummy_prefill: - node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id']) - if request.stream is True: - response = node_manager.stream_generate(request_dict, d_url, '/v1/chat/completions') - background_task = node_manager.create_background_tasks(d_url, start) - resp = StreamingResponse(response, background=background_task, media_type='text/event-stream') - else: - response = await node_manager.generate(request_dict, d_url, '/v1/chat/completions') - node_manager.post_call(d_url, start) - resp = JSONResponse(json.loads(response)) - - if not node_manager.dummy_prefill: - node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info['id']) - - return resp - - else: - raise ValueError(f'No serving strategy named {node_manager.serving_strategy}') + return await _handle_request(request, '/v1/chat/completions', is_chat=True) @app.post('/v1/completions', dependencies=[Depends(check_api_key)]) -async def completions_v1(request: CompletionRequest, raw_request: Request = None): +async def completions_v1(request: CompletionRequest): """Completion API similar to OpenAI's API. - Go to https://platform.openai.com/docs/api-reference/completions/create - for the API specification. - - The request should be a JSON object with the following fields: - - - **model** (str): model name. Available from /v1/models. - - **prompt** (str): the input prompt. - - **suffix** (str): The suffix that comes after a completion of inserted text. - - **max_completion_tokens** (int | None): output token nums. Default to None. - - **max_tokens** (int): output token nums. Default to 16. - Deprecated: Use max_completion_tokens instead. - - **temperature** (float): to modulate the next token probability - - **top_p** (float): If set to float < 1, only the smallest set of most - probable tokens with probabilities that add up to top_p or higher - are kept for generation. - - **n** (int): How many chat completion choices to generate for each input - message. **Only support one here**. - - **stream**: whether to stream the results or not. Default to false. - - **repetition_penalty** (float): The parameter for repetition penalty. - 1.0 means no penalty - - **user** (str): A unique identifier representing your end-user. - - **stop** (str | List[str] | None): To stop generating further - tokens. Only accept stop words that's encoded to one token idex. - - Additional arguments supported by LMDeploy: - - - **ignore_eos** (bool): indicator for ignoring eos - - **skip_special_tokens** (bool): Whether or not to remove special tokens - in the decoding. Default to be True. - - **top_k** (int): The number of the highest probability vocabulary - tokens to keep for top-k-filtering - - Currently we do not support the following features: - - - **logprobs** (not supported yet) - - **presence_penalty** (replaced with repetition_penalty) - - **frequency_penalty** (replaced with repetition_penalty) + See https://platform.openai.com/docs/api-reference/completions/create """ - check_response = await node_manager.check_request_model(request.model) - if check_response is not None: - return check_response - if node_manager.serving_strategy == ServingStrategy.Hybrid: - node_url = node_manager.get_node_url(request.model) - if not node_url: - return node_manager.handle_unavailable_model(request.model) - - logger.info(f'A request is dispatched to {node_url}') - request_dict = request.model_dump() - start = node_manager.pre_call(node_url) - if request.stream is True: - response = node_manager.stream_generate(request_dict, node_url, '/v1/completions') - background_task = node_manager.create_background_tasks(node_url, start) - return StreamingResponse(response, background=background_task, media_type='text/event-stream') - else: - response = await node_manager.generate(request_dict, node_url, '/v1/completions') - node_manager.post_call(node_url, start) - return JSONResponse(json.loads(response)) - elif node_manager.serving_strategy == ServingStrategy.DistServe: - request_dict = request.model_dump() - - # Prefill - prefill_request_dict = copy.deepcopy(request_dict) - prefill_request_dict['max_tokens'] = 1 - prefill_request_dict['stream'] = False - prefill_request_dict['with_cache'] = True - prefill_request_dict['preserve_cache'] = True - - if not node_manager.dummy_prefill: - try: - p_url = node_manager.get_node_url(request.model, EngineRole.Prefill) - except Exception as e: - logger.error(f'error Msg: {str(e)}') - return {'status': 'Instance sch error, cannot find available p_url'} - - if not p_url: - return node_manager.handle_unavailable_model(request.model) - logger.info(f'A Prefill request is dispatched to {p_url}') - - start = node_manager.pre_call(p_url) - prefill_info = json.loads(await node_manager.generate(prefill_request_dict, p_url, '/v1/completions')) - node_manager.post_call(p_url, start) - else: - p_url = 'dummy:dummy' - prefill_info = {} - - # Decode - try: - d_url = node_manager.get_node_url(request.model, EngineRole.Decode) - except Exception as e: - logger.error(f'error Msg: {str(e)}') - return {'status': 'Instance sch error, cannot find available p_url'} - - if not d_url: - return node_manager.handle_unavailable_model(request.model) - logger.info(f'A Decode request is dispatched to {d_url}') - - if not node_manager.dummy_prefill: - if not node_manager.pd_connection_pool.is_connected(p_url, d_url): - try: - await node_manager.pd_connection_pool.connect( - PDConnectionMessage( - p_url=p_url, - d_url=d_url, - protocol=node_manager.migration_protocol, - rdma_config=node_manager.rdma_config, - )) - except Exception as e: - logger.error(f'error Msg: {str(e)}') - return {'status': f'Connection error, cannot establish connection {(p_url, d_url)}'} - node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id']) - - remote_session_id = int(prefill_info.get('id')) if prefill_info.get('id') else 0 - remote_block_ids = prefill_info.get('cache_block_ids') or [] - remote_token_id = prefill_info.get('remote_token_ids')[-1] if prefill_info.get('remote_token_ids') else 0 - request_dict['migration_request'] = MigrationRequest( - protocol=node_manager.migration_protocol, - remote_engine_id=p_url, - remote_session_id=remote_session_id, - remote_block_ids=remote_block_ids, - remote_token_id=remote_token_id, - is_dummy_prefill=node_manager.dummy_prefill).model_dump(mode='json') - - start = node_manager.pre_call(d_url) - if not node_manager.dummy_prefill: - node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id']) - if request.stream is True: - response = node_manager.stream_generate(request_dict, d_url, '/v1/completions') - background_task = node_manager.create_background_tasks(d_url, start) - resp = StreamingResponse(response, background=background_task, media_type='text/event-stream') - else: - response = await node_manager.generate(request_dict, d_url, '/v1/completions') - node_manager.post_call(d_url, start) - resp = JSONResponse(json.loads(response)) - if not node_manager.dummy_prefill: - node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info.get('id')) - return resp - else: - raise ValueError(f'No serving strategy named {node_manager.serving_strategy}') + return await _handle_request(request, '/v1/completions', is_chat=False) def proxy(server_name: str = '0.0.0.0', server_port: int = 8000, serving_strategy: Literal['Hybrid', 'DistServe'] = 'Hybrid', routing_strategy: Literal['random', 'min_expected_latency', 'min_observed_latency'] = 'min_expected_latency', - api_keys: Optional[Union[List[str], str]] = None, + api_keys: Optional[Union[list[str], str]] = None, ssl: bool = False, log_level: str = 'INFO', - disable_cache_status: bool = False, link_type: Literal['RoCE', 'IB'] = 'RoCE', migration_protocol: Literal['RDMA'] = 'RDMA', dummy_prefill: bool = False, @@ -853,25 +410,20 @@ def proxy(server_name: str = '0.0.0.0', a single api_key. Default to None, which means no api key applied. ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'. log_level (str): Set the log level. Default to INFO. - disable_cache_status (str): Whether to cache the proxy status to - proxy_config.yml. migration_protocol: migration protocol when PD disaggregation. RDMA default. """ # noqa - node_manager.serving_strategy = ServingStrategy[serving_strategy] - node_manager.routing_strategy = RoutingStrategy.from_str(routing_strategy) - node_manager.migration_protocol = MigrationProtocol[migration_protocol] - node_manager.dummy_prefill = dummy_prefill - - node_manager.rdma_config = DistServeRDMAConfig( + app_settings.serving_strategy = ServingStrategy[serving_strategy] + app_settings.migration_protocol = MigrationProtocol[migration_protocol] + app_settings.dummy_prefill = dummy_prefill + app_settings.rdma_config = DistServeRDMAConfig( link_type=RDMALinkType[link_type], with_gdr=True, ) - node_manager.cache_status = not disable_cache_status + node_manager.routing_strategy = RoutingStrategy.from_str(routing_strategy) if api_keys is not None: if isinstance(api_keys, str): api_keys = api_keys.split(',') - from lmdeploy.serve.openai.api_server import VariableInterface - VariableInterface.api_keys = api_keys + app_settings.api_keys = api_keys ssl_keyfile, ssl_certfile = None, None if ssl: ssl_keyfile = os.environ['SSL_KEYFILE']