Skip to content

Commit b4b195b

Browse files
fix max seq len (#489)
1 parent 20b0d88 commit b4b195b

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

vllm/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,10 @@ class SchedulerConfig:
204204
"""
205205

206206
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
207-
max_seq_len: int) -> None:
207+
max_model_len: int) -> None:
208208
self.max_num_batched_tokens = max_num_batched_tokens
209209
self.max_num_seqs = max_num_seqs
210-
self.max_seq_len = max_seq_len
210+
self.max_model_len = max_model_len
211211

212212

213213
_STR_DTYPE_TO_TORCH_DTYPE = {

vllm/core/scheduler.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def _schedule(
190190
break
191191

192192
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
193-
if num_prompt_tokens > self.scheduler_config.max_seq_len:
193+
if num_prompt_tokens > min(
194+
self.scheduler_config.max_model_len,
195+
self.scheduler_config.max_num_batched_tokens):
194196
logger.warning(
195197
f"Input prompt ({num_prompt_tokens} tokens) is too long"
196198
" and exceeds limit of "

vllm/engine/arg_utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,10 @@ def create_engine_configs(
155155
parallel_config = ParallelConfig(self.pipeline_parallel_size,
156156
self.tensor_parallel_size,
157157
self.worker_use_ray)
158-
model_max_len = getattr(model_config.hf_config,
158+
max_model_len = getattr(model_config.hf_config,
159159
'max_position_embeddings', float('inf'))
160-
max_seq_len = min(self.max_num_batched_tokens, model_max_len)
161160
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
162-
self.max_num_seqs, max_seq_len)
161+
self.max_num_seqs, max_model_len)
163162
return model_config, cache_config, parallel_config, scheduler_config
164163

165164

vllm/engine/llm_engine.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,7 @@ def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
300300
continue
301301

302302
# Check if the sequence has reached max_seq_len.
303-
if (seq.get_len() >
304-
self.scheduler.scheduler_config.max_seq_len):
303+
if seq.get_len() > self.scheduler_config.max_model_len:
305304
self.scheduler.free_seq(
306305
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
307306
continue

0 commit comments

Comments
 (0)