Skip to content

Commit a6666ca

Browse files
sixiang-googlesierraisland
authored andcommitted
[Disagg] Use pathways resharding api to handle transfer (#935)
1 parent 6cdabdf commit a6666ca

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

tpu_inference/runner/kv_cache_manager.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import jax
55
import jax.numpy as jnp
6+
import vllm.envs as envs
67
from jax.sharding import NamedSharding, PartitionSpec
78
from torchax.ops.mappings import t2j_dtype
89
from vllm.attention import Attention
@@ -375,9 +376,23 @@ def transfer_kv_cache(self,
375376
)
376377
sharding = NamedSharding(self.runner.mesh,
377378
PartitionSpec(None, "model"))
378-
transferred_kv_cache = jax.device_put(kv_cache_slices, sharding)
379-
for cache in transferred_kv_cache:
380-
cache.block_until_ready()
379+
if envs.VLLM_TPU_USING_PATHWAYS:
380+
from pathwaysutils.experimental import \
381+
reshard as experimental_reshard
382+
383+
def get_sharding(x):
384+
return sharding
385+
386+
sharding_spec_pytree = jax.tree.map(get_sharding, kv_cache_slices)
387+
transferred_kv_cache = experimental_reshard.reshard(
388+
tuple(kv_cache_slices),
389+
tuple(sharding_spec_pytree),
390+
donate=False,
391+
)
392+
else:
393+
transferred_kv_cache = jax.device_put(kv_cache_slices, sharding)
394+
395+
jax.block_until_ready(transferred_kv_cache)
381396
return transferred_kv_cache
382397

383398
def insert_request_with_kv_cache(

0 commit comments

Comments
 (0)