Skip to content

Commit 016bdff

Browse files
author
xusenlin
committed
compat with old vllm version
1 parent 87097bf commit 016bdff

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

api/models.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,15 @@ def create_hf_llm():
9292
def create_vllm_engine():
9393
""" get vllm generate engine for chat or completion. """
9494
try:
95+
import vllm
9596
from vllm.engine.arg_utils import AsyncEngineArgs
9697
from vllm.engine.async_llm_engine import AsyncLLMEngine
9798
from api.core.vllm_engine import VllmEngine, LoRA
9899
except ImportError:
99100
raise ValueError("VLLM engine not available")
100101

102+
vllm_version = vllm.__version__
103+
101104
include = {
102105
"tokenizer_mode",
103106
"trust_remote_code",
@@ -106,11 +109,14 @@ def create_vllm_engine():
106109
"gpu_memory_utilization",
107110
"max_num_seqs",
108111
"enforce_eager",
109-
"max_seq_len_to_capture",
110112
"max_loras",
111113
"max_lora_rank",
112114
"lora_extra_vocab_size",
113115
}
116+
117+
if vllm_version >= "0.4.3":
118+
include.add("max_seq_len_to_capture")
119+
114120
kwargs = dictify(SETTINGS, include=include)
115121
engine_args = AsyncEngineArgs(
116122
model=SETTINGS.model_path,

api/vllm_routes/chat.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import AsyncIterator
66

77
import anyio
8+
import vllm
89
from fastapi import APIRouter, Depends, status
910
from fastapi import HTTPException, Request
1011
from loguru import logger
@@ -38,6 +39,7 @@
3839
)
3940

4041
chat_router = APIRouter(prefix="/chat")
42+
vllm_version = vllm.__version__
4143

4244

4345
def get_engine():
@@ -105,17 +107,16 @@ async def create_chat_completion(
105107
try:
106108
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
107109

108-
decoding_config = await engine.model.get_decoding_config()
109-
110-
try:
110+
if vllm_version >= "0.4.3":
111+
decoding_config = await engine.model.get_decoding_config()
111112
guided_decode_logits_processor = (
112113
await get_guided_decoding_logits_processor(
113114
request.guided_decoding_backend or decoding_config.guided_decoding_backend,
114115
request,
115116
engine.tokenizer,
116117
)
117118
)
118-
except TypeError:
119+
else:
119120
guided_decode_logits_processor = (
120121
await get_guided_decoding_logits_processor(
121122
request,
@@ -128,7 +129,7 @@ async def create_chat_completion(
128129
except ImportError:
129130
pass
130131

131-
try:
132+
if vllm_version >= "0.4.3":
132133
result_generator = engine.model.generate(
133134
{
134135
"prompt": prompt if isinstance(prompt, str) else None,
@@ -138,7 +139,7 @@ async def create_chat_completion(
138139
request_id,
139140
lora_request,
140141
)
141-
except TypeError:
142+
else:
142143
result_generator = engine.model.generate(
143144
prompt if isinstance(prompt, str) else None,
144145
sampling_params,

api/vllm_routes/completion.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import AsyncIterator, Tuple
77

88
import anyio
9+
import vllm
910
from fastapi import APIRouter, Depends
1011
from fastapi import Request
1112
from loguru import logger
@@ -27,6 +28,7 @@
2728
)
2829

2930
completion_router = APIRouter()
31+
vllm_version = vllm.__version__
3032

3133

3234
def get_engine():
@@ -144,17 +146,16 @@ async def create_completion(
144146
try:
145147
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
146148

147-
decoding_config = await engine.model.get_decoding_config()
148-
149-
try:
149+
if vllm_version >= "0.4.3":
150+
decoding_config = await engine.model.get_decoding_config()
150151
guided_decode_logits_processor = (
151152
await get_guided_decoding_logits_processor(
152153
request.guided_decoding_backend or decoding_config.guided_decoding_backend,
153154
request,
154155
engine.tokenizer,
155156
)
156157
)
157-
except TypeError:
158+
else:
158159
guided_decode_logits_processor = (
159160
await get_guided_decoding_logits_processor(
160161
request,
@@ -176,7 +177,7 @@ async def create_completion(
176177
else:
177178
input_ids = engine.convert_to_inputs(prompt=prompt, max_tokens=request.max_tokens)
178179

179-
try:
180+
if vllm_version >= "0.4.3":
180181
generator = engine.model.generate(
181182
{
182183
"prompt": prompt,
@@ -186,7 +187,7 @@ async def create_completion(
186187
request_id,
187188
lora_request,
188189
)
189-
except TypeError:
190+
else:
190191
generator = engine.model.generate(
191192
prompt,
192193
sampling_params,

0 commit comments

Comments
 (0)