Skip to content

Commit a9c9590

Browse files
committed
move EventLoopThread from async_engine to pipeline
1 parent b3ce74f commit a9c9590

3 files changed

Lines changed: 133 additions & 129 deletions

File tree

lmdeploy/pipeline.py

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
import asyncio
3+
import atexit
4+
import concurrent.futures
25
import os
36
from contextlib import closing
7+
from functools import partial
8+
from queue import Queue
9+
from threading import Thread
410
from typing import TYPE_CHECKING, Dict, Iterator, List, Tuple
511

612
import tqdm
713
from typing_extensions import deprecated
814

915
from .archs import autoget_backend_config, get_task
10-
from .messages import GenerationConfig, PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig
16+
from .messages import GenerationConfig, PytorchEngineConfig, Response, SpeculativeConfig, TurbomindEngineConfig
1117
from .model import ChatTemplateConfig
1218
from .serve.processors import MultimodalProcessor
1319
from .utils import get_logger, get_model
@@ -67,8 +73,11 @@ def __init__(self,
6773
max_log_len=max_log_len,
6874
speculative_config=speculative_config,
6975
**kwargs)
76+
self.internal_thread = _EventLoopThread(daemon=True)
77+
self.limiter: asyncio.Semaphore = None
7078
self.session_mgr = self.async_engine.session_mgr
7179
self.backend_config = self.async_engine.backend_config
80+
self.async_engine.start_loop(self.internal_thread.loop, use_async_api=False)
7281

