Skip to content

Commit cb170bc

Browse files
authored
Fix the jax device ordering. (#915)
Signed-off-by: Lance Wang <[email protected]>
1 parent 5b0c3d2 commit cb170bc

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

tpu_inference/runner/tpu_jax_runner.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ def _init_mesh(self) -> None:
154154
except KeyError:
155155
sharding_strategy = {"tensor_parallelism": len(self.devices)}
156156

157+
try:
158+
enforce_device_order = self.vllm_config.additional_config[
159+
"sharding"]["sharding_strategy"]["device_indexes"] is not None
160+
161+
except KeyError:
162+
enforce_device_order = False
163+
164+
logger.info(f"Device sequence enforced: {enforce_device_order}")
165+
157166
if os.getenv("NEW_MODEL_DESIGN", False):
158167
self.mesh = build_mesh(self.devices, sharding_strategy)
159168
else:
@@ -169,9 +178,14 @@ def _init_mesh(self) -> None:
169178
axis_names = ("data", "model")
170179
mesh_shape = (dp, tp)
171180

172-
self.mesh = make_optimized_mesh(mesh_shape,
173-
axis_names,
174-
devices=self.devices)
181+
if enforce_device_order:
182+
self.mesh = jax.make_mesh(mesh_shape,
183+
axis_names,
184+
devices=self.devices)
185+
else:
186+
self.mesh = make_optimized_mesh(mesh_shape,
187+
axis_names,
188+
devices=self.devices)
175189
logger.info(f"Init mesh | mesh={self.mesh}")
176190

177191
def _init_phased_profiling(self) -> None:

tpu_inference/worker/tpu_worker_jax.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,29 @@ def initialize_cache(self, num_gpu_blocks: int,
121121

122122
def init_device(self):
123123
if not self.devices:
124+
device_indexes = []
125+
tp = self.parallel_config.tensor_parallel_size
124126
try:
125127
device_indexes = self.vllm_config.additional_config[
126128
"sharding"]["sharding_strategy"]["device_indexes"]
127-
self.devices = [jax.devices()[i] for i in device_indexes]
128129
except KeyError:
129-
tp = self.parallel_config.tensor_parallel_size
130130
self.devices = jax.devices()[:tp]
131131

132+
# Enforcing the devices sequence to be consistent with the specified device indexes
133+
if not self.devices:
134+
all_devices = jax.devices()
135+
device_dict = {device.id: device for device in all_devices}
136+
self.devices = []
137+
for device_index in device_indexes:
138+
device = device_dict[device_index]
139+
if device is None:
140+
raise KeyError(
141+
f"Device index {device_index} not found in "
142+
f"jax.devices() with IDs {list(device_dict.keys())}!"
143+
)
144+
self.devices.append(device)
145+
self.devices = self.devices[:tp]
146+
132147
# Initialize the vLLM distribution layer as a single chip environment,
133148
# we'll swap the model's parallel modules with TPU SPMD equivalents.
134149
with set_current_vllm_config(self.vllm_config):

0 commit comments

Comments
 (0)