Skip to content

Commit 163cd94

Browse files
authored
[CI] Fixes to catchup with vllm changes (#912)
Signed-off-by: Hongmin Fan <[email protected]>
1 parent f7dad5d commit 163cd94

File tree

9 files changed

+48
-38
lines changed

9 files changed

+48
-38
lines changed

tests/layers/vllm/test_unquantized.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import tempfile
23

34
import jax
@@ -415,6 +416,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
415416
@pytest.mark.parametrize("topk", [2])
416417
def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
417418
num_experts, topk):
419+
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
418420
torch.manual_seed(42)
419421
dtype = torch.bfloat16
420422

tpu_inference/core/disagg_executor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from vllm.logger import init_logger
77
from vllm.multimodal import MULTIMODAL_REGISTRY
88
from vllm.multimodal.cache import worker_receiver_cache_from_config
9-
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
10-
run_method)
9+
from vllm.utils import run_method
10+
from vllm.utils.network_utils import (get_distributed_init_method, get_ip,
11+
get_open_port)
1112
from vllm.v1.executor.abstract import Executor
1213
from vllm.v1.outputs import AsyncModelRunnerOutput
1314
from vllm.v1.worker.worker_base import WorkerWrapperBase

tpu_inference/distributed/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22

3-
from vllm.utils import get_ip
3+
from vllm.utils.network_utils import get_ip
44

55
from tpu_inference.logger import init_logger
66

tpu_inference/executors/ray_distributed_executor.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import os
2-
from typing import Dict, List, Optional
2+
from array import array
3+
from typing import Any, Dict, List, Optional
34

45
import ray
56
import vllm.envs as envs
67
from ray.util.placement_group import PlacementGroup
78
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
89
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
9-
from vllm.executor.ray_distributed_executor import RayWorkerMetaData
10-
from vllm.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready
10+
from vllm.multimodal.inputs import MultiModalKwargs
1111
from vllm.platforms import current_platform
1212
from vllm.ray.ray_env import get_env_vars_to_copy
13-
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
13+
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
14+
from vllm.utils.network_utils import (get_distributed_init_method, get_ip,
15+
get_open_port)
1416
from vllm.v1.executor.ray_distributed_executor import \
1517
RayDistributedExecutor as RayDistributedExecutorV1
18+
from vllm.v1.executor.ray_executor import RayWorkerMetaData
19+
from vllm.v1.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready
1620

1721
from tpu_inference.logger import init_logger
1822

@@ -27,14 +31,27 @@
2731
from collections import defaultdict
2832

2933
import msgspec
30-
from vllm.executor.msgspec_utils import encode_hook
3134
from vllm.v1.outputs import SamplerOutput
3235

3336
from tpu_inference.distributed.utils import set_node_kv_ip_port
3437

3538
logger = init_logger(__name__)
3639

3740

