diff --git a/tpu_commons/runner/tpu_jax_runner.py b/tpu_commons/runner/tpu_jax_runner.py index fb182d15b..585c4f56c 100644 --- a/tpu_commons/runner/tpu_jax_runner.py +++ b/tpu_commons/runner/tpu_jax_runner.py @@ -206,7 +206,10 @@ def _init_inputs(self) -> None: self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size - self.max_model_len = model_config.max_model_len + # Note: https://github.com/vllm-project/vllm/pull/26168 allows + # prefill_length to be max-model-len, and decode for 1 token. + # so the total sequence length becomes max_model_len + 1. + self.max_model_len = model_config.max_model_len + 1 self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) # InputBatch needs to work with sampling tensors greater than padding # to avoid dynamic shapes. Also, avoid suboptimal alignment.