11# Copyright (c) OpenMMLab. All rights reserved.
22
33import asyncio
4- import atexit
54import concurrent .futures
65import dataclasses
76import random
87from contextlib import asynccontextmanager
98from copy import deepcopy
10- from functools import partial
11- from queue import Queue
12- from threading import Thread
13- from typing import Any , Dict , Iterator , List , Literal
9+ from typing import Any , Dict , List , Literal
1410
1511from lmdeploy .archs import get_model_arch
1612from lmdeploy .logger import RequestLogger
@@ -65,65 +61,6 @@ def to_response(self, index: int = 0) -> Response:
6561 index = index )
6662
6763
68- class _EventLoopThread :
69-
70- def __init__ (self , daemon = False ):
71- fut = concurrent .futures .Future ()
72- self .thread = Thread (target = partial (self ._thread_entry , fut ), daemon = daemon )
73- self .thread .start ()
74- self .loop : asyncio .AbstractEventLoop = fut .result ()
75- self .closed = False
76- if daemon :
77- atexit .register (self .close )
78-
79- def _thread_entry (self , fut ):
80- loop = asyncio .new_event_loop ()
81- asyncio .set_event_loop (loop )
82- fut .set_result (loop )
83- try :
84- loop .run_forever ()
85- except BaseException as e :
86- logger .error (f'[internal_thread] { type (e ).__name__ } { e } ' )
87- finally :
88- try :
89- self ._cancel_all_tasks ()
90- loop .run_until_complete (loop .shutdown_asyncgens ())
91- finally :
92- asyncio .set_event_loop (None )
93- loop .close ()
94-
95- def _cancel_all_tasks (self ):
96- """Modified from asyncio/runners.py."""
97- to_cancel = asyncio .all_tasks (self .loop )
98- if not to_cancel :
99- return
100-
101- for task in to_cancel :
102- task .cancel ()
103-
104- async def _gather ():
105- await asyncio .gather (* to_cancel , return_exceptions = True )
106-
107- self .loop .run_until_complete (_gather ())
108-
109- for task in to_cancel :
110- if task .cancelled ():
111- continue
112- if task .exception () is not None :
113- self .loop .call_exception_handler ({
114- 'message' : 'unhandled exception during worker thread shutdown' ,
115- 'exception' : task .exception (),
116- 'task' : task ,
117- })
118-
119- def close (self ):
120- if self .closed :
121- return
122- self .closed = True
123- self .loop .call_soon_threadsafe (self .loop .stop )
124- self .thread .join ()
125-
126-
12764class AsyncEngine (LogitsMixin ):
12865 """Async inference engine. Maintaining a bunch of tm_model instances.
12966
@@ -199,21 +136,18 @@ def __init__(self,
199136 self .stop_words = self .stop_words [0 ][0 ].tolist ()
200137 self .backend = backend
201138 self .request_logger = RequestLogger (max_log_len )
202- self .internal_thread = _EventLoopThread (daemon = True )
203- self .limiter : asyncio .Semaphore = None
139+
204140 self .num_spec_token = 0 if backend == 'turbomind' or speculative_config is None \
205141 else speculative_config .num_speculative_tokens
206142
207143 self .session_mgr = SessionManager ()
208- self .session_mgr .attach_event_loop (self .internal_thread .loop )
209144 self .session_mgr .build_request_handle_pool (self .engine , self .backend_config .max_batch_size )
210145
211146 # build stat loggers
212147 self ._build_stat_loggers ()
213148 self .epoch = 0
214149
215150 def close (self ):
216- self .internal_thread .close ()
217151 self .session_mgr .clear ()
218152 self .engine .close ()
219153
@@ -303,49 +237,6 @@ def wakeup(self, tags: List[str] | None = None):
303237 self .sleeping_tags = self .sleeping_tags - set (tags )
304238 self .is_sleeping = bool (self .sleeping_tags )
305239
306- def _get_limiter (self ):
307- if not self .limiter :
308- self .limiter = asyncio .Semaphore (self .backend_config .max_batch_size )
309- return self .limiter
310-
311- def _infer (self , requests : Iterator [Dict ], multiplex : bool , pbar = None , loop = None ) -> Iterator [Iterator [Response ]]:
312-
313- async def _sync_resp (g , que : Queue , idx : int , sem : asyncio .Semaphore ):
314- async for out in g :
315- que .put (out .to_response (idx ))
316- sem .release ()
317- if not multiplex :
318- que .put (None ) # sentinel of inner generator
319- if pbar :
320- pbar .update (1 )
321-
322- que = Queue ()
323-
324- async def _infer ():
325- sem = self ._get_limiter ()
326- tasks = []
327- for idx , req in enumerate (requests ):
328- await sem .acquire ()
329- gen = self .generate (** req )
330- dst = que if multiplex else Queue ()
331- if not multiplex :
332- que .put (iter (dst .get , None ))
333- # create a task to send the responses
334- task = asyncio .create_task (_sync_resp (gen , dst , idx , sem ))
335- tasks .append (task )
336- if not multiplex : # sentinel of outer generator
337- que .put (None )
338- await asyncio .gather (* tasks )
339- if multiplex :
340- que .put (None ) # sentinel of inner generator
341-
342- loop = loop or self .internal_thread .loop
343- # submit the coroutine to async world
344- asyncio .run_coroutine_threadsafe (_infer (),
345- loop ).add_done_callback (lambda f : None if f .cancelled () else f .result ())
346-
347- return iter (que .get , None )
348-
349240 def _determine_gen_config (self , session , input_ids , gen_config : GenerationConfig | None = None ) -> GenerationConfig :
350241 """Determine the generation configuration."""
351242 gen_config = deepcopy (gen_config ) or GenerationConfig ()
@@ -640,18 +531,7 @@ def is_error(status):
640531 # await session.async_close()
641532 # self.session_mgr.remove(session)
642533
643- def _run (self , fn = None , coro = None , loop = None ):
644- assert (fn or coro ) and not (fn and coro )
645- loop = loop or self .internal_thread .loop
646- if fn :
647-
648- async def _coro ():
649- return fn ()
650-
651- coro = _coro ()
652- return asyncio .run_coroutine_threadsafe (coro , loop )
653-
654- def start_loop (self , use_async_api = False ):
534+ def start_loop (self , loop , use_async_api = False ):
655535 """Start engine loop.
656536
657537 When using pytorch backend with dp > 1, all dp_rank should receive at least one request before it can start
@@ -661,6 +541,7 @@ def start_loop(self, use_async_api=False):
661541 The purpose of this function is to allow users to choose whether to use the synchronous interface or the
662542 asynchronous interface for the pipeline.
663543 """
544+ self .session_mgr .attach_event_loop (loop )
664545 if hasattr (self .engine , 'start_loop' ):
665546 if use_async_api :
666547 return self .engine .start_loop ()
@@ -671,7 +552,7 @@ def _start_loop(fut):
671552 res = self .engine .start_loop ()
672553 fut .set_result (res )
673554
674- self . internal_thread . loop .call_soon_threadsafe (_start_loop , fut )
555+ loop .call_soon_threadsafe (_start_loop , fut )
675556 return fut .result ()
676557 else :
677558 return True
0 commit comments