7382
def infer(self,
7483
prompts: List[str] | str | List[Dict] | List[List[Dict]] | Tuple | List[Tuple],
@@ -100,7 +109,7 @@ def infer(self,
100109
adapter_name=adapter_name,
101110
stream_response=False,
102111
**kwargs)
103-
for g in self.async_engine._infer(requests, multiplex=False, pbar=pbar):
112+
for g in self._infer(requests, multiplex=False, pbar=pbar):
104113
res = None
105114
for out in g:
106115
res = res.extend(out) if res else out
@@ -149,10 +158,11 @@ def stream_infer(self,
149158
adapter_name=adapter_name,
150159
stream_response=stream_response,
151160
**kwargs)
152-
return self.async_engine._infer(requests, multiplex=True)
161+
return self._infer(requests, multiplex=True)
153162

154163
def close(self):
155164
"""Close the pipeline."""
165+
self.internal_thread.close()
156166
self.async_engine.close()
157167

158168
def chat(self,
@@ -197,7 +207,7 @@ def _gen():
197207
resp = resp.extend(out) if resp else out
198208
yield out
199209
except: # noqa
200-
self.async_engine._run(coro=session.async_abort())
210+
self._run(coro=session.async_abort())
201211
raise
202212
else:
203213
session.response = resp
@@ -295,3 +305,116 @@ def _request_generator(self,
295305
# Since AsyncEngine.generate defines session_id in the argument lists, here we
296306
# use session_id to pass the session to the AsyncEngine.generate. It's
297307
yield dict(session_id=session, messages=prompt, gen_config=gen_cfg, **kwargs)
308+
309+
def _get_limiter(self):
310+
if not self.limiter:
311+
self.limiter = asyncio.Semaphore(self.backend_config.max_batch_size)
312+
return self.limiter
313+
314+
def _infer(self, requests: Iterator[Dict], multiplex: bool, pbar=None, loop=None) -> Iterator[Iterator[Response]]:
315+
316+
async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore):
317+
async for out in g:
318+
que.put(out.to_response(idx))
319+
sem.release()
320+
if not multiplex:
321+
que.put(None) # sentinel of inner generator
322+
if pbar:
323+
pbar.update(1)
324+
325+
que = Queue()
326+
327+
async def _infer():
328+
sem = self._get_limiter()
329+
tasks = []
330+
for idx, req in enumerate(requests):
331+
await sem.acquire()
332+
gen = self.async_engine.generate(**req)
333+
dst = que if multiplex else Queue()
334+
if not multiplex:
335+
que.put(iter(dst.get, None))
336+
# create a task to send the responses
337+
task = asyncio.create_task(_sync_resp(gen, dst, idx, sem))
338+
tasks.append(task)
339+
if not multiplex: # sentinel of outer generator
340+
que.put(None)
341+
await asyncio.gather(*tasks)
342+
if multiplex:
343+
que.put(None) # sentinel of inner generator
344+
345+
loop = loop or self.internal_thread.loop
346+
# submit the coroutine to async world
347+
asyncio.run_coroutine_threadsafe(_infer(),
348+
loop).add_done_callback(lambda f: None if f.cancelled() else f.result())
349+
350+
return iter(que.get, None)
351+
352+
def _run(self, fn=None, coro=None):
353+
assert (fn or coro) and not (fn and coro)
354+
loop = self.internal_thread.loop
355+
if fn:
356+
357+
async def _coro():
358+
return fn()
359+
360+
coro = _coro()
361+
return asyncio.run_coroutine_threadsafe(coro, loop)
362+
363+
364+
class _EventLoopThread:
365+
366+
def __init__(self, daemon=False):
367+
fut = concurrent.futures.Future()
368+
self.thread = Thread(target=partial(self._thread_entry, fut), daemon=daemon)
369+
self.thread.start()
370+
self.loop: asyncio.AbstractEventLoop = fut.result()
371+
self.closed = False
372+
if daemon:
373+
atexit.register(self.close)
374+
375+
def _thread_entry(self, fut):
376+
loop = asyncio.new_event_loop()
377+
asyncio.set_event_loop(loop)
378+
fut.set_result(loop)
379+
try:
380+
loop.run_forever()
381+
except BaseException as e:
382+
logger.error(f'[internal_thread] {type(e).__name__} {e}')
383+
finally:
384+
try:
385+
self._cancel_all_tasks()
386+
loop.run_until_complete(loop.shutdown_asyncgens())
387+
finally:
388+
asyncio.set_event_loop(None)
389+
loop.close()
390+
391+
def _cancel_all_tasks(self):
392+
"""Modified from asyncio/runners.py."""
393+
to_cancel = asyncio.all_tasks(self.loop)
394+
if not to_cancel:
395+
return
396+
397+
for task in to_cancel:
398+
task.cancel()
399+
400+
async def _gather():
401+
await asyncio.gather(*to_cancel, return_exceptions=True)
402+
403+
self.loop.run_until_complete(_gather())
404+
405+
for task in to_cancel:
406+
if task.cancelled():
407+
continue
408+
if task.exception() is not None:
409+
self.loop.call_exception_handler({
410+
'message': 'unhandled exception during worker thread shutdown',
411+
'exception': task.exception(),
412+
'task': task,
413+
})
414+
415+
def close(self):
416+
if self.closed:
417+
return
418+
self.closed = True
419+
self.loop.call_soon_threadsafe(self.loop.stop)
420+
self.thread.join()

lmdeploy/serve/core/async_engine.py

Lines changed: 5 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

33
import asyncio
4-
import atexit
54
import concurrent.futures
65
import dataclasses
76
import random
87
from contextlib import asynccontextmanager
98
from copy import deepcopy
10-
from functools import partial
11-
from queue import Queue
12-
from threading import Thread
13-
from typing import Any, Dict, Iterator, List, Literal
9+
from typing import Any, Dict, List, Literal
1410

1511
from lmdeploy.archs import get_model_arch
1612
from lmdeploy.logger import RequestLogger
@@ -65,65 +61,6 @@ def to_response(self, index: int = 0) -> Response:
6561
index=index)
6662

6763

68-
class _EventLoopThread:
69-
70-
def __init__(self, daemon=False):
71-
fut = concurrent.futures.Future()
72-
self.thread = Thread(target=partial(self._thread_entry, fut), daemon=daemon)
73-
self.thread.start()
74-
self.loop: asyncio.AbstractEventLoop = fut.result()
75-
self.closed = False
76-
if daemon:
77-
atexit.register(self.close)
78-
79-
def _thread_entry(self, fut):
80-
loop = asyncio.new_event_loop()
81-
asyncio.set_event_loop(loop)
82-
fut.set_result(loop)
83-
try:
84-
loop.run_forever()
85-
except BaseException as e:
86-
logger.error(f'[internal_thread] {type(e).__name__} {e}')
87-
finally:
88-
try:
89-
self._cancel_all_tasks()
90-
loop.run_until_complete(loop.shutdown_asyncgens())
91-
finally:
92-
asyncio.set_event_loop(None)
93-
loop.close()
94-
95-
def _cancel_all_tasks(self):
96-
"""Modified from asyncio/runners.py."""
97-
to_cancel = asyncio.all_tasks(self.loop)
98-
if not to_cancel:
99-
return
100-
101-
for task in to_cancel:
102-
task.cancel()
103-
104-
async def _gather():
105-
await asyncio.gather(*to_cancel, return_exceptions=True)
106-
107-
self.loop.run_until_complete(_gather())
108-
109-
for task in to_cancel:
110-
if task.cancelled():
111-
continue
112-
if task.exception() is not None:
113-
self.loop.call_exception_handler({
114-
'message': 'unhandled exception during worker thread shutdown',
115-
'exception': task.exception(),
116-
'task': task,
117-
})
118-
119-
def close(self):
120-
if self.closed:
121-
return
122-
self.closed = True
123-
self.loop.call_soon_threadsafe(self.loop.stop)
124-
self.thread.join()
125-
126-
12764
class AsyncEngine(LogitsMixin):
12865
"""Async inference engine. Maintaining a bunch of tm_model instances.
12966
@@ -199,21 +136,18 @@ def __init__(self,
199136
self.stop_words = self.stop_words[0][0].tolist()
200137
self.backend = backend
201138
self.request_logger = RequestLogger(max_log_len)
202-
self.internal_thread = _EventLoopThread(daemon=True)
203-
self.limiter: asyncio.Semaphore = None
139+
204140
self.num_spec_token = 0 if backend == 'turbomind' or speculative_config is None \
205141
else speculative_config.num_speculative_tokens
206142

207143
self.session_mgr = SessionManager()
208-
self.session_mgr.attach_event_loop(self.internal_thread.loop)
209144
self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size)
210145

211146
# build stat loggers
212147
self._build_stat_loggers()
213148
self.epoch = 0
214149

215150
def close(self):
216-
self.internal_thread.close()
217151
self.session_mgr.clear()
218152
self.engine.close()
219153

@@ -303,49 +237,6 @@ def wakeup(self, tags: List[str] | None = None):
303237
self.sleeping_tags = self.sleeping_tags - set(tags)
304238
self.is_sleeping = bool(self.sleeping_tags)
305239

306-
def _get_limiter(self):
307-
if not self.limiter:
308-
self.limiter = asyncio.Semaphore(self.backend_config.max_batch_size)
309-
return self.limiter
310-
311-
def _infer(self, requests: Iterator[Dict], multiplex: bool, pbar=None, loop=None) -> Iterator[Iterator[Response]]:
312-
313-
async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore):
314-
async for out in g:
315-
que.put(out.to_response(idx))
316-
sem.release()
317-
if not multiplex:
318-
que.put(None) # sentinel of inner generator
319-
if pbar:
320-
pbar.update(1)
321-
322-
que = Queue()
323-
324-
async def _infer():
325-
sem = self._get_limiter()
326-
tasks = []
327-
for idx, req in enumerate(requests):
328-
await sem.acquire()
329-
gen = self.generate(**req)
330-
dst = que if multiplex else Queue()
331-
if not multiplex:
332-
que.put(iter(dst.get, None))
333-
# create a task to send the responses
334-
task = asyncio.create_task(_sync_resp(gen, dst, idx, sem))
335-
tasks.append(task)
336-
if not multiplex: # sentinel of outer generator
337-
que.put(None)
338-
await asyncio.gather(*tasks)
339-
if multiplex:
340-
que.put(None) # sentinel of inner generator
341-
342-
loop = loop or self.internal_thread.loop
343-
# submit the coroutine to async world
344-
asyncio.run_coroutine_threadsafe(_infer(),
345-
loop).add_done_callback(lambda f: None if f.cancelled() else f.result())
346-
347-
return iter(que.get, None)
348-
349240
def _determine_gen_config(self, session, input_ids, gen_config: GenerationConfig | None = None) -> GenerationConfig:
350241
"""Determine the generation configuration."""
351242
gen_config = deepcopy(gen_config) or GenerationConfig()
@@ -640,18 +531,7 @@ def is_error(status):
640531
# await session.async_close()
641532
# self.session_mgr.remove(session)
642533

643-
def _run(self, fn=None, coro=None, loop=None):
644-
assert (fn or coro) and not (fn and coro)
645-
loop = loop or self.internal_thread.loop
646-
if fn:
647-
648-
async def _coro():
649-
return fn()
650-
651-
coro = _coro()
652-
return asyncio.run_coroutine_threadsafe(coro, loop)
653-
654-
def start_loop(self, use_async_api=False):
534+
def start_loop(self, loop, use_async_api=False):
655535
"""Start engine loop.
656536
657537
When using pytorch backend with dp > 1, all dp_rank should receive at least one request before it can start
@@ -661,6 +541,7 @@ def start_loop(self, use_async_api=False):
661541
The purpose of this function is to allow users to choose whether to use the synchronous interface or the
662542
asynchronous interface for the pipeline.
663543
"""
544+
self.session_mgr.attach_event_loop(loop)
664545
if hasattr(self.engine, 'start_loop'):
665546
if use_async_api:
666547
return self.engine.start_loop()
@@ -671,7 +552,7 @@ def _start_loop(fut):
671552
res = self.engine.start_loop()
672553
fut.set_result(res)
673554

674-
self.internal_thread.loop.call_soon_threadsafe(_start_loop, fut)
555+
loop.call_soon_threadsafe(_start_loop, fut)
675556
return fut.result()
676557
else:
677558
return True

lmdeploy/serve/openai/api_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1263,7 +1263,7 @@ def dummy_get_device_id():
12631263
@router.on_event('startup')
12641264
async def startup_event():
12651265
async_engine = VariableInterface.async_engine
1266-
async_engine.start_loop(use_async_api=True)
1266+
async_engine.start_loop(asyncio.get_running_loop(), use_async_api=True)
12671267

12681268
if VariableInterface.proxy_url is None:
12691269
return

0 commit comments

Comments
 (0)