Skip to content

Commit 9eafc4c

Browse files
committed
Refactor server to use factory
1 parent dd9ad1c commit 9eafc4c

File tree

3 files changed

+47
-31
lines changed

3 files changed

+47
-31
lines changed

llama_cpp/server/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import os
2525
import uvicorn
2626

27-
from llama_cpp.server.app import app, init_llama
27+
from llama_cpp.server.app import create_app
2828

2929
if __name__ == "__main__":
30-
init_llama()
30+
app = create_app()
3131

3232
uvicorn.run(
3333
app, host=os.getenv("HOST", "localhost"), port=int(os.getenv("PORT", 8000))

llama_cpp/server/app.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
import json
33
from threading import Lock
44
from typing import List, Optional, Union, Iterator, Dict
5-
from typing_extensions import TypedDict, Literal
5+
from typing_extensions import TypedDict, Literal, Annotated
66

77
import llama_cpp
88

9-
from fastapi import Depends, FastAPI
9+
from fastapi import Depends, FastAPI, APIRouter
1010
from fastapi.middleware.cors import CORSMiddleware
1111
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
1212
from sse_starlette.sse import EventSourceResponse
1313

1414

1515
class Settings(BaseSettings):
16-
model: str = os.environ.get("MODEL", "null")
16+
model: str
1717
n_ctx: int = 2048
1818
n_batch: int = 512
1919
n_threads: int = max((os.cpu_count() or 2) // 2, 1)
@@ -27,25 +27,29 @@ class Settings(BaseSettings):
2727
vocab_only: bool = False
2828

2929

30-
app = FastAPI(
31-
title="🦙 llama.cpp Python API",
32-
version="0.0.1",
33-
)
34-
app.add_middleware(
35-
CORSMiddleware,
36-
allow_origins=["*"],
37-
allow_credentials=True,
38-
allow_methods=["*"],
39-
allow_headers=["*"],
40-
)
30+
router = APIRouter()
31+
32+
llama: Optional[llama_cpp.Llama] = None
4133

42-
llama: llama_cpp.Llama = None
43-
def init_llama(settings: Settings = None):
34+
35+
def create_app(settings: Optional[Settings] = None):
4436
if settings is None:
4537
settings = Settings()
38+
app = FastAPI(
39+
title="🦙 llama.cpp Python API",
40+
version="0.0.1",
41+
)
42+
app.add_middleware(
43+
CORSMiddleware,
44+
allow_origins=["*"],
45+
allow_credentials=True,
46+
allow_methods=["*"],
47+
allow_headers=["*"],
48+
)
49+
app.include_router(router)
4650
global llama
4751
llama = llama_cpp.Llama(
48-
settings.model,
52+
model_path=settings.model,
4953
f16_kv=settings.f16_kv,
5054
use_mlock=settings.use_mlock,
5155
use_mmap=settings.use_mmap,
@@ -60,12 +64,17 @@ def init_llama(settings: Settings = None):
6064
if settings.cache:
6165
cache = llama_cpp.LlamaCache()
6266
llama.set_cache(cache)
67+
return app
68+
6369

6470
llama_lock = Lock()
71+
72+
6573
def get_llama():
6674
with llama_lock:
6775
yield llama
6876

77+
6978
class CreateCompletionRequest(BaseModel):
7079
prompt: Union[str, List[str]]
7180
suffix: Optional[str] = Field(None)
@@ -102,7 +111,7 @@ class Config:
102111
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
103112

104113

105-
@app.post(
114+
@router.post(
106115
"/v1/completions",
107116
response_model=CreateCompletionResponse,
108117
)
@@ -148,7 +157,7 @@ class Config:
148157
CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
149158

150159

151-
@app.post(
160+
@router.post(
152161
"/v1/embeddings",
153162
response_model=CreateEmbeddingResponse,
154163
)
@@ -202,7 +211,7 @@ class Config:
202211
CreateChatCompletionResponse = create_model_from_typeddict(llama_cpp.ChatCompletion)
203212

204213

205-
@app.post(
214+
@router.post(
206215
"/v1/chat/completions",
207216
response_model=CreateChatCompletionResponse,
208217
)
@@ -256,7 +265,7 @@ class ModelList(TypedDict):
256265
GetModelResponse = create_model_from_typeddict(ModelList)
257266

258267

259-
@app.get("/v1/models", response_model=GetModelResponse)
268+
@router.get("/v1/models", response_model=GetModelResponse)
260269
def get_models() -> ModelList:
261270
return {
262271
"object": "list",

tests/test_llama.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ def test_llama_patch(monkeypatch):
2222
## Set up mock function
2323
def mock_eval(*args, **kwargs):
2424
return 0
25-
25+
2626
def mock_get_logits(*args, **kwargs):
27-
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
27+
return (llama_cpp.c_float * n_vocab)(
28+
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
29+
)
2830

2931
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
3032
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -88,6 +90,7 @@ def mock_sample(*args, **kwargs):
8890
def test_llama_pickle():
8991
import pickle
9092
import tempfile
93+
9194
fp = tempfile.TemporaryFile()
9295
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
9396
pickle.dump(llama, fp)
@@ -101,6 +104,7 @@ def test_llama_pickle():
101104

102105
assert llama.detokenize(llama.tokenize(text)) == text
103106

107+
104108
def test_utf8(monkeypatch):
105109
llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True)
106110
n_vocab = int(llama_cpp.llama_n_vocab(llama.ctx))
@@ -110,7 +114,9 @@ def mock_eval(*args, **kwargs):
110114
return 0
111115

112116
def mock_get_logits(*args, **kwargs):
113-
return (llama_cpp.c_float * n_vocab)(*[llama_cpp.c_float(0) for _ in range(n_vocab)])
117+
return (llama_cpp.c_float * n_vocab)(
118+
*[llama_cpp.c_float(0) for _ in range(n_vocab)]
119+
)
114120

115121
monkeypatch.setattr("llama_cpp.llama_cpp.llama_eval", mock_eval)
116122
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
@@ -143,11 +149,12 @@ def mock_sample(*args, **kwargs):
143149

144150
def test_llama_server():
145151
from fastapi.testclient import TestClient
146-
from llama_cpp.server.app import app, init_llama, Settings
147-
s = Settings()
148-
s.model = MODEL
149-
s.vocab_only = True
150-
init_llama(s)
152+
from llama_cpp.server.app import create_app, Settings
153+
154+
settings = Settings()
155+
settings.model = MODEL
156+
settings.vocab_only = True
157+
app = create_app(settings)
151158
client = TestClient(app)
152159
response = client.get("/v1/models")
153160
assert response.json() == {

0 commit comments

Comments
 (0)