File tree Expand file tree Collapse file tree 1 file changed +18
-3
lines changed Expand file tree Collapse file tree 1 file changed +18
-3
lines changed Original file line number Diff line number Diff line change 33
44import jax
55import jax .numpy as jnp
6+ import vllm .envs as envs
67from jax .sharding import NamedSharding , PartitionSpec
78from torchax .ops .mappings import t2j_dtype
89from vllm .attention import Attention
@@ -375,9 +376,23 @@ def transfer_kv_cache(self,
375376 )
376377 sharding = NamedSharding (self .runner .mesh ,
377378 PartitionSpec (None , "model" ))
378- transferred_kv_cache = jax .device_put (kv_cache_slices , sharding )
379- for cache in transferred_kv_cache :
380- cache .block_until_ready ()
379+ if envs .VLLM_TPU_USING_PATHWAYS :
380+ from pathwaysutils .experimental import \
381+ reshard as experimental_reshard
382+
383+ def get_sharding (x ):
384+ return sharding
385+
386+ sharding_spec_pytree = jax .tree .map (get_sharding , kv_cache_slices )
387+ transferred_kv_cache = experimental_reshard .reshard (
388+ tuple (kv_cache_slices ),
389+ tuple (sharding_spec_pytree ),
390+ donate = False ,
391+ )
392+ else :
393+ transferred_kv_cache = jax .device_put (kv_cache_slices , sharding )
394+
395+ jax .block_until_ready (transferred_kv_cache )
381396 return transferred_kv_cache
382397
383398 def insert_request_with_kv_cache (
You can’t perform that action at this time.
0 commit comments