diff --git a/docs/en/api/pipeline.rst b/docs/en/api/pipeline.rst index 4fea54f540..436822682b 100644 --- a/docs/en/api/pipeline.rst +++ b/docs/en/api/pipeline.rst @@ -5,12 +5,11 @@ Inference pipeline Pipeline -------- .. autofunction:: pipeline - -Serving --------- -.. autofunction:: lmdeploy.api.serve -.. autofunction:: lmdeploy.api.client - +.. autoclass:: Pipeline + :undoc-members: + :show-inheritance: + :members: __init__, infer, stream_infer, chat, get_ppl + :member-order: bysource Config ------------------- diff --git a/docs/en/llm/pipeline.md b/docs/en/llm/pipeline.md index f3361c05e5..506878ced1 100644 --- a/docs/en/llm/pipeline.md +++ b/docs/en/llm/pipeline.md @@ -123,7 +123,7 @@ from lmdeploy import pipeline, GenerationConfig pipe = pipeline('internlm/internlm2_5-7b-chat') -gen_config=GenerationConfig(output_logits='generation' +gen_config=GenerationConfig(output_logits='generation', max_new_tokens=10) response = pipe(['Hi, pls intro yourself', 'Shanghai is'], gen_config=gen_config) diff --git a/docs/zh_cn/api/pipeline.rst b/docs/zh_cn/api/pipeline.rst index 839fd6ab4a..d03ce1998f 100644 --- a/docs/zh_cn/api/pipeline.rst +++ b/docs/zh_cn/api/pipeline.rst @@ -5,12 +5,11 @@ Pipeline -------- .. autofunction:: pipeline - -Serving --------- -.. autofunction:: lmdeploy.api.serve -.. autofunction:: lmdeploy.api.client - +.. autoclass:: Pipeline + :undoc-members: + :show-inheritance: + :members: __init__, infer, stream_infer, chat, get_ppl + :member-order: bysource Config ------------------- diff --git a/lmdeploy/__init__.py b/lmdeploy/__init__.py index 548772663b..644fb7b0e5 100644 --- a/lmdeploy/__init__.py +++ b/lmdeploy/__init__.py @@ -3,10 +3,11 @@ from .api import client, pipeline, serve from .messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, VisionConfig from .model import ChatTemplateConfig +from .pipeline import Pipeline from .tokenizer import Tokenizer from .version import __version__, version_info __all__ = [ 'pipeline', 'serve', 'client', 'Tokenizer', 'GenerationConfig', '__version__', 'version_info', 'ChatTemplateConfig', - 'PytorchEngineConfig', 'TurbomindEngineConfig', 'VisionConfig' + 'PytorchEngineConfig', 'TurbomindEngineConfig', 'VisionConfig', 'Pipeline' ] diff --git a/lmdeploy/api.py b/lmdeploy/api.py index e6aa322fb9..4f0ff34315 100644 --- a/lmdeploy/api.py +++ b/lmdeploy/api.py @@ -1,18 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os -from typing import List, Literal, Optional, Union +from __future__ import annotations -from .archs import autoget_backend_config, get_task -from .messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig -from .model import ChatTemplateConfig +from typing import TYPE_CHECKING, List, Literal + +from typing_extensions import deprecated + +from .pipeline import Pipeline + +if TYPE_CHECKING: + from .messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig + from .model import ChatTemplateConfig def pipeline(model_path: str, - backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, - chat_template_config: Optional[ChatTemplateConfig] = None, + backend_config: 'TurbomindEngineConfig' | 'PytorchEngineConfig' | None = None, + chat_template_config: 'ChatTemplateConfig' | None = None, log_level: str = 'WARNING', - max_log_len: int = None, - speculative_config: SpeculativeConfig = None, + max_log_len: int | None = None, + speculative_config: 'SpeculativeConfig' | None = None, **kwargs): """ Args: @@ -59,141 +64,46 @@ def pipeline(model_path: str, print(response) """ # noqa E501 - if os.getenv('TM_LOG_LEVEL') is None: - os.environ['TM_LOG_LEVEL'] = log_level - from lmdeploy.utils import get_logger, get_model - logger = get_logger('lmdeploy') - logger.setLevel(log_level) - - # model_path is not local path. - if not os.path.exists(model_path): - download_dir = backend_config.download_dir \ - if backend_config is not None else None - revision = backend_config.revision \ - if backend_config is not None else None - model_path = get_model(model_path, download_dir, revision) - - # spec model - if speculative_config is not None and speculative_config.model and not os.path.exists(speculative_config.model): - download_dir = backend_config.download_dir \ - if backend_config is not None else None - speculative_config.model = get_model(speculative_config.model, download_dir) - - _, pipeline_class = get_task(model_path) - if not isinstance(backend_config, PytorchEngineConfig): - # set auto backend mode - backend_config = autoget_backend_config(model_path, backend_config) - backend = 'pytorch' if isinstance(backend_config, PytorchEngineConfig) else 'turbomind' - logger.info(f'Using {backend} engine') - - return pipeline_class(model_path, - backend=backend, - backend_config=backend_config, - chat_template_config=chat_template_config, - max_log_len=max_log_len, - speculative_config=speculative_config, - **kwargs) + + return Pipeline(model_path, + backend_config=backend_config, + chat_template_config=chat_template_config, + log_level=log_level, + max_log_len=max_log_len, + speculative_config=speculative_config, + **kwargs) +@deprecated('This function is no longer available. Please use CLI command "lmdeploy serve api_server" instead.') def serve(model_path: str, - model_name: Optional[str] = None, + model_name: str | None = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', - backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, - chat_template_config: Optional[ChatTemplateConfig] = None, + backend_config: 'TurbomindEngineConfig' | 'PytorchEngineConfig' | None = None, + chat_template_config: 'ChatTemplateConfig' | None = None, server_name: str = '0.0.0.0', server_port: int = 23333, log_level: str = 'ERROR', - api_keys: Optional[Union[List[str], str]] = None, + api_keys: List[str] | str | None = None, ssl: bool = False, **kwargs): - """This will run the api_server in a subprocess. + """This function is deprecated and no longer available. - Args: - model_path: the path of a model. - It could be one of the following options: - - - i) A local directory path of a turbomind model which is - converted by ``lmdeploy convert`` command or download from - ii) and iii). - - ii) The model_id of a lmdeploy-quantized model hosted - inside a model repo on huggingface.co, such as - ``InternLM/internlm-chat-20b-4bit``, - ``lmdeploy/llama2-chat-70b-4bit``, etc. - - iii) The model_id of a model hosted inside a model repo - on huggingface.co, such as ``internlm/internlm-chat-7b``, - ``Qwen/Qwen-7B-Chat``, ``baichuan-inc/Baichuan2-7B-Chat`` - and so on. + .. deprecated:: + This function has been removed. Please use alternative methods. - model_name: the name of the served model. It can be accessed - by the RESTful API ``/v1/models``. If it is not specified, - ``model_path`` will be adopted - backend: either ``turbomind`` or ``pytorch`` backend. Default to - ``turbomind`` backend. - backend_config: backend - config instance. Default to none. - chat_template_config: chat template configuration. - Default to None. - server_name: host ip for serving - server_port: server port - log_level: set log level whose value among - [``CRITICAL``, ``ERROR``, ``WARNING``, ``INFO``, ``DEBUG``] - api_keys: Optional list of API keys. Accepts string type as - a single api_key. Default to None, which means no api key applied. - ssl: Enable SSL. Requires OS Environment variables - ``SSL_KEYFILE`` and ``SSL_CERTFILE``. + This will run the api_server in a subprocess. + """ # noqa E501 + raise NotImplementedError("The 'serve' function is no longer available. " + 'This function has been deprecated and removed.') - Return: - APIClient: A client chatbot for LLaMA series models. - Examples: +@deprecated('This function is no longer available. Please use "from lmdeploy.serve import APIClient" instead.') +def client(api_server_url: str = 'http://0.0.0.0:23333', api_key: str | None = None, **kwargs): + """This function is deprecated and no longer available. - .. code-block:: python + .. deprecated:: + This function has been removed. Please use ``from lmdeploy.serve import APIClient`` instead. - from lmdeploy.api import serve - client = serve('internlm/internlm-chat-7b', 'internlm-chat-7b') - for output in client.chat('hi', 1): - print(output) - """ # noqa E501 - import time - from multiprocessing import Process - - from lmdeploy.serve.openai.api_client import APIClient - from lmdeploy.serve.openai.api_server import serve - - if type(backend_config) is not PytorchEngineConfig: - # set auto backend mode - backend_config = autoget_backend_config(model_path, backend_config) - backend = 'pytorch' if type(backend_config) is PytorchEngineConfig else 'turbomind' - - task = Process(target=serve, - args=(model_path, ), - kwargs=dict(model_name=model_name, - backend=backend, - backend_config=backend_config, - chat_template_config=chat_template_config, - server_name=server_name, - server_port=server_port, - log_level=log_level, - api_keys=api_keys, - ssl=ssl, - **kwargs), - daemon=True) - task.start() - client = APIClient(f'http://{server_name}:{server_port}') - while True: - time.sleep(1) - try: - client.available_models - print(f'Launched the api_server in process {task.pid}, user can ' - f'kill the server by:\nimport os,signal\nos.kill({task.pid}, ' - 'signal.SIGKILL)') - return client - except: # noqa - pass - - -def client(api_server_url: str = 'http://0.0.0.0:23333', api_key: Optional[str] = None, **kwargs): - """ Args: api_server_url: communicating address ``http://:`` of api_server @@ -202,5 +112,5 @@ def client(api_server_url: str = 'http://0.0.0.0:23333', api_key: Optional[str] Return: Chatbot for LLaMA series models with turbomind as inference engine. """ - from lmdeploy.serve.openai.api_client import APIClient - return APIClient(api_server_url, api_key, **kwargs) + raise NotImplementedError("The 'client' function is no longer available. This function has been deprecated. " + ' Please use "from lmdeploy.serve import APIClient" instead.') diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index 444a2026a3..dd4aabb219 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Tuple from transformers import AutoConfig @@ -57,8 +57,8 @@ def autoget_backend(model_path: str) -> Literal['turbomind', 'pytorch']: def autoget_backend_config( model_path: str, - backend_config: Optional[Union[PytorchEngineConfig, TurbomindEngineConfig]] = None -) -> Union[PytorchEngineConfig, TurbomindEngineConfig]: + backend_config: PytorchEngineConfig | TurbomindEngineConfig | None = None +) -> Tuple[Literal['turbomind', 'pytorch'], PytorchEngineConfig | TurbomindEngineConfig]: """Get backend config automatically. Args: @@ -72,14 +72,14 @@ def autoget_backend_config( """ from dataclasses import asdict + if isinstance(backend_config, PytorchEngineConfig): + return 'pytorch', backend_config + backend = autoget_backend(model_path) - if backend == 'pytorch': - config = PytorchEngineConfig() - else: - config = TurbomindEngineConfig() + config = PytorchEngineConfig() if backend == 'pytorch' else TurbomindEngineConfig() if backend_config is not None: if type(backend_config) == type(config): - return backend_config + config = backend_config else: data = asdict(backend_config) for k, v in data.items(): @@ -90,7 +90,7 @@ def autoget_backend_config( config.block_size = backend_config.cache_block_seq_len else: config.cache_block_seq_len = backend_config.block_size - return config + return backend, config def check_vl_llm(config: dict) -> bool: @@ -126,14 +126,14 @@ def check_vl_llm(config: dict) -> bool: def get_task(model_path: str): """Get pipeline type and pipeline class from model config.""" - from lmdeploy.serve.async_engine import AsyncEngine + from lmdeploy.serve.core import AsyncEngine if os.path.exists(os.path.join(model_path, 'triton_models', 'weights')): # workspace model return 'llm', AsyncEngine _, config = get_model_arch(model_path) if check_vl_llm(config.to_dict()): - from lmdeploy.serve.vl_async_engine import VLAsyncEngine + from lmdeploy.serve.core import VLAsyncEngine return 'vlm', VLAsyncEngine # default task, pipeline_class @@ -146,40 +146,27 @@ def get_model_arch(model_path: str): Args: model_path(str): the model path """ - if os.path.exists(os.path.join(model_path, 'triton_models', 'weights')): - # the turbomind model - import yaml - config_file = os.path.join(model_path, 'triton_models', 'weights', 'config.yaml') - with open(config_file, 'r') as f: - config = yaml.safe_load(f) - - from .turbomind.deploy.config import TurbomindModelConfig - tm_config = TurbomindModelConfig.from_dict(config) - - return tm_config.model_config.model_arch, tm_config + try: + cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + except Exception as e: # noqa + from transformers import PretrainedConfig + cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True) + + _cfg = cfg.to_dict() + if _cfg.get('architectures', None): + arch = _cfg['architectures'][0] + if _cfg.get('auto_map'): + for _, v in _cfg['auto_map'].items(): + if 'InternLMXComposer2ForCausalLM' in v: + arch = 'InternLMXComposer2ForCausalLM' + elif _cfg.get('auto_map', None) and 'AutoModelForCausalLM' in _cfg['auto_map']: + arch = _cfg['auto_map']['AutoModelForCausalLM'].split('.')[-1] + elif _cfg.get('language_config', None) and _cfg['language_config'].get( + 'auto_map', None) and 'AutoModelForCausalLM' in _cfg['language_config']['auto_map']: + arch = _cfg['language_config']['auto_map']['AutoModelForCausalLM'].split('.')[-1] else: - # transformers model - try: - cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - except Exception as e: # noqa - from transformers import PretrainedConfig - cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True) - - _cfg = cfg.to_dict() - if _cfg.get('architectures', None): - arch = _cfg['architectures'][0] - if _cfg.get('auto_map'): - for _, v in _cfg['auto_map'].items(): - if 'InternLMXComposer2ForCausalLM' in v: - arch = 'InternLMXComposer2ForCausalLM' - elif _cfg.get('auto_map', None) and 'AutoModelForCausalLM' in _cfg['auto_map']: - arch = _cfg['auto_map']['AutoModelForCausalLM'].split('.')[-1] - elif _cfg.get('language_config', None) and _cfg['language_config'].get( - 'auto_map', None) and 'AutoModelForCausalLM' in _cfg['language_config']['auto_map']: - arch = _cfg['language_config']['auto_map']['AutoModelForCausalLM'].split('.')[-1] - else: - raise RuntimeError(f'Could not find model architecture from config: {_cfg}') - return arch, cfg + raise RuntimeError(f'Could not find model architecture from config: {_cfg}') + return arch, cfg def search_nested_config(config, key): diff --git a/lmdeploy/cli/chat.py b/lmdeploy/cli/chat.py index 8c1e32c84e..2921a85363 100644 --- a/lmdeploy/cli/chat.py +++ b/lmdeploy/cli/chat.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from contextlib import closing + import fire from lmdeploy import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, pipeline @@ -70,36 +72,38 @@ def main(model_path, backend, **kwargs): if backend != 'pytorch': # set auto backend mode backend = autoget_backend(model_path) - - pipe = build_pipe(model_path, backend, **kwargs) - gen_config = build_gen_config(**kwargs) - adapter_name = get_adapter_name(**kwargs) - quit = False - while not quit: - with pipe.session(gen_config) as sess: - while True: - try: - prompt = input_prompt() - except KeyboardInterrupt: - quit = True - break - if prompt == 'end': - sess.close() - break - if prompt == 'exit': - quit = True - break - if prompt.strip() == '': - continue - resps = sess(prompt, adapter_name=adapter_name) - try: - for resp in resps: - print(resp.text, end='', flush=True) - except KeyboardInterrupt: - sess.stop() - else: - print('exiting...') + with build_pipe(model_path, backend, **kwargs) as pipe: + gen_config = build_gen_config(**kwargs) + adapter_name = get_adapter_name(**kwargs) + while not quit: + with closing(pipe.session()) as sess: + while True: + try: + prompt = input_prompt() + except KeyboardInterrupt: + quit = True + break + if prompt == 'end': + sess.close() + break + if prompt == 'exit': + quit = True + break + if prompt.strip() == '': + continue + resps = pipe.chat(prompt, + session=sess, + gen_config=gen_config, + adapter_name=adapter_name, + stream_response=True) + try: + for resp in resps: + print(resp.text, end='', flush=True) + except KeyboardInterrupt: + sess.abort() + else: + print('exiting...') if __name__ == '__main__': diff --git a/lmdeploy/pipeline.py b/lmdeploy/pipeline.py new file mode 100644 index 0000000000..a81e4ae300 --- /dev/null +++ b/lmdeploy/pipeline.py @@ -0,0 +1,572 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import atexit +import concurrent.futures +import os +from contextlib import closing +from functools import partial +from queue import Queue +from threading import Thread +from typing import TYPE_CHECKING, Dict, Iterator, List, Tuple + +import torch +import tqdm +from typing_extensions import deprecated + +from .archs import autoget_backend_config, get_task +from .messages import GenerationConfig, PytorchEngineConfig, Response, SpeculativeConfig, TurbomindEngineConfig +from .model import ChatTemplateConfig +from .serve.processors import MultimodalProcessor +from .utils import get_logger, get_model + +if TYPE_CHECKING: + from PIL.Image import Image + + from .serve.managers import Session + +logger = get_logger('lmdeploy') + + +class Pipeline: + """Pipeline - User-facing API layer for inference.""" + + def __init__(self, + model_path: str, + backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None, + chat_template_config: ChatTemplateConfig | None = None, + log_level: str = 'WARNING', + max_log_len: int | None = None, + speculative_config: SpeculativeConfig | None = None, + **kwargs): + """Initialize Pipeline. + + Args: + model_path: Path to the model. + backend_config: Backend configuration. + chat_template_config: Chat template configuration. + log_level: Log level. + max_log_len: Max number of prompt characters or prompt tokens being printed in log. + speculative_config: Speculative decoding configuration. + **kwargs: Additional keyword arguments. + """ + + os.environ.setdefault('TM_LOG_LEVEL', log_level) + logger.setLevel(log_level) + + # Download model if the path does not exist locally + if not os.path.exists(model_path): + download_dir = backend_config.download_dir if backend_config else None + revision = backend_config.revision if backend_config else None + model_path = get_model(model_path, download_dir, revision) + + # Download speculative model if the path does not exist locally + if speculative_config and speculative_config.model and not os.path.exists(speculative_config.model): + download_dir = backend_config.download_dir if backend_config else None + speculative_config.model = get_model(speculative_config.model, download_dir) + + # Create inference engine + _, pipeline_class = get_task(model_path) + backend, backend_config = autoget_backend_config(model_path, backend_config) + self.async_engine = pipeline_class(model_path, + backend=backend, + backend_config=backend_config, + chat_template_config=chat_template_config, + max_log_len=max_log_len, + speculative_config=speculative_config, + **kwargs) + self.internal_thread = _EventLoopThread(daemon=True) + self.limiter: asyncio.Semaphore = None + self.session_mgr = self.async_engine.session_mgr + self.backend_config = self.async_engine.backend_config + self.async_engine.start_loop(self.internal_thread.loop, use_async_api=False) + + def infer(self, + prompts: List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple], + gen_config: GenerationConfig | List[GenerationConfig] | None = None, + do_preprocess: bool = True, + adapter_name: str | None = None, + use_tqdm: bool = False, + **kwargs): + """Inference prompts. + + Args: + prompts: Prompts to inference. It can be a single prompt, a list of prompts, a list of tuples, or a tuple. + Tuple can be (prompt, image or [images]) or (image or [images], prompt). + gen_config(GenerationConfig | List[GenerationConfig] | None): Generation configuration(s). + do_preprocess(bool): Whether to pre-process messages. + adapter_name(str | None): Adapter name. + use_tqdm(bool): Whether to use progress bar. + **kwargs(dict): Additional keyword arguments. + """ + is_single = self._is_single(prompts) + # format prompts to openai message format, which is a list of dicts + prompts = MultimodalProcessor.format_prompts(prompts) + pbar = tqdm.tqdm(total=len(prompts)) if use_tqdm else None + outputs = [] + try: + requests = self._request_generator(prompts, + gen_config=gen_config, + do_preprocess=do_preprocess, + adapter_name=adapter_name, + stream_response=False, + **kwargs) + for g in self._infer(requests, multiplex=False, pbar=pbar): + res = None + for out in g: + res = res.extend(out) if res else out + outputs.append(res) + finally: + if pbar: pbar.close() # noqa + if is_single: + return outputs[0] + return outputs + + @deprecated('This method is deprecated. Please use "Pipeline.infer" instead.') + def batch_infer(self, *args, **kwargs): + return self.infer(*args, **kwargs) + + def stream_infer(self, + prompts: List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple], + sessions: 'Session' | List['Session'] | None = None, + gen_config: GenerationConfig | List[GenerationConfig] | None = None, + do_preprocess: bool = True, + adapter_name: str | None = None, + stream_response: bool = True, + **kwargs): + """Stream inference. + + Args: + prompts(List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple]): Prompts to inference. + It can be a single prompt, a list of prompts, a list of tuples, or a tuple. + Tuple can be (prompt, image or [images]) or (image or [images], prompt). + sessions(Session | List[Session] | None): Sessions. Each of which corresponds to a prompt. + gen_config(GenerationConfig | List[GenerationConfig] | None): Generation configuration(s). + do_preprocess(bool): Whether to pre-process messages. + adapter_name(str | None): Adapter name. + stream_response(bool): Whether to stream the response. If True, the generator will stream the response. + Otherwise, the generator will run until finish and return the final response. This argument + is introduced to support the streaming and non-streaming modes of Pipeline.chat. + **kwargs(dict): Additional keyword arguments. + + Returns: + Generator: A generator that yields the output (i.e. instance of class `Response`) of the inference. + """ + prompts = MultimodalProcessor.format_prompts(prompts) + requests = self._request_generator(prompts, + sessions=sessions, + gen_config=gen_config, + do_preprocess=do_preprocess, + adapter_name=adapter_name, + stream_response=stream_response, + **kwargs) + return self._infer(requests, multiplex=True) + + def close(self): + """Close the pipeline.""" + self.internal_thread.close() + self.async_engine.close() + + def chat(self, + prompt: str | Tuple[str, 'Image' | List['Image']], + session=None, + gen_config: GenerationConfig | None = None, + stream_response=False, + adapter_name=None, + **kwargs) -> 'Session' | Iterator: + """Chat. + + Args: + prompt (str): prompt + session (Session): the chat session + gen_config (GenerationConfig | None): a instance of + GenerationConfig. Default to None. + stream_response (bool): whether to stream the response. + adapter_name (str): adapter name. + **kwargs (dict): additional keyword arguments. + """ + if session is None: + session = self.session_mgr.get() + session.update(prompt=prompt, response=None) + + prompt = MultimodalProcessor.format_prompts(prompt) + + sequence_start = session.step == 0 + generator = self.stream_infer(prompts=prompt, + sessions=session, + gen_config=gen_config, + stream_response=stream_response, + adapter_name=adapter_name, + multiplex=True, + sequence_start=sequence_start, + sequence_end=False, + step=session.step) + + def _gen(): + resp = None + try: + for out in generator: + resp = resp.extend(out) if resp else out + yield out + except: # noqa + self._run(coro=session.async_abort()) + raise + else: + session.response = resp + session.step += resp.generate_token_len + resp.input_token_len + session.history.append((session.prompt, resp.text)) + + if stream_response: + return _gen() + else: + # run the generator until finish + with closing(_gen()) as gen: + for _ in gen: + pass + session.generator = None + + return session + + def session(self) -> 'Session': + """Create a new session.""" + return self.session_mgr.get() + + def get_reward_score(self, input_ids: List) -> List[float]: + """Get reward score. + + Args: + input_ids(List): a list of token_id or a list of token_id list or token_id tensor + Return: + reward score in a list. If the input_ids is a list of token_id, the return value + is still a list with length 1. + """ + supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel'] + arch = self.async_engine.arch + if arch not in supported_reward_models: + raise ValueError(f'{arch} is not in reward model list: {supported_reward_models}') + assert isinstance(input_ids, List) + assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids) + # Make input_ids a list of token_id list + input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids + logits = self._run(coro=self.async_engine.async_get_logits(input_ids=input_ids)).result() + logits = [x.squeeze() for x in logits] + scores = [x[-1].cpu().item() for x in logits] + return scores + + def get_ppl(self, input_ids: List[int] | List[List[int]]) -> List[float]: + """Get perplexity scores given a list of input tokens that have to be + of the same length. + + Args: + input_ids (List[int] | List[List[int]]): the batch of input token ids + + Returns: + List[float]: A list of perplexity scores. + """ + assert isinstance(input_ids, List) + if isinstance(input_ids[0], int): + input_ids = [input_ids] + assert all(len(_) > 1 for _ in input_ids) + + # TODO: a better way to determine `max_input_len`, at most allocate + # 2G mem for logits with shape [bs, max_input_len, vocab_size] + vocab_size = self.async_engine.hf_cfg.vocab_size + max_input_len = 2 * 1024**3 // (vocab_size * 4) + sizes = [len(_) for _ in input_ids] + result = [] + sorted_index_values = sorted(list(enumerate(sizes)), key=lambda x: x[1], reverse=True) + sizes = [value for index, value in sorted_index_values] + indices = [index for index, value in sorted_index_values] + logger.info(f'sorted sizes: {sizes}') + logger.info(f'sorted indices: {indices}') + for (start, end) in self._batch_iterator(sizes, max_input_len): + logger.info(f'start: {start}, end: {end}') + if start == end: + _input_ids = input_ids[indices[start]] + res = self._get_long_text_ppl(input_ids=_input_ids, max_input_len=max_input_len) + result.append(res) + else: + _input_ids = [input_ids[indices[i]] for i in range(start, end)] + res = self._get_ppl( + input_ids=_input_ids, + max_input_len=max_input_len, + ) + result.extend(res) + output = list(range(len(result))) + for index, sorted_index in enumerate(indices): + output[sorted_index] = result[index] + return output + + def __call__(self, + prompts: List[str] | str | List[Dict] | List[List[Dict]], + gen_config: GenerationConfig | List[GenerationConfig] | None = None, + **kwargs): + return self.infer(prompts, gen_config=gen_config, **kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + @deprecated('This method is deprecated. Please use "AsyncEngine.generate" instead.') + async def generate(self, *args, **kwargs): + """Generate responses as an async generator. + + This method delegates to async_engine.generate and forwards all yielded values. + """ + async for item in self.async_engine.generate(*args, **kwargs): + yield item + + @staticmethod + def _is_single(prompts): + """Check if prompts is a single prompt.""" + return (isinstance(prompts, str) or (isinstance(prompts, tuple) and len(prompts) == 2) + or (isinstance(prompts, list) and len(prompts) > 0 and isinstance(prompts[0], Dict))) + + def _request_generator(self, + prompts: List[str] | str | List[Dict] | List[List[Dict]], + sessions: List['Session'] | 'Session' | None = None, + gen_config: GenerationConfig | List[GenerationConfig] | None = None, + **kwargs): + """Generate requests.""" + is_single = self._is_single(prompts) + prompts = [prompts] if is_single else prompts + + if sessions is None: + sessions = [self.session_mgr.get() for _ in prompts] + elif isinstance(sessions, list): + sessions = sessions + else: + sessions = [sessions] + + if len(prompts) != len(sessions): + raise ValueError(f'prompts and sessions should have the same length. ' + f'Got {len(prompts)} prompts and {len(sessions)} sessions') + + if gen_config is None: + gen_configs = [GenerationConfig()] * len(prompts) + elif isinstance(gen_config, list): + gen_configs = gen_config + else: + gen_configs = [gen_config] * len(prompts) + + if len(prompts) != len(gen_configs): + raise ValueError(f'input gen_config length differs from the length of prompts. ' + f'Got {len(prompts)} prompts and {len(gen_configs)} gen_configs') + + for prompt, gen_cfg, session in zip(prompts, gen_configs, sessions): + # Use session_id is for backward compatibility. We will remove it in the future. + # Since AsyncEngine.generate defines session_id in the argument lists, here we + # use session_id to pass the session to the AsyncEngine.generate. It's + yield dict(session_id=session, messages=prompt, gen_config=gen_cfg, **kwargs) + + def _get_limiter(self): + if not self.limiter: + self.limiter = asyncio.Semaphore(self.backend_config.max_batch_size) + return self.limiter + + def _infer(self, requests: Iterator[Dict], multiplex: bool, pbar=None, loop=None) -> Iterator[Iterator[Response]]: + + async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore): + async for out in g: + que.put(out.to_response(idx)) + sem.release() + if not multiplex: + que.put(None) # sentinel of inner generator + if pbar: + pbar.update(1) + + que = Queue() + + async def _infer(): + sem = self._get_limiter() + tasks = [] + for idx, req in enumerate(requests): + await sem.acquire() + gen = self.async_engine.generate(**req) + dst = que if multiplex else Queue() + if not multiplex: + que.put(iter(dst.get, None)) + # create a task to send the responses + task = asyncio.create_task(_sync_resp(gen, dst, idx, sem)) + tasks.append(task) + if not multiplex: # sentinel of outer generator + que.put(None) + await asyncio.gather(*tasks) + if multiplex: + que.put(None) # sentinel of inner generator + + loop = loop or self.internal_thread.loop + # submit the coroutine to async world + asyncio.run_coroutine_threadsafe(_infer(), + loop).add_done_callback(lambda f: None if f.cancelled() else f.result()) + + return iter(que.get, None) + + def _run(self, fn=None, coro=None): + assert (fn or coro) and not (fn and coro) + loop = self.internal_thread.loop + if fn: + + async def _coro(): + return fn() + + coro = _coro() + return asyncio.run_coroutine_threadsafe(coro, loop) + + def _batch_iterator(self, sizes, max_value): + """Return an iterator that calculates intervals (start, end) of a + descend-order list, in which the sum of values in the range is the + maximum number not less than max_value. By "the sum of values", + + here it means $$len(sizes[start:end]) * sizes[start]$$ + """ + i = 0 + while i < len(sizes): + current_sum = 0 + start_index = i + + while i < len(sizes) and current_sum + sizes[start_index] <= max_value: + current_sum += sizes[start_index] + i += 1 + + yield (start_index, i) + if i > start_index: + continue + else: + i += 1 + + def _get_long_text_ppl(self, input_ids, max_input_len): + assert all(isinstance(_, int) for _ in input_ids) + seq_len = len(input_ids) + assert seq_len > max_input_len + logger.info(f'get long text ppl: seq_len {seq_len}') + + losses = [] + target_counts = [] + for i in range(0, seq_len, max_input_len): + token_ids = input_ids[i:i + max_input_len] + step = [i] + # shift token_ids by 1 to the left + target_ids = input_ids[i + 1:i + 1 + max_input_len] + loss = self._get_ppl(input_ids=[token_ids], + max_input_len=len(token_ids), + target_ids=[target_ids], + steps=step, + sequence_start=(i == 0), + sequence_end=False) + losses.extend(loss) + target_counts.append(len(target_ids)) + losses = [loss * target_count for loss, target_count in zip(losses, target_counts)] + loss_sum = sum(losses) + target_count = sum(target_counts) + return loss_sum / target_count + + def _get_ppl(self, + input_ids, + max_input_len, + target_ids=None, + steps=None, + sequence_start: bool = True, + sequence_end: bool = True): + assert (isinstance(input_ids, List) and all(isinstance(_, List) for _ in input_ids)) + assert steps is None or len(steps) == len(input_ids) + assert target_ids is None or len(target_ids) == len(input_ids) + + lens = [len(_) for _ in input_ids] + total_len = sum(lens) + assert sum(lens) <= max_input_len + + logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, ' + f'total_len: {total_len}, steps: {steps}') + torch.cuda.empty_cache() + + logits = self._run(coro=self.async_engine.async_get_logits( + input_ids=input_ids, steps=steps, sequence_start=sequence_start, sequence_end=sequence_end)).result() + padding_token_id = -100 + if target_ids is None: + target_ids = [x[1:] + [padding_token_id] for x in input_ids] + else: + target_ids = [ + target_ids[i] + [padding_token_id] if len(target_ids[i]) < len(input_ids[i]) else target_ids[i] + for i in range(len(input_ids)) + ] + target_ids = [torch.Tensor(torch.LongTensor(_target_ids)) for _target_ids in target_ids] + + result = [] + for _logits, _target_ids in zip(logits, target_ids): + _logits = _logits.float() + vocab_size = _logits.shape[-1] + _target_ids = _target_ids.to(_logits.device) + target_mask = _target_ids != padding_token_id + # compute cross entropy loss + flat_logits = _logits.contiguous().view(-1, vocab_size) + flat_target_ids = _target_ids.contiguous().view(-1) + flat_loss_matrix = torch.nn.functional.cross_entropy(flat_logits, + flat_target_ids, + reduction='none', + ignore_index=padding_token_id) + loss = flat_loss_matrix.sum() + target_count = target_mask.sum() + result.append(loss.item() / target_count.item()) + logger.info(f'ppl result: {result}') + return result + + +class _EventLoopThread: + + def __init__(self, daemon=False): + fut = concurrent.futures.Future() + self.thread = Thread(target=partial(self._thread_entry, fut), daemon=daemon) + self.thread.start() + self.loop: asyncio.AbstractEventLoop = fut.result() + self.closed = False + if daemon: + atexit.register(self.close) + + def _thread_entry(self, fut): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + fut.set_result(loop) + try: + loop.run_forever() + except BaseException as e: + logger.error(f'[internal_thread] {type(e).__name__} {e}') + finally: + try: + self._cancel_all_tasks() + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() + + def _cancel_all_tasks(self): + """Modified from asyncio/runners.py.""" + to_cancel = asyncio.all_tasks(self.loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + async def _gather(): + await asyncio.gather(*to_cancel, return_exceptions=True) + + self.loop.run_until_complete(_gather()) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + self.loop.call_exception_handler({ + 'message': 'unhandled exception during worker thread shutdown', + 'exception': task.exception(), + 'task': task, + }) + + def close(self): + if self.closed: + return + self.closed = True + self.loop.call_soon_threadsafe(self.loop.stop) + self.thread.join() diff --git a/lmdeploy/serve/__init__.py b/lmdeploy/serve/__init__.py index ef101fec61..3fa68c54e3 100644 --- a/lmdeploy/serve/__init__.py +++ b/lmdeploy/serve/__init__.py @@ -1 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .core import AsyncEngine, VLAsyncEngine +from .managers import Session, SessionManager +from .processors import MultimodalProcessor + +__all__ = [ + 'AsyncEngine', + 'VLAsyncEngine', + 'SessionManager', + 'Session', + 'MultimodalProcessor', +] diff --git a/lmdeploy/serve/core/__init__.py b/lmdeploy/serve/core/__init__.py new file mode 100644 index 0000000000..1f3cba0a57 --- /dev/null +++ b/lmdeploy/serve/core/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .async_engine import AsyncEngine +from .vl_async_engine import VLAsyncEngine + +__all__ = ['AsyncEngine', 'VLAsyncEngine'] diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/core/async_engine.py similarity index 50% rename from lmdeploy/serve/async_engine.py rename to lmdeploy/serve/core/async_engine.py index 98cdcc0b72..1fed0eecbc 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -1,21 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -import atexit import concurrent.futures import dataclasses import random -from contextlib import asynccontextmanager, closing +from contextlib import asynccontextmanager from copy import deepcopy -from functools import partial -from itertools import count -from queue import Queue -from threading import Thread -from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal -import tqdm +import torch -from lmdeploy import Tokenizer from lmdeploy.archs import get_model_arch from lmdeploy.logger import RequestLogger from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, Response, ResponseType, SpeculativeConfig, @@ -25,11 +19,13 @@ from lmdeploy.model import ChatTemplateConfig, get_chat_template from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest) -from lmdeploy.serve.multimodal_processor import MultimodalProcessor -from lmdeploy.serve.utils import LogitsMixin -from lmdeploy.tokenizer import DetokenizeState +from lmdeploy.serve.managers import Session, SessionManager +from lmdeploy.serve.processors import MultimodalProcessor +from lmdeploy.tokenizer import DetokenizeState, Tokenizer from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_hf_gen_cfg, get_logger +from .exceptions import SafeRunException + logger = get_logger('lmdeploy') @@ -40,12 +36,12 @@ class GenOut: history_token_len: int input_token_len: int generate_token_len: int - finish_reason: Optional[Literal['stop', 'length', 'error']] = None - token_ids: List[int] = None - logprobs: List[Dict[int, float]] = None + finish_reason: Literal['stop', 'length', 'error'] | None = None + token_ids: List[int] | None = None + logprobs: List[Dict[int, float]] | None = None logits: Any = None last_hidden_state: Any = None - cache_block_ids: List[int] = None # for disaggregation + cache_block_ids: List[int] | None = None # for disaggregation routed_experts: Any = None # for RL router replay def to_response(self, index: int = 0) -> Response: @@ -66,144 +62,8 @@ def to_response(self, index: int = 0) -> Response: index=index) -class Session: - """Session for AsyncEngine.chat. - - Args: - _id (int): session_id for internal use. - _step (int): the offset of the k/v cache for internal use. - _prompt (Any): input prompt for internal use. - _response (Reaponse): model output for prompt. - _engine (Any): engine for internal use. - history (List[Any, str]): chat history. - """ - - def __init__(self, session_id: int, engine: Any, gen_config: GenerationConfig = None): - self._id: int = session_id - self._engine = engine - self._step: int = 0 - self._prompt: Any = None - self._response: Response = None - self._gen_config = gen_config - self.history: List[Tuple[Any, str]] = [] - - def _merge_response(self, resp: Response, step: Union[Response, GenOut]): - """Merge response.""" - resp.text += step.text if isinstance(step, Response) else step.response - resp.input_token_len = step.input_token_len - resp.generate_token_len = step.generate_token_len - resp.finish_reason = step.finish_reason - return resp - - @property - def response(self) -> Response: - """Return response.""" - return self._response - - def close(self): - """Release engine storage for this session.""" - if self._engine and self._prompt: - self._engine._run(coro=self._engine.end_session(self._id)).result() - self._engine = None - - def stop(self): - if self._engine and self._prompt: - self._engine._run(coro=self._engine.stop_session(self._id)).result() - - def __repr__(self) -> str: - res = '' - for user, assistant in self.history: - if isinstance(user, list): - user = str(user) - res += f'USER: \n{user}\nASSISTANT: \n{assistant}\n' - return res - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def __call__(self, - prompt: str, - gen_config: Optional[GenerationConfig] = None, - stream_response: bool = True, - do_preprocess: bool = True, - adapter_name: str = None, - **kwargs) -> Union[Response, Iterator[Response]]: - self._engine.chat(prompt, - gen_config=gen_config or self._gen_config, - stream_response=stream_response, - do_preprocess=do_preprocess, - session=self, - adapter_name=adapter_name, - **kwargs) - if stream_response: - return self.generator - else: - return self.response - - -class _EventLoopThread: - - def __init__(self, daemon=False): - fut = concurrent.futures.Future() - self.thread = Thread(target=partial(self._thread_entry, fut), daemon=daemon) - self.thread.start() - self.loop: asyncio.AbstractEventLoop = fut.result() - self.closed = False - if daemon: - atexit.register(self.close) - - def _thread_entry(self, fut): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - fut.set_result(loop) - try: - loop.run_forever() - except BaseException as e: - logger.error(f'[internal_thread] {type(e).__name__} {e}') - finally: - try: - self._cancel_all_tasks() - loop.run_until_complete(loop.shutdown_asyncgens()) - finally: - asyncio.set_event_loop(None) - loop.close() - - def _cancel_all_tasks(self): - """Modified from asyncio/runners.py.""" - to_cancel = asyncio.all_tasks(self.loop) - if not to_cancel: - return - - for task in to_cancel: - task.cancel() - - async def _gather(): - await asyncio.gather(*to_cancel, return_exceptions=True) - - self.loop.run_until_complete(_gather()) - - for task in to_cancel: - if task.cancelled(): - continue - if task.exception() is not None: - self.loop.call_exception_handler({ - 'message': 'unhandled exception during worker thread shutdown', - 'exception': task.exception(), - 'task': task, - }) - - def close(self): - if self.closed: - return - self.closed = True - self.loop.call_soon_threadsafe(self.loop.stop) - self.thread.join() - - -class AsyncEngine(LogitsMixin): +# class AsyncEngine(LogitsMixin): +class AsyncEngine: """Async inference engine. Maintaining a bunch of tm_model instances. Args: @@ -235,12 +95,12 @@ class AsyncEngine(LogitsMixin): def __init__(self, model_path: str, - model_name: Optional[str] = None, + model_name: str | None = None, backend: Literal['turbomind', 'pytorch'] = 'turbomind', - backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, - chat_template_config: Optional[ChatTemplateConfig] = None, - max_log_len: int = None, - speculative_config: SpeculativeConfig = None, + backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None, + chat_template_config: ChatTemplateConfig | None = None, + max_log_len: int | None = None, + speculative_config: SpeculativeConfig | None = None, **kwargs) -> None: logger.info(f'input backend={backend}, backend_config={backend_config}') logger.info(f'speculative_config={speculative_config}') @@ -277,26 +137,20 @@ def __init__(self, if self.stop_words is not None: self.stop_words = self.stop_words[0][0].tolist() self.backend = backend - self.instance_num = self.backend_config.max_batch_size - self.id2step = {} - self.id2inst = {} - self.free_insts: asyncio.Queue = None - self.instances = [self.engine.create_instance() for _ in range(self.instance_num)] - self._session_id = count(0) self.request_logger = RequestLogger(max_log_len) - self.internal_thread = _EventLoopThread(daemon=True) - self.limiter: asyncio.Semaphore = None + self.num_spec_token = 0 if backend == 'turbomind' or speculative_config is None \ else speculative_config.num_speculative_tokens + self.session_mgr = SessionManager() + self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size) + # build stat loggers self._build_stat_loggers() self.epoch = 0 def close(self): - self.internal_thread.close() - self.free_insts = None - self.instances.clear() + self.session_mgr.clear() self.engine.close() def __enter__(self): @@ -305,23 +159,15 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() - def _get_free_insts(self): - if self.free_insts is None: - # `asyncio.Queue` must be created in an async context - self.free_insts = asyncio.Queue() - for inst in self.instances: - self.free_insts.put_nowait(inst) - return self.free_insts - - def _build_turbomind(self, model_path: str, backend_config: TurbomindEngineConfig = None, **kwargs): + def _build_turbomind(self, model_path: str, backend_config: TurbomindEngineConfig | None = None, **kwargs): """Inner build method for turbomind backend.""" from lmdeploy import turbomind as tm return tm.TurboMind.from_pretrained(model_path, engine_config=backend_config, **kwargs) def _build_pytorch(self, model_path: str, - backend_config: PytorchEngineConfig = None, - speculative_config: SpeculativeConfig = None, + backend_config: PytorchEngineConfig | None = None, + speculative_config: SpeculativeConfig | None = None, **kwargs): """Inner build method for pytorch backend.""" from lmdeploy.pytorch.engine import Engine @@ -348,37 +194,6 @@ def _build_stat_loggers(self): def get_schedule_metrics(self): return self.engine.get_schedule_metrics() - def __call__(self, - prompts: Union[List[str], str, List[Dict], List[List[Dict]]], - gen_config: Optional[GenerationConfig] = None, - do_preprocess: bool = True, - adapter_name: Optional[str] = None, - use_tqdm: bool = False, - **kwargs): - """Inference a batch of prompts. - - Args: - prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a - batch of prompts. It accepts: string prompt, a list of string - prompts, a chat history in OpenAI format or a list of chat - history. - gen_config (GenerationConfig | None): a instance of - GenerationConfig. Default to None. - do_preprocess (bool): whether pre-process the messages. Default to - True, which means chat_template will be applied. - adapter_name (str): the adapter name of slora for pytorch backend. - Pick one from adapters. Default to None, using the base model. - use_tqdm (bool): Whether use the progress bar. Default to False - """ - if gen_config is None: - gen_config = GenerationConfig() - return self.batch_infer(prompts, - gen_config=gen_config, - do_preprocess=do_preprocess, - adapter_name=adapter_name, - use_tqdm=use_tqdm, - **kwargs) - async def do_log_stats(self): """Loop through CLI logger and Prometheus logger and output the metrics.""" @@ -389,41 +204,7 @@ async def stop_all_session(self): """Stop all running sessions.""" logger.info('stop all sessions') self.epoch += 1 - tasks = [] - session_ids = [] - for session_id in list(self.id2inst.keys()): - generator = self.id2inst.get(session_id) - if generator: - session_ids.append(session_id) - logger.debug(f'stop session {session_id}') - tasks.append(generator.async_cancel(session_id)) - await asyncio.gather(*tasks) - logger.info(f'all {len(session_ids)} sessions stopped') - - async def stop_session(self, session_id: int): - """Stop a session by a session_id.""" - logger.info(f'stop session {session_id}') - generator = self.id2inst.get(session_id) - if generator: - await generator.async_cancel(session_id) - logger.info(f'session {session_id} stopped') - # else it's not running at all - - async def end_session(self, session_id: int): - """For ending a session that is not running.""" - logger.info(f'end session {session_id}') - inst = self.id2inst.get(session_id) - if inst: - await inst._active.wait() - assert session_id not in self.id2inst - inst = await self._get_free_insts().get() - try: - await inst.async_end(session_id) - self.id2step[session_id] = 0 - except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa - logger.error(f'[end_session] exception caught: {e}') - finally: - self._get_free_insts().put_nowait(inst) + await self.session_mgr.async_abort_all() def sleep(self, level: int = 1): """Sleep the model. @@ -437,7 +218,7 @@ def sleep(self, level: int = 1): self.sleeping_tags = {'weights', 'kv_cache'} self.is_sleeping = True - def wakeup(self, tags: Optional[List[str]] = None): + def wakeup(self, tags: List[str] | None = None): """Wake up the model. Args: @@ -452,225 +233,70 @@ def wakeup(self, tags: Optional[List[str]] = None): logger.warning(f'some tag in {tags} not in sleeping tags {self.sleeping_tags}') return self.engine.wakeup(tags) - # for TM backend, sleep/wakeup will reset gateway, therefore we need to rebuild instance + # for TM backend, sleep/wakeup will reset gateway, therefore we need to rebuild instances if self.backend == 'turbomind' and 'kv_cache' in tags: - self.instances = [self.engine.create_instance() for _ in range(self.instance_num)] - self.free_insts = None + self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size) self.sleeping_tags = self.sleeping_tags - set(tags) self.is_sleeping = bool(self.sleeping_tags) - def _get_limiter(self): - if not self.limiter: - self.limiter = asyncio.Semaphore(self.instance_num) - return self.limiter - - async def _async_infer(self, requests: AsyncIterator[Dict], **kwargs) -> AsyncIterator[AsyncIterator[Response]]: - async for req in requests: - gen = self.generate(**req, **kwargs) - yield gen - - def _infer(self, requests: Iterator[Dict], multiplex: bool, pbar=None, loop=None) -> Iterator[Iterator[Response]]: - - async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore): - async for out in g: - que.put(out.to_response(idx)) - sem.release() - if not multiplex: - que.put(None) # sentinel of inner generator - if pbar: - pbar.update(1) - - que = Queue() - - async def _infer(): - sem = self._get_limiter() - tasks = [] - for idx, req in enumerate(requests): - await sem.acquire() - gen = self.generate(**req) - dst = que if multiplex else Queue() - if not multiplex: - que.put(iter(dst.get, None)) - # create a task to send the responses - task = asyncio.create_task(_sync_resp(gen, dst, idx, sem)) - tasks.append(task) - if not multiplex: # sentinel of outer generator - que.put(None) - await asyncio.gather(*tasks) - if multiplex: - que.put(None) # sentinel of inner generator - - loop = loop or self.internal_thread.loop - # submit the coroutine to async world - asyncio.run_coroutine_threadsafe(_infer(), - loop).add_done_callback(lambda f: None if f.cancelled() else f.result()) - - return iter(que.get, None) - - @staticmethod - def _is_single(prompts): - return isinstance(prompts, str) or isinstance(prompts[0], Dict) - - def infer(self, - prompts: Union[List[str], str, List[Dict], List[List[Dict]]], - gen_config: Optional[Union[GenerationConfig, List[GenerationConfig]]] = None, - do_preprocess: bool = True, - adapter_name: Optional[str] = None, - stream_response: bool = False, - multiplex: bool = False, - pbar: Optional[tqdm.tqdm] = None, - **kwargs): - - prompts = [prompts] if AsyncEngine._is_single(prompts) else prompts - assert isinstance(prompts, List), 'prompts should be a list' - gen_config = gen_config or GenerationConfig() - if not isinstance(gen_config, List): - gen_config = [gen_config] * len(prompts) - assert len(prompts) == len(gen_config), \ - 'input gen_confg length differs from the length of prompts' # noqa - - def requests(): - for prompt, gen_cfg in zip(prompts, gen_config): - r = dict(messages=prompt, - gen_config=gen_cfg, - do_preprocess=do_preprocess, - adapter_name=adapter_name, - stream_response=stream_response, - **kwargs) - r.setdefault('sequence_start', True) - r.setdefault('sequence_end', True) - if 'session_id' not in r: - r['session_id'] = next(self._session_id) - yield r - - return self._infer(requests(), multiplex, pbar) - - def batch_infer(self, - prompts: Union[List[str], str, List[Dict], List[List[Dict]]], - gen_config: Optional[Union[GenerationConfig, List[GenerationConfig]]] = None, - do_preprocess: bool = True, - adapter_name: Optional[str] = None, - use_tqdm: bool = False, - **kwargs): - """Inference a batch of prompts. - - Args: - prompts (List[str] | str | List[Dict] | List[List[Dict]]]): a - batch of prompts. It accepts: string prompt, a list of string - prompts, a chat history in OpenAI format or a list of chat - history. - gen_config (GenerationConfig | None): a instance of or a list of - GenerationConfig. Default to None. - do_preprocess (bool): whether pre-process the messages. Default to - True, which means chat_template will be applied. - adapter_name (str): the adapter name of slora for pytorch backend. - Pick one from adapters. Default to None, using the base model. - use_tqdm (bool): Whether use the progress bar. Default to False - """ - is_single = AsyncEngine._is_single(prompts) - outputs = [] - pbar = tqdm.tqdm(total=1 if is_single else len(prompts)) if use_tqdm else None - try: - for g in self.infer(prompts, - gen_config, - do_preprocess, - adapter_name, - stream_response=False, - pbar=pbar, - **kwargs): - res = None - for out in g: - res = res.extend(out) if res else out - outputs.append(res) - finally: - if pbar: pbar.close() # noqa - if is_single: - return outputs[0] - return outputs - - def stream_infer(self, - prompts: Union[List[str], str, List[Dict], List[List[Dict]]], - gen_config: Optional[Union[GenerationConfig, List[GenerationConfig]]] = None, - do_preprocess: bool = True, - adapter_name: Optional[str] = None, - stream_response: bool = True, - **kwargs): - """Inference a batch of prompts with stream mode. - - Args: - prompts (List[str] | str | List[Dict] | List[List[Dict]]]):a - batch of prompts. It accepts: string prompt, a list of string - prompts, a chat history in OpenAI format or a list of chat - history. - gen_config (GenerationConfig | None): a instance of or a list of - GenerationConfig. Default to None. - do_preprocess (bool): whether pre-process the messages. Default to - True, which means chat_template will be applied. - adapter_name (str): the adapter name of slora for pytorch backend. - Pick one from adapters. Default to None, using the base model. - """ - return self.infer(prompts, gen_config, do_preprocess, adapter_name, stream_response, multiplex=True, **kwargs) - - @asynccontextmanager - async def model_inst(self, session_id: int): - """A context manager to make sure server's safe running.""" - logger.debug(f'[model_inst] session {session_id} applying for a model instance') - assert session_id not in self.id2inst - free_insts = self._get_free_insts() - inst = await free_insts.get() - inst._active = asyncio.Event() - self.id2inst[session_id] = inst - logger.debug(f'[model_inst] session {session_id} acquired an instance') - try: - yield inst - except (Exception, asyncio.CancelledError, GeneratorExit) as e: - logger.error(f'[model_inst] session {session_id} exception caught: {e}') - if self.backend == 'pytorch': - # manually end pytorch session - await inst.async_end(session_id) - finally: - logger.debug(f'[model_inst] session {session_id} releasing the instance') - self.id2inst.pop(session_id, None) - inst._active.set() - free_insts.put_nowait(inst) + def _determine_gen_config(self, session, input_ids, gen_config: GenerationConfig | None = None) -> GenerationConfig: + """Determine the generation configuration.""" + gen_config = deepcopy(gen_config) or GenerationConfig() + gen_config.convert_stop_bad_words_to_ids(self.tokenizer) + gen_config.stop_token_ids = gen_config.stop_token_ids or self.stop_words + gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id) + if not gen_config.do_sample: + # greedy decode + gen_config.top_k = 1 + # avoid unnecessary process + gen_config.temperature = 1.0 + gen_config.repetition_penalty = 1.0 + # set random if it is not set and sequence_start is True + elif gen_config.random_seed is None and session.step == 0: + gen_config.random_seed = random.getrandbits(64) + if gen_config.n > 1: + logger.warning(f'n({gen_config.n}) > 1 hasn\'t been supported yet. Fallback to 1') + gen_config.n = 1 + if gen_config.max_new_tokens is None: + gen_config.max_new_tokens = max(0, self.session_len - session.step - len(input_ids)) + return gen_config @asynccontextmanager - async def safe_run(self, inst, session_id, **kwargs): - generator = inst.async_stream_infer(session_id, **kwargs) + async def safe_run(self, handle, session, **kwargs): + generator = handle.async_stream_infer(session.session_id, **kwargs) try: yield generator except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa - logger.error(f'[safe_run] session {session_id} exception caught: {type(e).__name__} {e}') - # TODO: remove session_id from async cancel - await inst.async_cancel(session_id) - raise e + logger.error(f'[safe_run] session {session.session_id} exception caught: {type(e).__name__} {e}') + await session.async_abort() + raise SafeRunException(f'Safe run exception for session {session.session_id}') from e finally: await generator.aclose() async def generate( self, messages, - session_id: int, - gen_config: Optional[GenerationConfig] = None, - tools: Optional[List[object]] = None, - reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None, + session_id: int | Session, + gen_config: GenerationConfig | None = None, + tools: List[object] | None = None, + reasoning_effort: Literal['low', 'medium', 'high'] | None = None, stream_response: bool = True, sequence_start: bool = True, sequence_end: bool = True, # no interactive mode by default step: int = 0, do_preprocess: bool = True, - adapter_name: Optional[str] = None, + adapter_name: str | None = None, rewind_stop_tokens: bool = False, - input_ids: Optional[List] = None, - enable_thinking: Optional[bool] = None, - chat_template_kwargs: Optional[Dict] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, + input_ids: List | None = None, + enable_thinking: bool | None = None, + chat_template_kwargs: Dict | None = None, + mm_processor_kwargs: Dict[str, Any] | None = None, **kwargs): """Generate responses. Args: messages (str | List): chat history or prompt - session_id (int): the session id + session_id (int | Session): the session id or instance of Session gen_config (GenerationConfig | None): a instance of GenerationConfig. Default to None. stream_response (bool): whether return responses streamingly @@ -683,30 +309,13 @@ async def generate( epoch = self.epoch if (messages is not None) ^ (input_ids is None): raise ValueError('You must specify exactly one of messages or input_ids') - if session_id not in self.id2step: - self.id2step[session_id] = 0 - if step != 0: - self.id2step[session_id] = step - if gen_config is None: - gen_config = GenerationConfig() + if isinstance(session_id, Session): + session = session_id + elif isinstance(session_id, int): + session = self.session_mgr.get(session_id, step=step) else: - gen_config = deepcopy(gen_config) - gen_config.convert_stop_bad_words_to_ids(self.tokenizer) - if gen_config.stop_token_ids is None: - gen_config.stop_token_ids = self.stop_words - gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id) - if not gen_config.do_sample: - # greedy decode - gen_config.top_k = 1 - # avoid unnecessary process - gen_config.temperature = 1.0 - gen_config.repetition_penalty = 1.0 - # set random if it is not set and sequence_start is True - elif gen_config.random_seed is None and sequence_start: - gen_config.random_seed = random.getrandbits(64) - if gen_config.n > 1: - logger.warning(f'n({gen_config.n}) > 1 hasn\'t been supported yet. Fallback to 1') - gen_config.n = 1 + raise ValueError(f'Invalid session_id: {session_id}. It should be an instance of Session or an integer.') + session_id = session.session_id chat_template_kwargs = chat_template_kwargs or {} if enable_thinking is not None: logger.warning('enable_thinking is deprecated, use chat_template_kwargs["enable_thinking"] instead') @@ -717,7 +326,7 @@ async def generate( 'the value will not be overwritten by enable_thinking') if messages: prompt = messages - self.request_logger.log_prompt(session_id=session_id, prompt=prompt) + self.request_logger.log_prompt(session, prompt=prompt) prompt_input = await self.prompt_processor.get_prompt_input(prompt=prompt, do_preprocess=do_preprocess, sequence_start=sequence_start, @@ -725,59 +334,62 @@ async def generate( tools=tools, reasoning_effort=reasoning_effort, chat_template_kwargs=chat_template_kwargs, + mm_processor_kwargs=mm_processor_kwargs, **kwargs) prompt = prompt_input['prompt'] input_ids = prompt_input['input_ids'] - self.request_logger.log_inputs(session_id=session_id, + self.request_logger.log_inputs(session, prompt=prompt, prompt_token_ids=input_ids, gen_config=gen_config, adapter_name=adapter_name) - logger.info(f'session={session_id}, ' - f'history_tokens={self.id2step[session_id]}, ' - f'input_tokens={len(input_ids)}, ' - f'max_new_tokens={gen_config.max_new_tokens}, ' - f'seq_start={sequence_start}, seq_end={sequence_end}, ' - f'step={step}, prep={do_preprocess}') else: # TODO(lvhan) VLM doesn't support input_ids as an argument. # Figure out a graceful way to handle the invalid input prompt_input = dict(input_ids=input_ids) - if gen_config.max_new_tokens is None: - gen_config.max_new_tokens = max(0, self.session_len - self.id2step[session_id] - len(input_ids)) - if gen_config.max_new_tokens == 0: + max_new_tokens = max(0, self.session_len - session.step - len(input_ids)) + if max_new_tokens == 0: logger.error(f'run out of tokens. session={session_id}.') yield GenOut(response='', - history_token_len=self.id2step[session_id], + history_token_len=session.step, input_token_len=len(input_ids), generate_token_len=0, finish_reason='length', token_ids=[]) if sequence_end is True and sequence_start is False: - await self.end_session(session_id) + await session.async_close() return if self.backend_config.enable_prefix_caching and (gen_config.output_last_hidden_state == 'all' or gen_config.output_logits == 'all'): errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state ' 'when prefix caching is ON') yield GenOut(response=errmsg, - history_token_len=self.id2step[session_id], + history_token_len=session.step, input_token_len=len(input_ids), generate_token_len=0, finish_reason='error', token_ids=[]) return + logger.info(f'session={session_id}, ' + f'history_tokens={session.step}, ' + f'input_tokens={len(input_ids)}, ' + f'max_new_tokens={gen_config.max_new_tokens}, ' + f'seq_start={sequence_start}, seq_end={sequence_end}, ' + f'step={step}, prep={do_preprocess}') def is_error(status): return status not in [ResponseType.SUCCESS, ResponseType.FINISH, ResponseType.CANCEL] + gen_config = self._determine_gen_config(session, input_ids, gen_config=gen_config) + stop_ids = [] if not gen_config.ignore_eos: stop_ids = gen_config.stop_token_ids or [] metrics_processor.increment_total_requests() - async with self.model_inst(session_id) as inst: + + async with session.request_handle() as handle: if epoch != self.epoch: logger.debug(f'[generate] session {session_id} got aborted before starting inference') # TODO(lvhan): metrics_processor.increment_failed_requests('abort') @@ -790,14 +402,14 @@ def is_error(status): token_ids=[]) return token_ids = input_ids.copy() - history_len = self.id2step[session_id] + history_len = session.step input_len = len(input_ids) output_len, gen_len = 0, 0 - state = DetokenizeState(len(input_ids)) + state = DetokenizeState(input_len) response = '' finish_reason = None - async with self.safe_run(inst, - session_id=session_id, + async with self.safe_run(handle, + session=session, **prompt_input, gen_config=gen_config, adapter_name=adapter_name, @@ -885,7 +497,7 @@ def is_error(status): f'"{finish_reason}", input_tokens ' f'{len(input_ids)}, output_tokens {gen_len}') yield GenOut(response, - self.id2step[session_id], + session.step, len(input_ids), gen_len, finish_reason, @@ -895,105 +507,33 @@ def is_error(status): last_hidden_state=last_hidden_state, routed_experts=routed_experts, cache_block_ids=outputs.cache_block_ids) - # Update a session's sequence only when it is in finished status - if outputs.status == ResponseType.FINISH: - if rewind_stop_tokens: - # rewind the step to the token before the stop token - output_len = gen_len - self.id2step[session_id] += input_len + output_len + # Note: We remove the session step update here. Let the caller(e.g., pipeline.chat) take care of it. else: logger.error(f'session {session_id} finished, {outputs.status}, ' 'reason "error"') yield GenOut(response=f'internal error happened, status code {outputs.status}', - history_token_len=self.id2step[session_id], + history_token_len=session.step, input_token_len=len(input_ids), generate_token_len=0, finish_reason='error', token_ids=[]) # update step if sequence_end: - self.id2step[session_id] = 0 if self.backend == 'pytorch': # manually end pytorch session - await inst.async_end(session_id) - - def _run(self, fn=None, coro=None, loop=None): - assert (fn or coro) and not (fn and coro) - loop = loop or self.internal_thread.loop - if fn: - - async def _coro(): - return fn() - - coro = _coro() - return asyncio.run_coroutine_threadsafe(coro, loop) - - def session(self, gen_config: GenerationConfig = None): - return Session(self._run(fn=lambda: next(self._session_id)).result(), engine=self, gen_config=gen_config) - - def chat(self, - prompt: str, - session=None, - gen_config: Optional[GenerationConfig] = None, - stream_response=False, - adapter_name=None, - **kwargs) -> Union[Session, Iterator]: - """Chat. - - Args: - prompt (str): prompt - session (Session): the chat session - gen_config (GenerationConfig | None): a instance of - GenerationConfig. Default to None. - do_preprocess (bool): whether pre-process the messages. Default to - True, which means chat_template will be applied. - **kwargs (dict): ad hoc parametrization of `gen_config - """ - if session is None: - session = self.session() - - # sync & init - session._prompt = prompt - session._response = None - - sequence_start = session._step == 0 - - generator = self.infer(prompt, - gen_config, - adapter_name=adapter_name, - sequence_start=sequence_start, - sequence_end=False, - session_id=session._id, - stream_response=stream_response, - multiplex=True, - step=session._step) - - def _gen(): - resp = None - try: - for out in generator: - resp = resp.extend(out) if resp else out - yield out - except: # noqa - self._run(coro=self.stop_session(session._id)).result() - raise - else: - session._response = resp - session._step += resp.generate_token_len + resp.input_token_len - session.history.append((session._prompt, resp.text)) - - if stream_response: - session.generator = _gen() - else: - # run the generator until finish - with closing(_gen()) as gen: - for _ in gen: - pass - session.generator = None - - return session - - def start_loop(self, use_async_api=False): + # note: Using session.async_abort() here results in deadlock + # because it waits for session's _active event to be set, but the event won't be set + # until the session is finished, i.e., session.request_handle() context exits. + await handle.async_end(session.session_id) + self.session_mgr.remove(session) + # if sequence_end: + # if self.backend == 'pytorch': + # # manually end pytorch session. session cannot be ended until session.request_handle() + # # context exits + # await session.async_close() + # self.session_mgr.remove(session) + + def start_loop(self, loop, use_async_api=False): """Start engine loop. When using pytorch backend with dp > 1, all dp_rank should receive at least one request before it can start @@ -1003,6 +543,7 @@ def start_loop(self, use_async_api=False): The purpose of this function is to allow users to choose whether to use the synchronous interface or the asynchronous interface for the pipeline. """ + self.session_mgr.attach_event_loop(loop) if hasattr(self.engine, 'start_loop'): if use_async_api: return self.engine.start_loop() @@ -1013,7 +554,7 @@ def _start_loop(fut): res = self.engine.start_loop() fut.set_result(res) - self.internal_thread.loop.call_soon_threadsafe(_start_loop, fut) + loop.call_soon_threadsafe(_start_loop, fut) return fut.result() else: return True @@ -1036,3 +577,57 @@ def p2p_drop_connect(self, drop_conn_request: List[DistServeDropConnectionReques return self.engine.p2p_drop_connect(drop_conn_request) """ DistServe Async Engine API End """ + + async def async_get_reward_score(self, input_ids: List) -> List[float]: + """Async version of get_reward_score.""" + supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel'] + if self.arch not in supported_reward_models: + raise ValueError(f'{self.arch} is not in reward model list: {supported_reward_models}') + assert isinstance(input_ids, List) + assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids) + # Make input_ids a list of token_id list + input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids + + logits = await self.async_get_logits(input_ids=input_ids) + + logits = [x.squeeze() for x in logits] + scores = [x[-1].cpu().item() for x in logits] + return scores + + async def async_get_logits(self, + input_ids, + steps: List[int] | None = None, + sequence_start: bool = True, + sequence_end: bool = True) -> List[torch.Tensor]: + assert input_ids and all(isinstance(_, List) for _ in input_ids) + assert steps is None or (len(steps) == len(input_ids)) + + logits = [None] * len(input_ids) + + async def _proc(session, i): + async with session.request_handle() as handle: + input_len = len(input_ids[i]) + # TODO(lvhan): Fix the ugly code later on + max_new_tokens = 1 if self.backend == 'turbomind' else 0 + # The reason to set `top_k=1` is that pt engine crashes at top_k sampling stage + # when perform inference on a reward model. + gen_config = GenerationConfig(max_new_tokens=max_new_tokens, output_logits='all', top_k=1) + async with self.safe_run(handle, + session=session, + input_ids=input_ids[i], + gen_config=gen_config, + stream_output=False, + sequence_start=sequence_start, + sequence_end=sequence_end, + step=steps[i] if steps else 0) as gen: + async for outputs in gen: + pass + logits[i] = outputs.logits[:input_len, :] + + sessions = [self.session_mgr.get() for _ in range(len(input_ids))] + tasks = [_proc(session, i) for i, session in enumerate(sessions)] + await asyncio.gather(*tasks) + if sequence_end and self.backend == 'pytorch': + for session in sessions: + await session.async_close() + return logits diff --git a/lmdeploy/serve/core/exceptions.py b/lmdeploy/serve/core/exceptions.py new file mode 100644 index 0000000000..3391f3c14e --- /dev/null +++ b/lmdeploy/serve/core/exceptions.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Exceptions for the serve module.""" + + +class SafeRunException(Exception): + """Exception raised by safe_run to avoid upper layer handling the original + exception again. + + This exception wraps the original exception that occurred during safe_run execution. + """ diff --git a/lmdeploy/serve/core/vl_async_engine.py b/lmdeploy/serve/core/vl_async_engine.py new file mode 100644 index 0000000000..44fd97dac6 --- /dev/null +++ b/lmdeploy/serve/core/vl_async_engine.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Literal + +from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig +from lmdeploy.utils import get_logger + +from .async_engine import AsyncEngine + +logger = get_logger('lmdeploy') + + +class VLAsyncEngine(AsyncEngine): + """Visual Language Async inference engine.""" + + def __init__(self, + model_path: str, + backend: Literal['turbomind', 'pytorch'] = 'turbomind', + backend_config: TurbomindEngineConfig | PytorchEngineConfig | None = None, + vision_config: VisionConfig | None = None, + **kwargs) -> None: + from lmdeploy.serve.processors import MultimodalProcessor + from lmdeploy.utils import try_import_deeplink + from lmdeploy.vl.engine import ImageEncoder + + if backend == 'pytorch': + try_import_deeplink(backend_config.device_type) + if backend_config and backend_config.enable_prefix_caching: + backend_config.enable_prefix_caching = False + logger.warning('Prefix caching is disabled since LMDeploy hasn\'t support in on VL models yet') + self.vl_encoder = ImageEncoder(model_path, backend, vision_config, backend_config=backend_config) + super().__init__(model_path, backend=backend, backend_config=backend_config, **kwargs) + # Update prompt_processor to support multimodal processing + self.prompt_processor = MultimodalProcessor(self.tokenizer, + self.chat_template, + vl_encoder=self.vl_encoder, + backend=backend) + if self.model_name == 'base': + raise RuntimeError( + 'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501 + ) + + def close(self): + if hasattr(self, 'vl_encoder'): + del self.vl_encoder + super().close() diff --git a/lmdeploy/serve/managers/__init__.py b/lmdeploy/serve/managers/__init__.py new file mode 100644 index 0000000000..17ccf480e2 --- /dev/null +++ b/lmdeploy/serve/managers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .session_manager import Session, SessionManager + +__all__ = ['Session', 'SessionManager'] diff --git a/lmdeploy/serve/managers/session_manager.py b/lmdeploy/serve/managers/session_manager.py new file mode 100644 index 0000000000..2a4ec57f07 --- /dev/null +++ b/lmdeploy/serve/managers/session_manager.py @@ -0,0 +1,235 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + +import asyncio +import itertools +import weakref +from contextlib import asynccontextmanager +from typing import Any, List, Tuple + +from lmdeploy.messages import GenerationConfig, Response +from lmdeploy.serve.core.exceptions import SafeRunException +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + + +class Session: + """Session for the engine.""" + + def __init__(self, session_id: int, session_mgr: SessionManager, **kwargs): + self.session_id = session_id + self.prompt: Any = None + self.response: Response | None = None + self.history: List[Tuple[Any, str]] = [] + self.gen_config: GenerationConfig | None = None + self.step: int = 0 + # event to wait for the session to be active + self._active: asyncio.Event | None = None + self._handle = None # inference instance + self._session_mgr: SessionManager = weakref.ref(session_mgr) + self.update(**kwargs) + + def update(self, **kwargs): + """Update the session.""" + self.prompt = kwargs.get('prompt', self.prompt) + self.gen_config = kwargs.get('gen_config', self.gen_config) + self.step = kwargs.get('step', self.step) + + def __repr__(self) -> str: + """Return a string representation of the Session object.""" + return (f'Session(session_id={self.session_id}, ' + f'step={self.step}, history_len={len(self.history)}, ' + f'has_response={self.response is not None}, ' + f'has_gen_config={self.gen_config is not None})') + + def __str__(self) -> str: + """Return a human-readable string representation of the Session.""" + res = f'Session(id={self.session_id}, step={self.step})' + if self.history: + res += '\nHistory:\n' + for user, assistant in self.history: + if isinstance(user, list): + user = str(user) + res += f'USER: \n{user}\nASSISTANT: \n{assistant}\n' + return res + + def reset(self): + """Reset the session to initial state. + + This method resets all session data (prompt, response, history, etc.) but keeps the session_id. + """ + self.prompt = None + self.response = None + self.history = [] + self.gen_config = None + self.step = 0 + self._active = None + self._handle = None + self._session_mgr = None + logger.debug(f'Session {self.session_id} has been reset.') + + @asynccontextmanager + async def request_handle(self): + if self._handle is not None: + raise RuntimeError(f'Session {self.session_id} already has an inference instance.') + logger.debug(f'[acquire_request_handle] session {self.session_id} acquiring an instance') + + hnd_pool = self._session_mgr().request_handle_pool + self._handle = await hnd_pool.get() + self._active = asyncio.Event() + logger.debug(f'[acquire_request_handle] session {self.session_id} acquired an instance') + try: + yield self._handle + except SafeRunException: + await self._handle.async_end(self.session_id) + except Exception as e: + logger.error(f'Session {self.session_id} failed to acquire an inference instance: {e}') + raise e + finally: + logger.debug(f'[acquire_request_handle] session {self.session_id} releasing the instance') + # Return inference instance if it was acquired + if self._handle is not None: + hnd_pool.put(self._handle) + self._handle = None + # MUST set the signal after releasing the instance to avoid race condition + # refer to async_end method + self._active.set() + + async def async_abort(self): + """Abort the session.""" + logger.debug(f'Aborting session {self.session_id}') + if self._handle is not None: + await self._handle.async_cancel(self.session_id) + # DO NOT reset the session here because it might be used by other components. + # Leave the cleanup to the caller. + + async def async_close(self): + """End the session.""" + logger.debug(f'Ending session {self.session_id}') + if self._handle is not None: + await self._active.wait() + async with self.request_handle() as handle: + try: + await handle.async_end(self.session_id) + except (Exception, asyncio.CancelledError, GeneratorExit) as e: + logger.error(f'[async_end] exception caught: {e}') + self.reset() + + def abort(self): + """Abort the session in sync mode.""" + self._run(self.async_abort()) + + def close(self): + """End the session in sync mode.""" + self._run(self.async_close()) + + def _run(self, coro): + return asyncio.run_coroutine_threadsafe(coro, self._session_mgr().loop) + + +class RequestHandlePool: + """Manages a pool of request handles for concurrent request processing. + + This class maintains a fixed-size pool of request handles that can be reused + across multiple inference requests. It implements a lazy-initialized queue-based + pool pattern to efficiently manage handle lifecycle and enable concurrent + request handling. + + Each session or request should acquire a handle from the pool before inference and + return it after completion. The manager supports: + - Pool-based handle allocation and deallocation + - Lazy initialization of the async queue (required for asyncio.Queue) + - Handle rebuilding after engine wakeup (e.g., turbomind backend) + - Complete pool cleanup + + Args: + engine (AsyncEngine): The async inference engine that creates handles. + size (int): The size of the handle pool, typically set to max_batch_size. + + Note: + The pool queue is lazily initialized on first access via `get()` method, + as `asyncio.Queue` must be created within an async context. + """ + + def __init__(self, engine, size: int): + self.size = size + self.handles = [engine.create_instance() for _ in range(size)] + # `asyncio.Queue` must be created in an async context, refer to `get` method + self.pool: asyncio.Queue = None + + async def get(self): + """Get a handle from pool.""" + # Lazy initialization: create pool on first use + if self.pool is None: + self.pool = asyncio.Queue() + for inst in self.handles: + self.pool.put_nowait(inst) + + return await self.pool.get() + + def put(self, handle): + """Put a handle back to the pool.""" + if handle is not None and self.pool is not None: + self.pool.put_nowait(handle) + + def clear(self): + """Clear all handles.""" + self.handles = [] + self.pool = None + + +class SessionManager: + """Session manager.""" + + def __init__(self): + """Initialize the session manager.""" + + self.sessions = {} + self.session_id_generator = itertools.count(1) + self.request_handle_pool = None + self.loop = None + + def get(self, session_id: int | None = None, **kwargs) -> Session: + """Create a new session.""" + session_id = session_id or next(self.session_id_generator) + if session_id in self.sessions: + logger.debug(f'[SessionManager] session {session_id} already exists. Updating...') + session = self.sessions[session_id] + session.update(**kwargs) + return session + else: + logger.info(f'[SessionManager] session {session_id} not found. Creating...') + session = Session(session_id, self, **kwargs) + self.sessions[session_id] = session + return session + + async def async_abort_all(self): + """Abort all sessions.""" + tasks = [] + for session in list(self.sessions.values()): + tasks.append(session.async_abort()) + await asyncio.gather(*tasks, return_exceptions=True) + # "abort all" is designed for async RL. The aborted sessions will be no longer used, + # so we reset and clear the sessions here. + for session in list(self.sessions.values()): + session.reset() + self.sessions.clear() + + def has(self, session_id): + return session_id in self.sessions + + def remove(self, session: Session): + self.sessions.pop(session.session_id) + + def clear(self): + self.sessions.clear() + # reset the session id generator + self.session_id_generator = itertools.count(1) + + def attach_event_loop(self, loop): + self.loop = loop + + def build_request_handle_pool(self, engine, size): + """Build the request handle's pool.""" + self.request_handle_pool = RequestHandlePool(engine, size) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 3716a038de..8f7493909e 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -30,7 +30,7 @@ from lmdeploy.pytorch.disagg.conn.protocol import (DistServeCacheFreeRequest, DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest, MigrationRequest) -from lmdeploy.serve.async_engine import AsyncEngine +from lmdeploy.serve.core import AsyncEngine from lmdeploy.serve.openai.harmony_utils import GptOssChatParser from lmdeploy.serve.openai.protocol import ChatCompletionResponse # noqa: E501 from lmdeploy.serve.openai.protocol import (AbortRequest, ChatCompletionRequest, ChatCompletionResponseChoice, @@ -55,7 +55,6 @@ class VariableInterface: """A IO interface maintaining variables.""" async_engine: AsyncEngine = None - session_id: int = 0 api_keys: Optional[List[str]] = None request_hosts = [] # following are for registering to proxy server @@ -68,9 +67,26 @@ class VariableInterface: allow_terminate_by_client: bool = False enable_abort_handling: bool = False + @staticmethod + def get_session(session_id: int) -> int: + session_mgr = VariableInterface.get_session_manager() + if session_id == -1: + return session_mgr.get() + else: + return session_mgr.get(session_id) + + @staticmethod + def get_session_manager(): + return VariableInterface.async_engine.session_mgr + + @staticmethod + def get_engine_config(): + return VariableInterface.async_engine.backend_config + router = APIRouter() get_bearer_token = HTTPBearer(auto_error=False) +server_context = VariableInterface() async def check_api_key(auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), ) -> str: @@ -146,12 +162,12 @@ def check_request(request) -> Optional[JSONResponse]: check_func = check_request else: # Define an async function that always returns success - def always_success(req, backend_config): + def always_success(req, server_context): return '' check_func = always_success - error_msg = check_func(request, VariableInterface.async_engine.backend_config) + error_msg = check_func(request, server_context) if error_msg: return create_error_response(HTTPStatus.BAD_REQUEST, error_msg) return None @@ -383,6 +399,11 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque - **presence_penalty** (replaced with repetition_penalty) - **frequency_penalty** (replaced with repetition_penalty) """ + error_check_ret = check_request(request) + if error_check_ret is not None: + return error_check_ret + session = VariableInterface.get_session(request.session_id) + json_request = await raw_request.json() migration_request = json_request.pop('migration_request', None) with_cache = json_request.pop('with_cache', False) @@ -390,20 +411,11 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque if migration_request: migration_request = MigrationRequest.model_validate(migration_request) - if request.session_id == -1: - VariableInterface.session_id += 1 - request.session_id = VariableInterface.session_id - error_check_ret = check_request(request) - if error_check_ret is not None: - return error_check_ret - if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: - return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id {request.session_id!r} is occupied.') - model_name = request.model adapter_name = None if model_name != VariableInterface.async_engine.model_name: adapter_name = model_name # got a adapter name - request_id = str(request.session_id) + request_id = str(session.session_id) created_time = int(time.time()) gpt_oss_parser = None if VariableInterface.async_engine.arch == 'GptOssForCausalLM': @@ -486,7 +498,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque enable_thinking = chat_template_kwargs.get('enable_thinking', None) result_generator = VariableInterface.async_engine.generate( request.messages, - request.session_id, + session, gen_config=gen_config, tools=tools, reasoning_effort=request.reasoning_effort, @@ -533,7 +545,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: res.logprobs) # Only stream chunk `usage` in the final chunk according to OpenAI API spec if (res.finish_reason and request.stream_options and request.stream_options.include_usage): - total_tokens = sum([res.history_token_len, res.input_token_len, res.generate_token_len]) + total_tokens = sum([res.input_token_len, res.generate_token_len]) usage = UsageInfo( prompt_tokens=res.input_token_len, completion_tokens=res.generate_token_len, @@ -610,7 +622,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await VariableInterface.async_engine.stop_session(request.session_id) + await session.async_abort() return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res text += res.response @@ -671,7 +683,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: cache_block_ids = cache_block_ids[0] remote_token_ids = [remote_token_ids[0][-1]] - total_tokens = sum([final_res.history_token_len, final_res.input_token_len, final_res.generate_token_len]) + total_tokens = sum([final_res.input_token_len, final_res.generate_token_len]) usage = UsageInfo( prompt_tokens=final_res.input_token_len, completion_tokens=final_res.generate_token_len, @@ -744,6 +756,10 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None - **presence_penalty** (replaced with repetition_penalty) - **frequency_penalty** (replaced with repetition_penalty) """ + error_check_ret = check_request(request) + if error_check_ret is not None: + return error_check_ret + json_request = await raw_request.json() migration_request = json_request.pop('migration_request', None) with_cache = json_request.pop('with_cache', False) @@ -751,23 +767,19 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None if migration_request: migration_request = MigrationRequest.model_validate(migration_request) - if request.session_id == -1: - VariableInterface.session_id += 1 - request.session_id = VariableInterface.session_id - error_check_ret = check_request(request) - if error_check_ret is not None: - return error_check_ret - if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: - return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id {request.session_id!r} is occupied.') - model_name = request.model adapter_name = None if model_name != VariableInterface.async_engine.model_name: adapter_name = model_name # got a adapter name request_id = str(request.session_id) created_time = int(time.time()) + sessions = [] if isinstance(request.prompt, str): request.prompt = [request.prompt] + sessions.append(VariableInterface.get_session(request.session_id)) + elif isinstance(request.prompt, list): + for i in range(len(request.prompt)): + sessions.append(VariableInterface.get_session()) if isinstance(request.stop, str): request.stop = [request.stop] random_seed = request.seed if request.seed else None @@ -792,10 +804,10 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None preserve_cache=preserve_cache, ) generators = [] - for i in range(len(request.prompt)): + for prompt, session in zip(request.prompt, sessions): result_generator = VariableInterface.async_engine.generate( - request.prompt[i], - request.session_id + i, + prompt, + session, gen_config=gen_config, stream_response=True, # always use stream to enable batching sequence_start=True, @@ -842,8 +854,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # Only stream chunk `usage` in the final chunk according to OpenAI API spec if (res.finish_reason and request.stream_options and request.stream_options.include_usage): final_res = res - total_tokens = sum( - [final_res.history_token_len, final_res.input_token_len, final_res.generate_token_len]) + total_tokens = sum([final_res.input_token_len, final_res.generate_token_len]) usage = UsageInfo( prompt_tokens=final_res.input_token_len, completion_tokens=final_res.generate_token_len, @@ -915,7 +926,7 @@ async def _inner_call(i, generator): cache_block_ids = cache_block_ids[0] remote_token_ids = [remote_token_ids[0][-1]] - total_tokens = sum([final_res.history_token_len, final_res.input_token_len, final_res.generate_token_len]) + total_tokens = sum([final_res.input_token_len, final_res.generate_token_len]) usage.prompt_tokens += final_res.input_token_len usage.completion_tokens += final_res.generate_token_len usage.total_tokens += total_tokens @@ -939,14 +950,10 @@ async def _inner_call(i, generator): @router.post('/generate', dependencies=[Depends(check_api_key)]) async def generate(request: GenerateReqInput, raw_request: Request = None): - if request.session_id == -1: - VariableInterface.session_id += 1 - request.session_id = VariableInterface.session_id error_check_ret = check_request(request) if error_check_ret is not None: return error_check_ret - if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: - return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id {request.session_id!r} is occupied.') + session = VariableInterface.get_session(request.session_id) prompt = request.prompt input_ids = request.input_ids @@ -985,7 +992,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): result_generator = VariableInterface.async_engine.generate( messages=prompt, - session_id=request.session_id, + session_id=session, input_ids=input_ids, gen_config=gen_config, stream_response=True, # always use stream to enable batching @@ -1036,7 +1043,7 @@ async def _inner_call(): async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await VariableInterface.async_engine.stop_session(request.session_id) + await session.async_abort() return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') text += res.response or '' output_ids.extend(res.token_ids or []) @@ -1129,7 +1136,7 @@ async def pooling(request: PoolingRequest, raw_request: Request = None): else: return create_error_response(HTTPStatus.BAD_REQUEST, 'Input must be a string or a list.') - batch_scores = await async_engine._async_get_reward_score(input_ids) + batch_scores = await async_engine.async_get_reward_score(input_ids) prompt_tokens = sum(len(ids) for ids in input_ids) usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens) @@ -1228,7 +1235,8 @@ async def abort_request(request: AbortRequest, raw_request: Request = None): if request.abort_all: await VariableInterface.async_engine.stop_all_session() else: - await VariableInterface.async_engine.stop_session(request.session_id) + session = VariableInterface.get_session(request.session_id) + await session.async_abort() return Response(status_code=200) @@ -1255,7 +1263,7 @@ def dummy_get_device_id(): @router.on_event('startup') async def startup_event(): async_engine = VariableInterface.async_engine - async_engine.start_loop(use_async_api=True) + async_engine.start_loop(asyncio.get_running_loop(), use_async_api=True) if VariableInterface.proxy_url is None: return diff --git a/lmdeploy/serve/openai/serving_chat_completion.py b/lmdeploy/serve/openai/serving_chat_completion.py index 6d2afec789..415b2afeb7 100644 --- a/lmdeploy/serve/openai/serving_chat_completion.py +++ b/lmdeploy/serve/openai/serving_chat_completion.py @@ -4,15 +4,14 @@ from .protocol import ChatCompletionRequest if TYPE_CHECKING: - from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig + from .api_server import VariableInterface -def check_request(request: ChatCompletionRequest, engine_config: 'TurbomindEngineConfig | PytorchEngineConfig') -> str: - if not isinstance(request, ChatCompletionRequest): - raise TypeError(f'Invalid request type, expected ChatCompletionRequest, got {type(request)}') - - # Check logprobs settings +def check_request(request: ChatCompletionRequest, server_context: 'VariableInterface') -> str: + engine_config = server_context.get_engine_config() + session_manager = server_context.get_session_manager() try: + # Check logprobs settings logprobs_mode = engine_config.logprobs_mode logprobs = request.logprobs top_logprobs = request.top_logprobs or 0 @@ -25,6 +24,9 @@ def check_request(request: ChatCompletionRequest, engine_config: 'TurbomindEngin except AttributeError: pass + if session_manager.has(request.session_id): + return f'The session_id {request.session_id!r} is occupied.' + # check sampling settings if request.n <= 0: return f'The n {request.n!r} must be a positive int.' diff --git a/lmdeploy/serve/openai/serving_completion.py b/lmdeploy/serve/openai/serving_completion.py index 76339c8cb4..759972db36 100644 --- a/lmdeploy/serve/openai/serving_completion.py +++ b/lmdeploy/serve/openai/serving_completion.py @@ -4,15 +4,14 @@ from .protocol import CompletionRequest if TYPE_CHECKING: - from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig + from .api_server import VariableInterface -def check_request(request: CompletionRequest, engine_config: 'TurbomindEngineConfig | PytorchEngineConfig') -> str: - if not isinstance(request, CompletionRequest): - raise TypeError(f'Invalid request type, expected CompletionRequest, got {type(request)}') - - # Check logprobs settings +def check_request(request: CompletionRequest, server_context: 'VariableInterface') -> str: + engine_config = server_context.get_engine_config() + session_manager = server_context.get_session_manager() try: + # Check logprobs settings logprobs_mode = engine_config.logprobs_mode logprobs = request.logprobs or 0 if logprobs > 0 and logprobs_mode is None: @@ -22,6 +21,9 @@ def check_request(request: CompletionRequest, engine_config: 'TurbomindEngineCon except AttributeError: pass + if session_manager.has(request.session_id): + return f'The session_id {request.session_id!r} is occupied.' + # check sampling settings if request.n <= 0: return f'The n {request.n!r} must be a positive int.' diff --git a/lmdeploy/serve/openai/serving_generate.py b/lmdeploy/serve/openai/serving_generate.py index f6e4b6f85e..4615d9caea 100644 --- a/lmdeploy/serve/openai/serving_generate.py +++ b/lmdeploy/serve/openai/serving_generate.py @@ -4,15 +4,14 @@ from .protocol import GenerateReqInput if TYPE_CHECKING: - from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig + from .api_server import VariableInterface -def check_request(request: GenerateReqInput, engine_config: 'TurbomindEngineConfig | PytorchEngineConfig') -> str: - if not isinstance(request, GenerateReqInput): - raise TypeError(f'Invalid request type, expected GenerateReqInput, got {type(request)}') - - # Check logprobs settings +def check_request(request: GenerateReqInput, server_context: 'VariableInterface') -> str: + engine_config = server_context.get_engine_config() + session_manager = server_context.get_session_manager() try: + # Check logprobs settings logprobs_mode = engine_config.logprobs_mode return_logprob = request.return_logprob if logprobs_mode is None and return_logprob: @@ -32,6 +31,9 @@ def check_request(request: GenerateReqInput, engine_config: 'TurbomindEngineConf if request.max_tokens is not None and request.max_tokens <= 0: return f'The max_tokens {request.max_tokens!r} must be a positive integer.' + if session_manager.has(request.session_id): + return f'The session_id {request.session_id!r} is occupied.' + # check sampling settings if not (0 < request.top_p <= 1): return f'The top_p {request.top_p!r} must be in (0, 1].' diff --git a/lmdeploy/serve/processors/__init__.py b/lmdeploy/serve/processors/__init__.py new file mode 100644 index 0000000000..7097f7ca91 --- /dev/null +++ b/lmdeploy/serve/processors/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .multimodal import MultimodalProcessor + +__all__ = ['MultimodalProcessor'] diff --git a/lmdeploy/serve/multimodal_processor.py b/lmdeploy/serve/processors/multimodal.py similarity index 76% rename from lmdeploy/serve/multimodal_processor.py rename to lmdeploy/serve/processors/multimodal.py index 5aef724953..2bc9b7e80e 100644 --- a/lmdeploy/serve/multimodal_processor.py +++ b/lmdeploy/serve/processors/multimodal.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Tuple + +import PIL -from lmdeploy import Tokenizer from lmdeploy.model import MODELS, BaseChatTemplate +from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger logger = get_logger('lmdeploy') @@ -17,7 +19,7 @@ def __init__(self, tokenizer: Tokenizer, chat_template: BaseChatTemplate, vl_encoder=None, - backend: Optional[str] = None): + backend: str | None = None): """Initialize MultimodalProcessor. Args: @@ -180,14 +182,14 @@ def _inner_call(i, in_messages, out_messages): return out_messages async def get_prompt_input(self, - prompt: Union[str, List[Dict]], + prompt: str | List[Dict], do_preprocess: bool, sequence_start: bool, adapter_name: str, - tools: Optional[List[object]] = None, - reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None, - chat_template_kwargs: Optional[Dict] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, + tools: List[object] | None = None, + reasoning_effort: Literal['low', 'medium', 'high'] | None = None, + chat_template_kwargs: Dict | None = None, + mm_processor_kwargs: Dict[str, Any] | None = None, **kwargs): """Process prompt and return prompt string and input_ids. @@ -248,6 +250,83 @@ async def get_prompt_input(self, else: raise RuntimeError(f'unsupported prompt type: {type(prompt)}') + @staticmethod + def format_prompts(prompts: Any) -> List[Dict]: + """Format prompts.""" + if not isinstance(prompts, list): + prompts = [prompts] + # str or batch of str + if all(isinstance(prompt, str) for prompt in prompts): + return prompts + if (MultimodalProcessor._is_openai_message(prompts) + or all(MultimodalProcessor._is_openai_message(prompt) for prompt in prompts)): + return prompts + if all(MultimodalProcessor._is_str_images_pair(prompt) for prompt in prompts): + # batch of (prompt, image or [images]) or (image or [images], prompt) -> + # [[openai_gpt4v_message], [openai_gpt4v_message], ...] + return [[MultimodalProcessor._re_format_prompt_images_pair(prompt)] for prompt in prompts] + raise ValueError(f'Unsupported prompts: {prompts}. Only support str, openai message format, ' + 'or (prompt, image or [images]) or (image or [images], prompt) pair.') + + @staticmethod + def _is_openai_message(message) -> bool: + """Check if the message conforms to openai message format.""" + return isinstance(message, list) and all(isinstance(msg, dict) for msg in message) + + @staticmethod + def _is_str_images_pair(message) -> bool: + """Check if the message is a (prompt, image or [images]) or (image or + [images], prompt) pair.""" + if not (isinstance(message, tuple) and len(message) == 2): + return False + _1, _2 = message + if MultimodalProcessor._is_image(_1) or MultimodalProcessor._is_image_list(_1): + _1, _2 = _2, _1 + return isinstance(_1, str) and (MultimodalProcessor._is_image(_2) or MultimodalProcessor._is_image_list(_2)) + + @staticmethod + def _is_image(obj) -> bool: + # image or image url or base64-encoded image data + return (isinstance(obj, PIL.Image.Image) + or isinstance(obj, str) and (obj.startswith('http') or obj.startswith('data:image'))) + + @staticmethod + def _is_image_list(obj) -> bool: + return isinstance(obj, list) and all(MultimodalProcessor._is_image(img) for img in obj) + + @staticmethod + def _re_format_prompt_images_pair(prompt: Tuple) -> Dict: + """Reformat the prompt to openai message format.""" + from lmdeploy.vl.utils import load_image + + messages = {'role': 'user', 'content': []} + prompt, images = prompt + prompt_first = True + if MultimodalProcessor._is_image(prompt) or MultimodalProcessor._is_image_list(prompt): + prompt, images = images, prompt + prompt_first = False + image_contents = [] + images = images if isinstance(images, list) else [images] + for image in images: + # 'image_url': means url or local path to image. + # 'image_data': means PIL.Image.Image object. + if isinstance(image, str): + image = load_image(image) + item = {'type': 'image_data', 'image_data': {'data': image}} + elif isinstance(image, PIL.Image.Image): + item = {'type': 'image_data', 'image_data': {'data': image}} + else: + raise ValueError('image should be a str(url/path) or PIL.Image.Image') + image_contents.append(item) + + if prompt_first: + messages['content'].append({'type': 'text', 'text': prompt}) + messages['content'].extend(image_contents) + else: + messages['content'].extend(image_contents) + messages['content'].append({'type': 'text', 'text': prompt}) + return messages + def _has_multimodal_input(self, messages: List[Dict]) -> bool: """Check if messages contain multimodal input (images).""" return any( @@ -255,13 +334,13 @@ def _has_multimodal_input(self, messages: List[Dict]) -> bool: item.get('type') in ['image_url', 'image_data'] for item in message['content']) for message in messages) async def _get_text_prompt_input(self, - prompt: Union[str, List[Dict]], + prompt: str | List[Dict], do_preprocess: bool, sequence_start: bool, adapter_name: str, - tools: Optional[List[object]] = None, - reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None, - chat_template_kwargs: Optional[Dict] = None, + tools: List[object] | None = None, + reasoning_effort: Literal['low', 'medium', 'high'] | None = None, + chat_template_kwargs: Dict | None = None, **kwargs): """Process text-only prompt and return prompt string and input_ids.""" # Change multimodal data to openai text messages @@ -292,9 +371,9 @@ async def _get_multimodal_prompt_input(self, do_preprocess: bool, sequence_start: bool, adapter_name: str, - tools: Optional[List[object]] = None, - chat_template_kwargs: Optional[Dict] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, + tools: List[object] | None = None, + chat_template_kwargs: Dict | None = None, + mm_processor_kwargs: Dict[str, Any] | None = None, **kwargs): """Process multimodal prompt and return processed data for inference engines.""" diff --git a/lmdeploy/serve/utils.py b/lmdeploy/serve/utils.py deleted file mode 100644 index a7c2d49e8f..0000000000 --- a/lmdeploy/serve/utils.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import asyncio -from typing import Dict, List, Tuple, Union - -import numpy as np -import torch - -from lmdeploy.messages import GenerationConfig -from lmdeploy.utils import get_logger - -logger = get_logger('lmdeploy') - -InputIdsType = List[int] -InputEmbsType = Union[None, List[Union[torch.Tensor, np.ndarray]]] -InputEmbRngsType = Union[None, List[Tuple[int, int]]] -PromptType = Union[str, List[Dict]] - - -class LogitsMixin: - """Helper class to get logits, reward score and calculate ppl.""" - - def get_reward_score(self, input_ids: List) -> List[float]: - """ - Args: - input_ids(List): a list of token_id or a list of token_id list or a tensor containing - token_ids - Return: - reward score in a list. If the input_ids is a list of token_id, the return value - is still a list with length 1. - """ - supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel'] - if self.arch not in supported_reward_models: - raise ValueError(f'{self.arch} is not in reward model list: {supported_reward_models}') - assert isinstance(input_ids, List) - assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids) - # Make input_ids a list of token_id list - input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids - logits = self._run(coro=self._async_get_logits(input_ids=input_ids)).result() - logits = [x.squeeze() for x in logits] - scores = [x[-1].cpu().item() for x in logits] - return scores - - async def _async_get_reward_score(self, input_ids: List) -> List[float]: - """Async version of get_reward_score.""" - supported_reward_models = ['InternLM2ForRewardModel', 'Qwen2ForRewardModel'] - if self.arch not in supported_reward_models: - raise ValueError(f'{self.arch} is not in reward model list: {supported_reward_models}') - assert isinstance(input_ids, List) - assert all(isinstance(x, int) for x in input_ids) or all(isinstance(x, List) for x in input_ids) - # Make input_ids a list of token_id list - input_ids = [input_ids] if isinstance(input_ids[0], int) else input_ids - - logits = await self._async_get_logits(input_ids=input_ids) - - logits = [x.squeeze() for x in logits] - scores = [x[-1].cpu().item() for x in logits] - return scores - - async def _async_get_logits(self, - input_ids, - steps: List[int] = None, - sequence_start: bool = True, - sequence_end: bool = True) -> List[torch.Tensor]: - assert input_ids and all(isinstance(_, List) for _ in input_ids) - assert steps is None or (len(steps) == len(input_ids)) - - logits = [None] * len(input_ids) - - async def _proc(i): - async with self.model_inst(session_id=i) as inst: - input_len = len(input_ids[i]) - # TODO(lvhan): Fix the ugly code later on - max_new_tokens = 1 if self.backend == 'turbomind' else 0 - # The reason to set `top_k=1` is that pt engine crashes at top_k sampling stage - # when perform inference on a reward model. - gen_config = GenerationConfig(max_new_tokens=max_new_tokens, output_logits='all', top_k=1) - async with self.safe_run(inst, - session_id=i, - input_ids=input_ids[i], - gen_config=gen_config, - stream_output=False, - sequence_start=sequence_start, - sequence_end=sequence_end, - step=steps[i] if steps else 0) as gen: - async for outputs in gen: - pass - logits[i] = outputs.logits[:input_len, :] - - session_ids = list(range(len(input_ids))) - tasks = [_proc(i) for i in range(len(input_ids))] - await asyncio.gather(*tasks) - if sequence_end and self.backend == 'pytorch': - for session_id in session_ids: - await self.end_session(session_id) - return logits - - def get_ppl(self, input_ids: Union[List[int], List[List[int]]]) -> List[float]: - """Get perplexity scores given a list of input tokens that have to be - of the same length. - - Args: - input_ids (Union[List[int], List[List[int]]]): the batch of - input token ids - - Returns: - List[float]: A list of perplexity scores. - """ - assert isinstance(input_ids, List) - if isinstance(input_ids[0], int): - input_ids = [input_ids] - assert all(len(_) > 1 for _ in input_ids) - - # TODO: a better way to determine `max_input_len`, at most allocate - # 2G mem for logits with shape [bs, max_input_len, vocab_size] - vocab_size = self.hf_cfg.vocab_size - max_input_len = 2 * 1024**3 // (vocab_size * 4) - sizes = [len(_) for _ in input_ids] - result = [] - sorted_index_values = sorted(list(enumerate(sizes)), key=lambda x: x[1], reverse=True) - sizes = [value for index, value in sorted_index_values] - indices = [index for index, value in sorted_index_values] - logger.info(f'sorted sizes: {sizes}') - logger.info(f'sorted indices: {indices}') - for (start, end) in self._batch_iterator(sizes, max_input_len): - logger.info(f'start: {start}, end: {end}') - if start == end: - _input_ids = input_ids[indices[start]] - res = self._get_long_text_ppl(input_ids=_input_ids, max_input_len=max_input_len) - result.append(res) - else: - _input_ids = [input_ids[indices[i]] for i in range(start, end)] - res = self._get_ppl( - input_ids=_input_ids, - max_input_len=max_input_len, - ) - result.extend(res) - output = list(range(len(result))) - for index, sorted_index in enumerate(indices): - output[sorted_index] = result[index] - return output - - def _batch_iterator(self, sizes, max_value): - """Return an iterator that calculates intervals (start, end) of a - descend-order list, in which the sum of values in the range is the - maximum number not less than max_value. By "the sum of values", - - here it means $$len(sizes[start:end]) * sizes[start]$$ - """ - i = 0 - while i < len(sizes): - current_sum = 0 - start_index = i - - while i < len(sizes) and current_sum + sizes[start_index] <= max_value: - current_sum += sizes[start_index] - i += 1 - - yield (start_index, i) - if i > start_index: - continue - else: - i += 1 - - def _get_long_text_ppl(self, input_ids, max_input_len): - assert all(isinstance(_, int) for _ in input_ids) - seq_len = len(input_ids) - assert seq_len > max_input_len - logger.info(f'get long text ppl: seq_len {seq_len}') - - losses = [] - target_counts = [] - for i in range(0, seq_len, max_input_len): - token_ids = input_ids[i:i + max_input_len] - step = [i] - # shift token_ids by 1 to the left - target_ids = input_ids[i + 1:i + 1 + max_input_len] - loss = self._get_ppl(input_ids=[token_ids], - max_input_len=len(token_ids), - target_ids=[target_ids], - steps=step, - sequence_start=(i == 0), - sequence_end=False) - losses.extend(loss) - target_counts.append(len(target_ids)) - losses = [loss * target_count for loss, target_count in zip(losses, target_counts)] - loss_sum = sum(losses) - target_count = sum(target_counts) - return loss_sum / target_count - - def _get_ppl(self, - input_ids, - max_input_len, - target_ids=None, - steps=None, - sequence_start: bool = True, - sequence_end: bool = True): - assert (isinstance(input_ids, List) and all(isinstance(_, List) for _ in input_ids)) - assert steps is None or len(steps) == len(input_ids) - assert target_ids is None or len(target_ids) == len(input_ids) - - lens = [len(_) for _ in input_ids] - total_len = sum(lens) - assert sum(lens) <= max_input_len - - logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, ' - f'total_len: {total_len}, steps: {steps}') - torch.cuda.empty_cache() - - logits = self._run(coro=self._async_get_logits( - input_ids=input_ids, steps=steps, sequence_start=sequence_start, sequence_end=sequence_end)).result() - padding_token_id = -100 - if target_ids is None: - target_ids = [x[1:] + [padding_token_id] for x in input_ids] - else: - target_ids = [ - target_ids[i] + [padding_token_id] if len(target_ids[i]) < len(input_ids[i]) else target_ids[i] - for i in range(len(input_ids)) - ] - target_ids = [torch.Tensor(torch.LongTensor(_target_ids)) for _target_ids in target_ids] - - result = [] - for _logits, _target_ids in zip(logits, target_ids): - _logits = _logits.float() - vocab_size = _logits.shape[-1] - _target_ids = _target_ids.to(_logits.device) - target_mask = _target_ids != padding_token_id - # compute cross entropy loss - flat_logits = _logits.contiguous().view(-1, vocab_size) - flat_target_ids = _target_ids.contiguous().view(-1) - flat_loss_matrix = torch.nn.functional.cross_entropy(flat_logits, - flat_target_ids, - reduction='none', - ignore_index=padding_token_id) - loss = flat_loss_matrix.sum() - target_count = target_mask.sum() - result.append(loss.item() / target_count.item()) - logger.info(f'ppl result: {result}') - return result diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py deleted file mode 100644 index 7009385f36..0000000000 --- a/lmdeploy/serve/vl_async_engine.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Literal, Optional, Tuple, Union - -import PIL - -from lmdeploy.messages import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig -from lmdeploy.serve.async_engine import AsyncEngine -from lmdeploy.serve.multimodal_processor import MultimodalProcessor -from lmdeploy.utils import get_logger, try_import_deeplink -from lmdeploy.vl.engine import ImageEncoder -from lmdeploy.vl.utils import load_image - -logger = get_logger('lmdeploy') - -VLPromptType = Union[str, Tuple[str, PIL.Image.Image], Tuple[str, List[PIL.Image.Image]]] - - -class VLAsyncEngine(AsyncEngine): - """Visual Language Async inference engine.""" - - def __init__(self, - model_path: str, - backend: Literal['turbomind', 'pytorch'] = 'turbomind', - backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None, - vision_config: Optional[VisionConfig] = None, - **kwargs) -> None: - if backend == 'pytorch': - try_import_deeplink(backend_config.device_type) - if backend_config and backend_config.enable_prefix_caching: - backend_config.enable_prefix_caching = False - logger.warning('Prefix caching is disabled since LMDeploy hasn\'t support in on VL models yet') - self.vl_encoder = ImageEncoder(model_path, backend, vision_config, backend_config=backend_config) - super().__init__(model_path, backend=backend, backend_config=backend_config, **kwargs) - # Update prompt_processor to support multimodal processing - self.prompt_processor = MultimodalProcessor(self.tokenizer, - self.chat_template, - vl_encoder=self.vl_encoder, - backend=backend) - if self.model_name == 'base': - raise RuntimeError( - 'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501 - ) - - @classmethod - def _convert_prompts(cls, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]]): - """Convert prompts to openai GPT4V format.""" - if isinstance(prompts, str) or isinstance(prompts, tuple): - _prompts = cls.prompt_to_messages(prompts) - elif isinstance(prompts[0], tuple) or isinstance(prompts[0], str): - _prompts = [cls.prompt_to_messages(x) for x in prompts] - else: - _prompts = prompts - return _prompts - - @classmethod - async def async_convert_to_pil_images(cls, messages: List[Dict]) -> List[Dict]: - """Convert messages to PIL images. - - Delegates to MultimodalProcessor. - """ - return await MultimodalProcessor.async_convert_to_pil_images(messages) - - def batch_infer(self, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]], *args, - **kwargs): - """Inference a batch of prompts.""" - prompts = self._convert_prompts(prompts) - return super().batch_infer(prompts, *args, **kwargs) - - def stream_infer(self, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]], *args, - **kwargs): - """Inference a batch of prompts with stream mode.""" - prompts = self._convert_prompts(prompts) - return super().stream_infer(prompts, *args, **kwargs) - - def __call__(self, prompts: Union[VLPromptType, List[Dict], List[VLPromptType], List[List[Dict]]], *args, **kwargs): - """Inference a batch of prompts.""" - return super().__call__(prompts, *args, **kwargs) - - def close(self): - if hasattr(self, 'vl_encoder'): - del self.vl_encoder - super().close() - - def chat(self, prompts: VLPromptType, *args, **kwargs): - """chat.""" - _prompts = self._convert_prompts(prompts) - sess = super().chat(_prompts, *args, **kwargs) - - # recover prompts & history - sess._prompt = prompts - if sess.history: - last_round = sess.history[-1] - sess.history[-1] = (prompts, last_round[-1]) - return sess - - @classmethod - def prompt_to_messages(cls, prompt: VLPromptType): - """Convert prompt to GTP4V format.""" - messages = { - 'role': 'user', - 'content': [{ - 'type': 'text', - 'text': '', - }] - } - if isinstance(prompt, str): - messages['content'][0]['text'] = prompt - else: - prompt, images = prompt - if not isinstance(images, list): - images = [images] - messages['content'][0]['text'] = prompt - for image in images: - # 'image_url': means url or local path to image. - # 'image_data': means PIL.Image.Image object. - if isinstance(image, str): - image = load_image(image) - item = {'type': 'image_data', 'image_data': {'data': image}} - elif isinstance(image, PIL.Image.Image): - item = {'type': 'image_data', 'image_data': {'data': image}} - else: - raise ValueError('image should be a str(url/path) or PIL.Image.Image') - - messages['content'].append(item) - - return [messages] diff --git a/requirements/readthedocs.txt b/requirements/readthedocs.txt index 7f975b99cb..db477834b6 100644 --- a/requirements/readthedocs.txt +++ b/requirements/readthedocs.txt @@ -5,6 +5,7 @@ mmengine-lite openai_harmony partial_json_parser pillow +pybase64 pydantic pyyaml shortuuid diff --git a/tests/test_lmdeploy/test_content_merge.py b/tests/test_lmdeploy/test_content_merge.py index bb1f26e593..8574505c9b 100644 --- a/tests/test_lmdeploy/test_content_merge.py +++ b/tests/test_lmdeploy/test_content_merge.py @@ -1,6 +1,6 @@ import pytest -from lmdeploy.serve.multimodal_processor import MultimodalProcessor +from lmdeploy.serve.processors import MultimodalProcessor class TestMergeMessageContent: