16
16
import pyrootutils
17
17
import soundfile as sf
18
18
import torch
19
- from kui .wsgi import (
19
+ from kui .asgi import (
20
20
Body ,
21
+ FileResponse ,
21
22
HTTPException ,
22
23
HttpView ,
23
24
JSONResponse ,
24
25
Kui ,
25
26
OpenAPI ,
26
27
StreamResponse ,
27
28
)
28
- from kui .wsgi .routing import MultimethodRoutes
29
+ from kui .asgi .routing import MultimethodRoutes
29
30
from loguru import logger
30
31
from pydantic import BaseModel , Field
31
32
from transformers import AutoTokenizer
@@ -57,7 +58,7 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
57
58
58
59
59
60
# Define utils for web server
60
- def http_execption_handler (exc : HTTPException ):
61
+ async def http_execption_handler (exc : HTTPException ):
61
62
return JSONResponse (
62
63
dict (
63
64
statusCode = exc .status_code ,
@@ -69,7 +70,7 @@ def http_execption_handler(exc: HTTPException):
69
70
)
70
71
71
72
72
- def other_exception_handler (exc : "Exception" ):
73
+ async def other_exception_handler (exc : "Exception" ):
73
74
traceback .print_exc ()
74
75
75
76
status = HTTPStatus .INTERNAL_SERVER_ERROR
@@ -334,8 +335,17 @@ def inference(req: InvokeRequest):
334
335
yield fake_audios
335
336
336
337
338
+ async def inference_async (req : InvokeRequest ):
339
+ for chunk in inference (req ):
340
+ yield chunk
341
+
342
+
343
+ async def buffer_to_async_generator (buffer ):
344
+ yield buffer
345
+
346
+
337
347
@routes .http .post ("/v1/invoke" )
338
- def api_invoke_model (
348
+ async def api_invoke_model (
339
349
req : Annotated [InvokeRequest , Body (exclusive = True )],
340
350
):
341
351
"""
@@ -354,22 +364,21 @@ def api_invoke_model(
354
364
content = "Streaming only supports WAV format" ,
355
365
)
356
366
357
- generator = inference (req )
358
367
if req .streaming :
359
368
return StreamResponse (
360
- iterable = generator ,
369
+ iterable = inference_async ( req ) ,
361
370
headers = {
362
371
"Content-Disposition" : f"attachment; filename=audio.{ req .format } " ,
363
372
},
364
373
content_type = get_content_type (req .format ),
365
374
)
366
375
else :
367
- fake_audios = next (generator )
376
+ fake_audios = next (inference ( req ) )
368
377
buffer = io .BytesIO ()
369
378
sf .write (buffer , fake_audios , decoder_model .sampling_rate , format = req .format )
370
379
371
380
return StreamResponse (
372
- iterable = [ buffer .getvalue ()] ,
381
+ iterable = buffer_to_async_generator ( buffer .getvalue ()) ,
373
382
headers = {
374
383
"Content-Disposition" : f"attachment; filename=audio.{ req .format } " ,
375
384
},
@@ -378,7 +387,7 @@ def api_invoke_model(
378
387
379
388
380
389
@routes .http .post ("/v1/health" )
381
- def api_health ():
390
+ async def api_health ():
382
391
"""
383
392
Health check
384
393
"""
@@ -409,6 +418,7 @@ def parse_args():
409
418
parser .add_argument ("--compile" , action = "store_true" )
410
419
parser .add_argument ("--max-text-length" , type = int , default = 0 )
411
420
parser .add_argument ("--listen" , type = str , default = "127.0.0.1:8000" )
421
+ parser .add_argument ("--workers" , type = int , default = 1 )
412
422
413
423
return parser .parse_args ()
414
424
@@ -433,7 +443,7 @@ def parse_args():
433
443
if __name__ == "__main__" :
434
444
import threading
435
445
436
- from zibai import create_bind_socket , serve
446
+ import uvicorn
437
447
438
448
args = parse_args ()
439
449
args .precision = torch .half if args .half else torch .bfloat16
@@ -480,13 +490,5 @@ def parse_args():
480
490
)
481
491
482
492
logger .info (f"Warming up done, starting server at http://{ args .listen } " )
483
- sock = create_bind_socket (args .listen )
484
- sock .listen ()
485
-
486
- # Start server
487
- serve (
488
- app = app ,
489
- bind_sockets = [sock ],
490
- max_workers = 10 ,
491
- graceful_exit = threading .Event (),
492
- )
493
+ host , port = args .listen .split (":" )
494
+ uvicorn .run (app , host = host , port = int (port ), workers = args .workers , log_level = "info" )
0 commit comments