Skip to content

Commit 46dae9b

Browse files
committed
Replace zibai with uvicorn
1 parent dc8c834 commit 46dae9b

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"wandb>=0.15.11",
2727
"grpcio>=1.58.0",
2828
"kui>=1.6.0",
29-
"zibai-server>=0.9.0",
29+
"uvicorn>=0.30.0",
3030
"loguru>=0.6.0",
3131
"loralib>=0.1.2",
3232
"natsort>=8.4.0",

tools/api.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@
1616
import pyrootutils
1717
import soundfile as sf
1818
import torch
19-
from kui.wsgi import (
19+
from kui.asgi import (
2020
Body,
21+
FileResponse,
2122
HTTPException,
2223
HttpView,
2324
JSONResponse,
2425
Kui,
2526
OpenAPI,
2627
StreamResponse,
2728
)
28-
from kui.wsgi.routing import MultimethodRoutes
29+
from kui.asgi.routing import MultimethodRoutes
2930
from loguru import logger
3031
from pydantic import BaseModel, Field
3132
from transformers import AutoTokenizer
@@ -57,7 +58,7 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
5758

5859

5960
# Define utils for web server
60-
def http_execption_handler(exc: HTTPException):
61+
async def http_execption_handler(exc: HTTPException):
6162
return JSONResponse(
6263
dict(
6364
statusCode=exc.status_code,
@@ -69,7 +70,7 @@ def http_execption_handler(exc: HTTPException):
6970
)
7071

7172

72-
def other_exception_handler(exc: "Exception"):
73+
async def other_exception_handler(exc: "Exception"):
7374
traceback.print_exc()
7475

7576
status = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -334,8 +335,17 @@ def inference(req: InvokeRequest):
334335
yield fake_audios
335336

336337

338+
async def inference_async(req: InvokeRequest):
339+
for chunk in inference(req):
340+
yield chunk
341+
342+
343+
async def buffer_to_async_generator(buffer):
344+
yield buffer
345+
346+
337347
@routes.http.post("/v1/invoke")
338-
def api_invoke_model(
348+
async def api_invoke_model(
339349
req: Annotated[InvokeRequest, Body(exclusive=True)],
340350
):
341351
"""
@@ -354,22 +364,21 @@ def api_invoke_model(
354364
content="Streaming only supports WAV format",
355365
)
356366

357-
generator = inference(req)
358367
if req.streaming:
359368
return StreamResponse(
360-
iterable=generator,
369+
iterable=inference_async(req),
361370
headers={
362371
"Content-Disposition": f"attachment; filename=audio.{req.format}",
363372
},
364373
content_type=get_content_type(req.format),
365374
)
366375
else:
367-
fake_audios = next(generator)
376+
fake_audios = next(inference(req))
368377
buffer = io.BytesIO()
369378
sf.write(buffer, fake_audios, decoder_model.sampling_rate, format=req.format)
370379

371380
return StreamResponse(
372-
iterable=[buffer.getvalue()],
381+
iterable=buffer_to_async_generator(buffer.getvalue()),
373382
headers={
374383
"Content-Disposition": f"attachment; filename=audio.{req.format}",
375384
},
@@ -378,7 +387,7 @@ def api_invoke_model(
378387

379388

380389
@routes.http.post("/v1/health")
381-
def api_health():
390+
async def api_health():
382391
"""
383392
Health check
384393
"""
@@ -409,6 +418,7 @@ def parse_args():
409418
parser.add_argument("--compile", action="store_true")
410419
parser.add_argument("--max-text-length", type=int, default=0)
411420
parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
421+
parser.add_argument("--workers", type=int, default=1)
412422

413423
return parser.parse_args()
414424

@@ -433,7 +443,7 @@ def parse_args():
433443
if __name__ == "__main__":
434444
import threading
435445

436-
from zibai import create_bind_socket, serve
446+
import uvicorn
437447

438448
args = parse_args()
439449
args.precision = torch.half if args.half else torch.bfloat16
@@ -480,13 +490,5 @@ def parse_args():
480490
)
481491

482492
logger.info(f"Warming up done, starting server at http://{args.listen}")
483-
sock = create_bind_socket(args.listen)
484-
sock.listen()
485-
486-
# Start server
487-
serve(
488-
app=app,
489-
bind_sockets=[sock],
490-
max_workers=10,
491-
graceful_exit=threading.Event(),
492-
)
493+
host, port = args.listen.split(":")
494+
uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")

0 commit comments

Comments
 (0)