Skip to content

Commit 8c7e7bb

Browse files
jrplatinJacob Platin
andauthored
[Misc] Fix various vLLM import issues (#900)
Signed-off-by: Jacob Platin <[email protected]> Co-authored-by: Jacob Platin <[email protected]>
1 parent 8d856cf commit 8c7e7bb

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

tpu_inference/models/common/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchax.ops.mappings import j2t_dtype
1010
from transformers import PretrainedConfig
1111
from vllm.config import VllmConfig
12-
from vllm.utils import supports_kw
12+
from vllm.utils.functools import supports_kw
1313

1414
from tpu_inference.logger import init_logger
1515
from tpu_inference.models.jax.utils.quantization.quantization_utils import (

tpu_inference/platforms/tpu_jax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
122122
"VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
123123
)
124124

125-
from vllm.config import CompilationLevel
125+
from vllm.config import CompilationMode
126126

127127
cache_config = vllm_config.cache_config
128128
# For v0, the default block size is 16.
129129
if cache_config and cache_config.block_size is None:
130130
cache_config.block_size = cast(BlockSize, 16)
131131
compilation_config = vllm_config.compilation_config
132132

133-
# TPU only supports DYNAMO_ONCE compilation level
133+
# TPU only supports DYNAMO_TRACE_ONCE compilation level
134134
# NOTE(xiang): the compilation_config is not used by jax.
135-
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
136-
compilation_config.level = CompilationLevel.DYNAMO_ONCE
135+
if compilation_config.level != CompilationMode.DYNAMO_TRACE_ONCE:
136+
compilation_config.level = CompilationMode.DYNAMO_TRACE_ONCE
137137

138138
if compilation_config.backend == "":
139139
compilation_config.backend = "openxla"

tpu_inference/runner/input_batch_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
from vllm.lora.request import LoRARequest
1111
from vllm.sampling_params import SamplingType
12-
from vllm.utils import swap_dict_values
12+
from vllm.utils.collections import swap_dict_values
1313
from vllm.v1.core.sched.output import NewRequestData
1414
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
1515

0 commit comments

Comments
 (0)