Skip to content

Commit 85ed6b5

Browse files
committed
Use glm 4 to build an OpenAI-compatible service
1 parent 3f79b54 commit 85ed6b5

File tree

3 files changed

+138
-0
lines changed

3 files changed

+138
-0
lines changed

inference/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,27 @@ python vllm_cli_demo.py # LLM Such as GLM-4-9B-0414
102102
vllm serve THUDM/GLM-4-9B-0414 --tensor_parallel_size 2
103103
```
104104

105+
### Use glm-4 to build an OpenAI-compatible service
106+
107+
Start the server:
108+
109+
```shell
110+
python glm4_server.py THUDM/GLM-4-9B-0414
111+
```
112+
113+
Client request:
114+
115+
```shell
116+
curl -X POST http://localhost:8000/v1/chat/completions \
117+
-H 'Content-Type: application/json' \
118+
-d \
119+
"{ \
120+
\"messages\": [ \
121+
{\"role\": \"user\", \"content\": \"Who are you?\"} \
122+
] \
123+
}"
124+
```
125+
105126
### Use glm-4v to build an OpenAI-compatible service
106127

107128
Start the server:

inference/README_zh.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,27 @@ python vllm_cli_demo.py # LLM Such as GLM-4-9B-0414
102102
vllm serve THUDM/GLM-4-9B-0414 --tensor_parallel_size 2
103103
```
104104

105+
### 使用 glm-4 构建 OpenAI 服务
106+
107+
启动服务端
108+
109+
```shell
110+
python glm4_server.py THUDM/GLM-4-9B-0414
111+
```
112+
113+
客户端请求:
114+
115+
```shell
116+
curl -X POST http://localhost:8000/v1/chat/completions \
117+
-H 'Content-Type: application/json' \
118+
-d \
119+
"{ \
120+
\"messages\": [ \
121+
{\"role\": \"user\", \"content\": \"Who are you?\"} \
122+
] \
123+
}"
124+
```
125+
105126
### 使用 glm-4v 构建 OpenAI 服务
106127

107128
启动服务端

inference/glm4_server.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import sys
2+
import torch
3+
from threading import Thread
4+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5+
import uvicorn
6+
from fastapi import FastAPI
7+
from pydantic import BaseModel
8+
from typing import List, Literal, Optional
9+
10+
11+
app = FastAPI()
12+
13+
14+
class MessageInput(BaseModel):
15+
role: Literal["user", "assistant", "system"]
16+
content: str
17+
name: Optional[str] = None
18+
19+
20+
class MessageOutput(BaseModel):
21+
role: Literal["assistant"]
22+
content: str = None
23+
name: Optional[str] = None
24+
25+
26+
class Choice(BaseModel):
27+
message: MessageOutput
28+
29+
30+
class Request(BaseModel):
31+
messages: List[MessageInput]
32+
temperature: Optional[float] = 0.8
33+
top_p: Optional[float] = 0.8
34+
max_tokens: Optional[int] = 1024
35+
repetition_penalty: Optional[float] = 1.0
36+
37+
38+
class Response(BaseModel):
39+
model: str
40+
choices: List[Choice]
41+
42+
43+
@app.post("/v1/chat/completions", response_model=Response)
44+
async def create_chat_completion(request: Request):
45+
global model, tokenizer
46+
47+
messages = [message.model_dump() for message in request.messages]
48+
model_inputs = tokenizer.apply_chat_template(
49+
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
50+
).to(model.device)
51+
streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
52+
generate_kwargs = {
53+
"input_ids": model_inputs["input_ids"],
54+
"attention_mask": model_inputs["attention_mask"],
55+
"streamer": streamer,
56+
"max_new_tokens": request.max_tokens,
57+
"do_sample": True,
58+
"top_p": request.top_p,
59+
"temperature": request.temperature,
60+
"repetition_penalty": request.repetition_penalty,
61+
"eos_token_id": model.config.eos_token_id,
62+
}
63+
64+
thread = Thread(target=model.generate, kwargs=generate_kwargs)
65+
thread.start()
66+
67+
result = ""
68+
for new_token in streamer:
69+
result += new_token
70+
71+
print("\033[91m--generated_text\033[0m", result)
72+
73+
message = MessageOutput(
74+
role="assistant",
75+
content=result,
76+
)
77+
choice = Choice(
78+
message=message,
79+
)
80+
response = Response(model=sys.argv[1].split('/')[-1].lower(), choices=[choice])
81+
return response
82+
83+
84+
torch.cuda.empty_cache()
85+
86+
if __name__ == "__main__":
87+
MODEL_PATH = sys.argv[1]
88+
89+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
90+
model = AutoModelForCausalLM.from_pretrained(
91+
MODEL_PATH,
92+
torch_dtype=torch.bfloat16,
93+
device_map="auto",
94+
).eval()
95+
96+
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

0 commit comments

Comments
 (0)