2
2
import json
3
3
from threading import Lock
4
4
from typing import List , Optional , Union , Iterator , Dict
5
- from typing_extensions import TypedDict , Literal
5
+ from typing_extensions import TypedDict , Literal , Annotated
6
6
7
7
import llama_cpp
8
8
9
- from fastapi import Depends , FastAPI
9
+ from fastapi import Depends , FastAPI , APIRouter
10
10
from fastapi .middleware .cors import CORSMiddleware
11
11
from pydantic import BaseModel , BaseSettings , Field , create_model_from_typeddict
12
12
from sse_starlette .sse import EventSourceResponse
13
13
14
14
15
15
class Settings (BaseSettings ):
16
- model : str = os . environ . get ( "MODEL" , "null" )
16
+ model : str
17
17
n_ctx : int = 2048
18
18
n_batch : int = 512
19
19
n_threads : int = max ((os .cpu_count () or 2 ) // 2 , 1 )
@@ -27,25 +27,29 @@ class Settings(BaseSettings):
27
27
vocab_only : bool = False
28
28
29
29
30
- app = FastAPI (
31
- title = "🦙 llama.cpp Python API" ,
32
- version = "0.0.1" ,
33
- )
34
- app .add_middleware (
35
- CORSMiddleware ,
36
- allow_origins = ["*" ],
37
- allow_credentials = True ,
38
- allow_methods = ["*" ],
39
- allow_headers = ["*" ],
40
- )
30
+ router = APIRouter ()
31
+
32
+ llama : Optional [llama_cpp .Llama ] = None
41
33
42
- llama : llama_cpp . Llama = None
43
- def init_llama (settings : Settings = None ):
34
+
35
+ def create_app (settings : Optional [ Settings ] = None ):
44
36
if settings is None :
45
37
settings = Settings ()
38
+ app = FastAPI (
39
+ title = "🦙 llama.cpp Python API" ,
40
+ version = "0.0.1" ,
41
+ )
42
+ app .add_middleware (
43
+ CORSMiddleware ,
44
+ allow_origins = ["*" ],
45
+ allow_credentials = True ,
46
+ allow_methods = ["*" ],
47
+ allow_headers = ["*" ],
48
+ )
49
+ app .include_router (router )
46
50
global llama
47
51
llama = llama_cpp .Llama (
48
- settings .model ,
52
+ model_path = settings .model ,
49
53
f16_kv = settings .f16_kv ,
50
54
use_mlock = settings .use_mlock ,
51
55
use_mmap = settings .use_mmap ,
@@ -60,12 +64,17 @@ def init_llama(settings: Settings = None):
60
64
if settings .cache :
61
65
cache = llama_cpp .LlamaCache ()
62
66
llama .set_cache (cache )
67
+ return app
68
+
63
69
64
70
llama_lock = Lock ()
71
+
72
+
65
73
def get_llama ():
66
74
with llama_lock :
67
75
yield llama
68
76
77
+
69
78
class CreateCompletionRequest (BaseModel ):
70
79
prompt : Union [str , List [str ]]
71
80
suffix : Optional [str ] = Field (None )
@@ -102,7 +111,7 @@ class Config:
102
111
CreateCompletionResponse = create_model_from_typeddict (llama_cpp .Completion )
103
112
104
113
105
- @app .post (
114
+ @router .post (
106
115
"/v1/completions" ,
107
116
response_model = CreateCompletionResponse ,
108
117
)
@@ -148,7 +157,7 @@ class Config:
148
157
CreateEmbeddingResponse = create_model_from_typeddict (llama_cpp .Embedding )
149
158
150
159
151
- @app .post (
160
+ @router .post (
152
161
"/v1/embeddings" ,
153
162
response_model = CreateEmbeddingResponse ,
154
163
)
@@ -202,7 +211,7 @@ class Config:
202
211
CreateChatCompletionResponse = create_model_from_typeddict (llama_cpp .ChatCompletion )
203
212
204
213
205
- @app .post (
214
+ @router .post (
206
215
"/v1/chat/completions" ,
207
216
response_model = CreateChatCompletionResponse ,
208
217
)
@@ -256,7 +265,7 @@ class ModelList(TypedDict):
256
265
GetModelResponse = create_model_from_typeddict (ModelList )
257
266
258
267
259
- @app .get ("/v1/models" , response_model = GetModelResponse )
268
+ @router .get ("/v1/models" , response_model = GetModelResponse )
260
269
def get_models () -> ModelList :
261
270
return {
262
271
"object" : "list" ,
0 commit comments