|
1 | 1 | import os |
2 | | -from typing import Dict, List, Optional |
| 2 | +from array import array |
| 3 | +from typing import Any, Dict, List, Optional |
3 | 4 |
|
4 | 5 | import ray |
5 | 6 | import vllm.envs as envs |
6 | 7 | from ray.util.placement_group import PlacementGroup |
7 | 8 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy |
8 | 9 | from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator |
9 | | -from vllm.executor.ray_distributed_executor import RayWorkerMetaData |
10 | | -from vllm.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready |
| 10 | +from vllm.multimodal.inputs import MultiModalKwargs |
11 | 11 | from vllm.platforms import current_platform |
12 | 12 | from vllm.ray.ray_env import get_env_vars_to_copy |
13 | | -from vllm.utils import get_distributed_init_method, get_ip, get_open_port |
| 13 | +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE |
| 14 | +from vllm.utils.network_utils import (get_distributed_init_method, get_ip, |
| 15 | + get_open_port) |
14 | 16 | from vllm.v1.executor.ray_distributed_executor import \ |
15 | 17 | RayDistributedExecutor as RayDistributedExecutorV1 |
| 18 | +from vllm.v1.executor.ray_executor import RayWorkerMetaData |
| 19 | +from vllm.v1.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready |
16 | 20 |
|
17 | 21 | from tpu_inference.logger import init_logger |
18 | 22 |
|
|
27 | 31 | from collections import defaultdict |
28 | 32 |
|
29 | 33 | import msgspec |
30 | | -from vllm.executor.msgspec_utils import encode_hook |
31 | 34 | from vllm.v1.outputs import SamplerOutput |
32 | 35 |
|
33 | 36 | from tpu_inference.distributed.utils import set_node_kv_ip_port |
34 | 37 |
|
35 | 38 | logger = init_logger(__name__) |
36 | 39 |
|
37 | 40 |
|
| 41 | +def _encode_hook(obj: Any) -> Any: |
| 42 | + """Custom msgspec enc hook that supports array types and MultiModalKwargs. |
| 43 | +
|
| 44 | + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder |
| 45 | + """ |
| 46 | + if isinstance(obj, array): |
| 47 | + assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( |
| 48 | + f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " |
| 49 | + f"Given array has a type code of {obj.typecode}.") |
| 50 | + return obj.tobytes() |
| 51 | + if isinstance(obj, MultiModalKwargs): |
| 52 | + return dict(obj) |
| 53 | + |
| 54 | + |
38 | 55 | class RayDistributedExecutor(RayDistributedExecutorV1): |
39 | 56 | """Ray-based distributed executor for TPU. |
40 | 57 |
|
@@ -82,7 +99,7 @@ def _init_executor(self) -> None: |
82 | 99 | # Create the parallel GPU workers. |
83 | 100 | self._init_workers_ray(placement_group) |
84 | 101 |
|
85 | | - self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) |
| 102 | + self.input_encoder = msgspec.msgpack.Encoder(enc_hook=_encode_hook) |
86 | 103 | self.output_decoder = msgspec.msgpack.Decoder( |
87 | 104 | Optional[List[SamplerOutput]]) |
88 | 105 | self.use_v1 = envs.VLLM_USE_V1 |
|
0 commit comments