41+
def _encode_hook(obj: Any) -> Any:
42+
"""Custom msgspec enc hook that supports array types and MultiModalKwargs.
43+
44+
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
45+
"""
46+
if isinstance(obj, array):
47+
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
48+
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
49+
f"Given array has a type code of {obj.typecode}.")
50+
return obj.tobytes()
51+
if isinstance(obj, MultiModalKwargs):
52+
return dict(obj)
53+
54+
3855
class RayDistributedExecutor(RayDistributedExecutorV1):
3956
"""Ray-based distributed executor for TPU.
4057
@@ -82,7 +99,7 @@ def _init_executor(self) -> None:
8299
# Create the parallel GPU workers.
83100
self._init_workers_ray(placement_group)
84101

85-
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
102+
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=_encode_hook)
86103
self.output_decoder = msgspec.msgpack.Decoder(
87104
Optional[List[SamplerOutput]])
88105
self.use_v1 = envs.VLLM_USE_V1

tpu_inference/layers/vllm/quantization/awq.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from vllm.model_executor.layers.quantization import \
1212
register_quantization_config
1313
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
14-
AWQLinearMethod,
15-
is_layer_skipped_awq)
14+
AWQLinearMethod)
1615
from vllm.model_executor.layers.quantization.base_config import \
1716
QuantizeMethodBase
18-
from vllm.model_executor.layers.quantization.utils.quant_utils import \
19-
unpack_quantized_values_into_int32
17+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18+
is_layer_skipped, unpack_quantized_values_into_int32)
2019
from vllm.scalar_type import scalar_types
2120

2221
from tpu_inference.layers.vllm.linear_common import (
@@ -48,7 +47,7 @@ def get_quant_method(
4847
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
4948
if isinstance(layer, LinearBase):
5049
linear_config = self.get_linear_config(layer)
51-
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
50+
if is_layer_skipped(prefix, self.modules_to_not_convert):
5251
return VllmUnquantizedLinearMethod(linear_config)
5352
return VllmAWQLinearMethod(self, linear_config)
5453
elif isinstance(layer, FusedMoE):

tpu_inference/models/common/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchax.ops.mappings import j2t_dtype
1010
from transformers import PretrainedConfig
1111
from vllm.config import VllmConfig
12-
from vllm.utils import supports_kw
12+
from vllm.utils.func_utils import supports_kw
1313

1414
from tpu_inference.logger import init_logger
1515
from tpu_inference.models.jax.utils.quantization.quantization_utils import (

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,15 @@ def __init__(self, vllm_model: torch.nn.Module):
4040
self.vllm_model = vllm_model
4141

4242
def forward(self, **kwargs) -> torch.Tensor:
43-
# We don't support multimodal input in Gemma3, but we need patch it to
44-
# None to workaround vLLM Gemma3 model bug that
45-
# `get_multimodal_embeddings` returns empty list but it's caller checks
46-
# for None.
47-
with patch(
48-
"vllm.model_executor.models.gemma3_mm."
49-
"Gemma3ForConditionalGeneration."
50-
"get_multimodal_embeddings",
51-
return_value=None):
52-
if "hidden_state" in kwargs:
53-
return self.compute_logits(kwargs["hidden_state"])
54-
else:
55-
return self.compute_hidden_state(
56-
kwargs["input_ids"],
57-
kwargs["positions"],
58-
kwargs["intermediate_tensors"],
59-
kwargs["inputs_embeds"],
60-
)
43+
if "hidden_state" in kwargs:
44+
return self.compute_logits(kwargs["hidden_state"])
45+
else:
46+
return self.compute_hidden_state(
47+
kwargs["input_ids"],
48+
kwargs["positions"],
49+
kwargs["intermediate_tensors"],
50+
kwargs["inputs_embeds"],
51+
)
6152

6253
def compute_hidden_state(
6354
self,

tpu_inference/platforms/tpu_jax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
122122
"VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
123123
)
124124

125-
from vllm.config import CompilationLevel
125+
from vllm.config import CompilationMode
126126

127127
cache_config = vllm_config.cache_config
128128
# For v0, the default block size is 16.
129129
if cache_config and cache_config.block_size is None:
130130
cache_config.block_size = cast(BlockSize, 16)
131131
compilation_config = vllm_config.compilation_config
132132

133-
# TPU only supports DYNAMO_ONCE compilation level
133+
# TPU only supports DYNAMO_TRACE_ONCE compilation level
134134
# NOTE(xiang): the compilation_config is not used by jax.
135-
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
136-
compilation_config.level = CompilationLevel.DYNAMO_ONCE
135+
if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
136+
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
137137

138138
if compilation_config.backend == "":
139139
compilation_config.backend = "openxla"

tpu_inference/runner/input_batch_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
from vllm.lora.request import LoRARequest
1111
from vllm.sampling_params import SamplingType
12-
from vllm.utils import swap_dict_values
12+
from vllm.utils.collection_utils import swap_dict_values
1313
from vllm.v1.core.sched.output import NewRequestData
1414
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
1515

0 commit comments

Comments
 (0)