From dcfb0b2e07a9c95107dba7cb8e82ce7a6ba8c264 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 22 Oct 2025 23:11:07 +0000 Subject: [PATCH 01/10] update gather and insert (also reshape) functions Signed-off-by: Juncheng Gu --- .../cpu_offloading_kv_roundtrip_test.py | 5 +- .../distributed/tpu_connector_local.py | 70 +++++++++---------- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/tests/distributed/cpu_offloading_kv_roundtrip_test.py b/tests/distributed/cpu_offloading_kv_roundtrip_test.py index b9047a5d1..7cb1b76ab 100644 --- a/tests/distributed/cpu_offloading_kv_roundtrip_test.py +++ b/tests/distributed/cpu_offloading_kv_roundtrip_test.py @@ -243,5 +243,6 @@ def create_on_device(key): jax.block_until_ready(worker.runner.kv_caches) # 5. Verify TPU Reloaded Content - self.assertArraysEqual(source_kv_cache[0][target_block_ids, ...], - dest_kv_cache[0][target_block_ids, ...]) + self.assertArraysEqual( + source_kv_cache[0][target_block_ids, ...], + worker.runner.kv_caches[0][target_block_ids, ...]) diff --git a/tpu_inference/distributed/tpu_connector_local.py b/tpu_inference/distributed/tpu_connector_local.py index 0dc682e24..ce7733982 100644 --- a/tpu_inference/distributed/tpu_connector_local.py +++ b/tpu_inference/distributed/tpu_connector_local.py @@ -105,6 +105,7 @@ from vllm.forward_context import ForwardContext from tpu_inference.logger import init_logger +from tpu_inference.runner.kv_cache_manager import KVCacheManager from tpu_inference.runner.tpu_jax_runner import TPUModelRunner from .cache_util import CPU_OFFLOADING_SWAP_OP_TYPE, TokenProcessor, swap_ops @@ -829,16 +830,19 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], try: start_time = time.time() + blocks_to_process = jnp.array(blocks_to_process) + # gather and reshape blocks on TPU first: output_shape: [process_blocks * block_size, num_heads, 2, head_dim] + extracted_blocks_tpu = KVCacheManager._jitted_gather_kv_cache( + self.runner.kv_caches, blocks_to_process) - # Extract blocks on TPU first - extracted_blocks_tpu = [ - layer_cache_tpu[blocks_to_process, ...] - for layer_cache_tpu in self.runner.kv_caches - ] + jax.block_until_ready(extracted_blocks_tpu) + logger.info( + f"extracted_blocks_tpu: {extracted_blocks_tpu[0].shape}, {extracted_blocks_tpu[0].sharding}" + ) # Initiate non-blocking copy to CPU kv_caches_on_cpu = [ - swap_ops(extracted_blocks, self.host_sharding, "d2h", + swap_ops(extracted_blocks, self.flatten_host_sharding, "d2h", self.swap_op_type) for extracted_blocks in extracted_blocks_tpu ] @@ -857,26 +861,17 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], f"Shape of a single layer on CPU before reshape (num_blocks, block_size, ...): {kv_caches_on_cpu[0].shape}" ) - post_transfer_start_time = time.time() - # Reshape per-layer data from (num_blocks, block_size, ...) to - # a flat (total_tokens, ...) array for easy slicing. - flat_kv_caches_on_cpu = [ - layer_cache.reshape(-1, *layer_cache.shape[2:]) - for layer_cache in kv_caches_on_cpu - ] - - jax.block_until_ready(flat_kv_caches_on_cpu) - - if flat_kv_caches_on_cpu: total_size_bytes = sum(layer.nbytes - for layer in flat_kv_caches_on_cpu) + for layer in kv_caches_on_cpu) logger.info( - f"Total size of flat_kv_caches_on_cpu: {total_size_bytes / 1024**2:.2f} MB" + f"Total size of kv_caches_on_cpu: {total_size_bytes / 1024**2:.2f} MB" ) logger.info( - f"Shape of a single layer after reshape (total_tokens, ...): {flat_kv_caches_on_cpu[0].shape}" + f"Shape of a single layer after reshape (total_tokens, ...): {kv_caches_on_cpu[0].shape}" ) + post_transfer_start_time = time.time() + # Generate keys for the entire token sequence to get absolute positions. This to ensure that the delta # tokens that is about to be captured in the cache are correctly mapped. These keys will be recreated # during get_finished() to unpin the correct keys. @@ -904,7 +899,7 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], rel_start_idx, rel_end_idx, axis=0) - for flat_layer_cache in flat_kv_caches_on_cpu + for flat_layer_cache in kv_caches_on_cpu ] jax.block_until_ready(value_for_key) self.cpu_backend.add(key, value_for_key) @@ -1019,11 +1014,22 @@ def register_runner(self, runner: TPUModelRunner): spec=self.device_sharding.spec, memory_kind="pinned_host") + self.flatten_device_sharding = jax.sharding.NamedSharding( + mesh=self.device_sharding.mesh, + spec=jax.sharding.PartitionSpec(None, "model"), + memory_kind="device") + self.flatten_host_sharding = jax.sharding.NamedSharding( + mesh=self.device_sharding.mesh, + spec=jax.sharding.PartitionSpec(None, "model"), + memory_kind="pinned_host") + logger.info("KV Cache details registered in TPUConnectorWorker:") logger.info(f" - Num layers: {self.num_layers}") logger.info(f" - Shape per layer: {self.shape}") logger.info(f" - DType: {self.dtype}") logger.info(f" - Device sharding: {self.device_sharding}") + logger.info( + f" - Flatten Device sharding: {self.flatten_device_sharding}") logger.info(f" - Layout: {self.kv_cache_layout}") else: logger.warning("TPUConnectorWorker registered with no KV caches.") @@ -1155,23 +1161,15 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: else: padded_kv_on_cpu = final_kv_on_cpu - # 4. Reshape data back to block format for the update operation. - block_shaped_kv_on_cpu = [ - layer_data.reshape(num_blocks_to_load, self.block_size, - *layer_data.shape[1:]) - for layer_data in padded_kv_on_cpu - ] - - jax.block_until_ready(block_shaped_kv_on_cpu) + jax.block_until_ready(padded_kv_on_cpu) logger.info( - f"Request {meta.req_id}: Reshaped data for transfer to TPU. Shape for one layer: {block_shaped_kv_on_cpu[0].shape}." + f"Request {meta.req_id}: Reshaped data for transfer to TPU. Shape for one layer: {padded_kv_on_cpu[0].shape}." ) # 5. Transfer to TPU, applying the correct sharding. loaded_kv_sharded_on_tpu = [ - swap_ops(layer_data, self.device_sharding, "h2d", - self.swap_op_type) - for layer_data in block_shaped_kv_on_cpu + swap_ops(layer_data, self.flatten_device_sharding, "h2d", + self.swap_op_type) for layer_data in padded_kv_on_cpu ] jax.block_until_ready(loaded_kv_sharded_on_tpu) logger.info( @@ -1180,9 +1178,9 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: # 6. Update the runner's KV cache with the correctly sharded data. destination_blocks = meta.local_block_ids - for i in range(len(self.runner.kv_caches)): - self.runner.kv_caches[i] = self.runner.kv_caches[i].at[ - destination_blocks, ...].set(loaded_kv_sharded_on_tpu[i]) + self.runner.kv_caches = KVCacheManager._jitted_insert_kv_cache( + self.block_size, self.runner.kv_caches, + loaded_kv_sharded_on_tpu, jnp.array(destination_blocks)) jax.block_until_ready(self.runner.kv_caches) logger.info( f"Successfully loaded {len(destination_blocks)} blocks into TPU KV cache for request {meta.req_id}" From 1050fae453349b491e54f37ef48120cc7799b108 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 23 Oct 2025 08:16:51 +0000 Subject: [PATCH 02/10] revise swap_op Signed-off-by: Juncheng Gu --- tpu_inference/distributed/cache_util.py | 109 ++++++++++++++---- .../distributed/tpu_connector_local.py | 108 +++++++++-------- tpu_inference/kernels/dma/host_dma.py | 26 ++--- 3 files changed, 151 insertions(+), 92 deletions(-) diff --git a/tpu_inference/distributed/cache_util.py b/tpu_inference/distributed/cache_util.py index 5915264df..4c149d059 100644 --- a/tpu_inference/distributed/cache_util.py +++ b/tpu_inference/distributed/cache_util.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the LMCache project +import functools import hashlib from dataclasses import dataclass -from typing import Iterable, List, Literal, Optional, Tuple +from typing import Any, Iterable, List, Literal, Optional, Tuple import jax from vllm.config import get_current_vllm_config @@ -103,31 +104,89 @@ def get_kv_connector_cache_layout(): return None -def swap_ops( - src_kv_cache: jax.Array, - out_sharding: Optional[jax.sharding.NamedSharding], +# NOTE(jcgu): keep the same interface as the pallas one +def jax_swap_kv_caches( + src_kv_caches: List[jax.Array], + src_sharding: jax.sharding.NamedSharding, + dst_sharding: jax.sharding.NamedSharding, direction: Literal["h2d", "d2h"], - op_type: CPU_OFFLOADING_SWAP_OP_TYPE, -) -> jax.Array: - if op_type == "jax": - return jax_swap_kv_cache(src_kv_cache, out_sharding, direction) - return dma_kv_cache(src_kv_cache, out_sharding, direction) +) -> List[jax.Array]: + """Swap in / out multi-layer kv_cache using jax device_put + + Args: + src_kv_caches: [kv_cache of each layer] + src_sharding: kv_caches' original sharding + dst_sharding: kv_caches' target sharding (different memory_kind) + direction: h2d -> swap_in, d2h -> swap_out + Returns: + a list of jax.Array objects with the dst_sharding + """ + + def _jax_device_put(input_array): + return jax.device_put(input_array, dst_sharding) + return jax.tree.map(_jax_device_put, src_kv_caches) -def jax_swap_kv_cache( - src_kv_cache: jax.Array, - out_sharding: Optional[jax.sharding.NamedSharding], + +def pallas_swap_kv_caches( + src_kv_caches: List[jax.Array], + src_sharding: jax.sharding.NamedSharding, + dst_sharding: jax.sharding.NamedSharding, direction: Literal["h2d", "d2h"], -) -> jax.Array: - cpu_device = jax.devices("cpu")[0] - return jax.device_put(src_kv_cache, - cpu_device if direction == "d2h" else out_sharding) - - -def dma_kv_cache( - src_kv_cache: jax.Array, - out_sharding: jax.sharding.NamedSharding, - direction: CPU_OFFLOADING_SWAP_OP_TYPE, -) -> jax.Array: - dma_fn = d2h_dma if direction == "d2h" else h2d_dma - return dma_fn(src_kv_cache, out_sharding) +) -> List[jax.Array]: + """Swap in / out multi-layer kv_cache using pallas dma kernel + + Args: + src_kv_caches: [kv_cache of each layer] + src_sharding: kv_caches' original sharding + dst_sharding: kv_caches' target sharding (different memory_kind) + direction: h2d -> swap_in, d2h -> swap_out + Returns: + a list of jax.Array objects with the dst_sharding + """ + + def swap_in_fn(inputs, input_sharding, out_sharding): + + def _swap_in(hbm_sharded_array): + return h2d_dma(hbm_sharded_array, input_sharding, out_sharding) + + return jax.tree.map(_swap_in, inputs) + + def swap_out_fn(inputs, input_sharding, out_sharding): + + def _swap_out(hbm_sharded_array): + return d2h_dma(hbm_sharded_array, input_sharding, out_sharding) + + return jax.tree.map(_swap_out, inputs) + + if direction == "d2h": + return swap_out_fn(src_kv_caches, src_sharding, dst_sharding) + elif direction == "h2d": + return swap_in_fn(src_kv_caches, src_sharding, dst_sharding) + + +def get_jitted_swap_fn( + swap_op_type: CPU_OFFLOADING_SWAP_OP_TYPE, + host_sharding: jax.sharding.NamedSharding, + device_sharding: jax.sharding.NamedSharding) -> List[Any]: + """jit compile the swap_in and swap_out functions + + Args: + swap_op_type : (str) pallas or jax + host_sharding: + device_sharding: + + Returns: + [jitted_swap_in_fn, jitted_swap_out_fn] + """ + _swap_fn = pallas_swap_kv_caches if swap_op_type == "pallas" else jax_swap_kv_caches + # swap_in (host_sharding), swap_out (device_sharding) + return functools.partial( + jax.jit(_swap_fn, + static_argnames=["src_sharding", "dst_sharding", "direction"], + out_shardings=device_sharding), + direction="h2d"), functools.partial(jax.jit( + _swap_fn, + static_argnames=["src_sharding", "dst_sharding", "direction"], + out_shardings=host_sharding), + direction="d2h") diff --git a/tpu_inference/distributed/tpu_connector_local.py b/tpu_inference/distributed/tpu_connector_local.py index ce7733982..e7f184818 100644 --- a/tpu_inference/distributed/tpu_connector_local.py +++ b/tpu_inference/distributed/tpu_connector_local.py @@ -108,7 +108,8 @@ from tpu_inference.runner.kv_cache_manager import KVCacheManager from tpu_inference.runner.tpu_jax_runner import TPUModelRunner -from .cache_util import CPU_OFFLOADING_SWAP_OP_TYPE, TokenProcessor, swap_ops +from .cache_util import (CPU_OFFLOADING_SWAP_OP_TYPE, TokenProcessor, + get_jitted_swap_fn) from .local_cpu_backend import LocalCPUBackend EngineId = str @@ -760,6 +761,9 @@ def __init__(self, vllm_config: VllmConfig, connector: "TPUConnector"): logger.info( f"(cpu offloading) swap operation type is {self.swap_op_type}") + self.swap_in_fn = None + self.swap_out_fn = None + self.host = self.config.kv_ip self.kv_transfer_port = self.config.kv_port @@ -785,6 +789,52 @@ def __del__(self): logger.info("TPUConnectorWorker: Entering __del__") self.save_executor.shutdown(wait=True) + def register_runner(self, runner: TPUModelRunner): + logger.info("TPUConnectorWorker: Entering register_runner") + self.runner = runner + self.mesh = runner.mesh + # Get the spec of the kv_caches + kv_caches = runner.kv_caches + if kv_caches: + self.kv_cache_layout = runner.get_kv_cache_layout() + kv_layer = kv_caches[0] + self.num_layers = len(kv_caches) + self.shape = list(kv_layer.shape) + self.dtype = kv_layer.dtype + self.device_sharding = kv_layer.sharding + # TODO(jcgu): handle SingleDeviceSharding + self.host_sharding = jax.sharding.NamedSharding( + mesh=self.device_sharding.mesh, + spec=self.device_sharding.spec, + memory_kind="pinned_host") + + # NOTE(jcgu): needed when sliced-kv is [num_tokens, num_head, head_dim] + self.flatten_device_sharding = jax.sharding.NamedSharding( + mesh=self.device_sharding.mesh, + spec=jax.sharding.PartitionSpec(None, "model"), + memory_kind="device") + self.flatten_host_sharding = jax.sharding.NamedSharding( + mesh=self.device_sharding.mesh, + spec=jax.sharding.PartitionSpec(None, "model"), + memory_kind="pinned_host") + + self.swap_in_fn, self.swap_out_fn = get_jitted_swap_fn( + self.swap_op_type, + host_sharding=self.flatten_host_sharding, + device_sharding=self.flatten_device_sharding) + + logger.info("KV Cache details registered in TPUConnectorWorker:") + logger.info(f" - Num layers: {self.num_layers}") + logger.info(f" - Shape per layer: {self.shape}") + logger.info(f" - DType: {self.dtype}") + logger.info(f" - Device sharding: {self.device_sharding}") + logger.info( + f" - Flatten Device sharding: {self.flatten_device_sharding}") + logger.info(f" - Layout: {self.kv_cache_layout}") + else: + raise ValueError( + "TPUConnectorWorker registered with no KV caches.") + def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], full_token_ids: list[int], save_spec: SaveSpec) -> ReqId: @@ -840,13 +890,9 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], f"extracted_blocks_tpu: {extracted_blocks_tpu[0].shape}, {extracted_blocks_tpu[0].sharding}" ) - # Initiate non-blocking copy to CPU - kv_caches_on_cpu = [ - swap_ops(extracted_blocks, self.flatten_host_sharding, "d2h", - self.swap_op_type) - for extracted_blocks in extracted_blocks_tpu - ] - + kv_caches_on_cpu = self.swap_out_fn(extracted_blocks_tpu, + self.flatten_device_sharding, + self.flatten_host_sharding) # Block until the transfer is complete if kv_caches_on_cpu: jax.block_until_ready(kv_caches_on_cpu) @@ -995,45 +1041,6 @@ def wait_for_save(self): f"completed in {duration:.4f} seconds.") self._processed_save_for_step = True - def register_runner(self, runner: TPUModelRunner): - logger.info("TPUConnectorWorker: Entering register_runner") - self.runner = runner - self.mesh = runner.mesh - # Get the spec of the kv_caches - kv_caches = runner.kv_caches - if kv_caches: - self.kv_cache_layout = runner.get_kv_cache_layout() - kv_layer = kv_caches[0] - self.num_layers = len(kv_caches) - self.shape = list(kv_layer.shape) - self.dtype = kv_layer.dtype - self.device_sharding = kv_layer.sharding - # TODO(jcgu): handle SingleDeviceSharding - self.host_sharding = jax.sharding.NamedSharding( - mesh=self.device_sharding.mesh, - spec=self.device_sharding.spec, - memory_kind="pinned_host") - - self.flatten_device_sharding = jax.sharding.NamedSharding( - mesh=self.device_sharding.mesh, - spec=jax.sharding.PartitionSpec(None, "model"), - memory_kind="device") - self.flatten_host_sharding = jax.sharding.NamedSharding( - mesh=self.device_sharding.mesh, - spec=jax.sharding.PartitionSpec(None, "model"), - memory_kind="pinned_host") - - logger.info("KV Cache details registered in TPUConnectorWorker:") - logger.info(f" - Num layers: {self.num_layers}") - logger.info(f" - Shape per layer: {self.shape}") - logger.info(f" - DType: {self.dtype}") - logger.info(f" - Device sharding: {self.device_sharding}") - logger.info( - f" - Flatten Device sharding: {self.flatten_device_sharding}") - logger.info(f" - Layout: {self.kv_cache_layout}") - else: - logger.warning("TPUConnectorWorker registered with no KV caches.") - def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: """ This function is the worker-side entry point for loading data from the @@ -1167,10 +1174,9 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: ) # 5. Transfer to TPU, applying the correct sharding. - loaded_kv_sharded_on_tpu = [ - swap_ops(layer_data, self.flatten_device_sharding, "h2d", - self.swap_op_type) for layer_data in padded_kv_on_cpu - ] + loaded_kv_sharded_on_tpu = self.swap_in_fn( + padded_kv_on_cpu, self.flatten_host_sharding, + self.flatten_device_sharding) #, "h2d") jax.block_until_ready(loaded_kv_sharded_on_tpu) logger.info( f"loaded_kv_on_tpu[0]: {loaded_kv_sharded_on_tpu[0].shape}, {loaded_kv_sharded_on_tpu[0].sharding}" diff --git a/tpu_inference/kernels/dma/host_dma.py b/tpu_inference/kernels/dma/host_dma.py index cbcfc7249..68a53f9d0 100644 --- a/tpu_inference/kernels/dma/host_dma.py +++ b/tpu_inference/kernels/dma/host_dma.py @@ -22,20 +22,17 @@ def body(sem): # NOTE(jcgu): only support NamedSharding, does not support SingleDeviceSharding def d2h_dma( input_array: jax.Array, + input_sharding: jax.sharding.NamedSharding, out_sharding: jax.sharding.NamedSharding, ) -> jax.Array: """ DMA a device jax array to host memory. Args: input_array: input jax array on device hbm + input_sharding: input's device sharding out_sharding: output's host sharding Returns: jax array on host memory with the same sharding """ - device_sharding = input_array.sharding - assert isinstance(device_sharding, jax.sharding.NamedSharding) - assert isinstance(out_sharding, jax.sharding.NamedSharding) - assert device_sharding.memory_kind == "device" - assert out_sharding.memory_kind == "pinned_host" @jax.jit def _d2h_dma_call(x): @@ -52,9 +49,10 @@ def _d2h_dma_call(x): d2h_dma_kernel = jax.jit( jax.shard_map( _d2h_dma_call, - mesh=device_sharding.mesh, - in_specs=device_sharding.spec, + mesh=input_sharding.mesh, + in_specs=input_sharding.spec, out_specs=out_sharding.spec, + check_vma=False, ), out_shardings=out_sharding, ) @@ -66,23 +64,18 @@ def _d2h_dma_call(x): # NOTE(jcgu): only support NamedSharding, does not support SingleDeviceSharding def h2d_dma( input_array: jax.Array, + input_sharding: jax.sharding.NamedSharding, out_sharding: jax.sharding.NamedSharding, ) -> jax.Array: """ DMA a host jax array to device hbm. Args: input_array: input jax array on host memory + input_sharding: the host sharding for input out_sharding: the device sharding for output Returns: jax array on device hbm with the assigned sharding """ - host_sharding = input_array.sharding - - assert isinstance(host_sharding, jax.sharding.NamedSharding) - assert isinstance(out_sharding, jax.sharding.NamedSharding) - assert host_sharding.memory_kind == "pinned_host" - assert out_sharding.memory_kind == "device" - @jax.jit def _h2d_dma_call(x): return pl.pallas_call( @@ -98,11 +91,12 @@ def _h2d_dma_call(x): h2d_dma_kernel = jax.jit( jax.shard_map( _h2d_dma_call, - mesh=host_sharding.mesh, - in_specs=host_sharding.spec, + mesh=input_sharding.mesh, + in_specs=input_sharding.spec, out_specs=out_sharding.spec, check_vma=False, ), out_shardings=out_sharding, ) + return h2d_dma_kernel(input_array) From 433291644874e642c7657fec66c0833b005c88e6 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 23 Oct 2025 17:24:25 +0000 Subject: [PATCH 03/10] formalize swap fn interface Signed-off-by: Juncheng Gu --- tpu_inference/distributed/cache_util.py | 52 +++++++++++++------ .../distributed/tpu_connector_local.py | 18 +++---- 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/tpu_inference/distributed/cache_util.py b/tpu_inference/distributed/cache_util.py index 4c149d059..b4831ca2b 100644 --- a/tpu_inference/distributed/cache_util.py +++ b/tpu_inference/distributed/cache_util.py @@ -4,7 +4,7 @@ import functools import hashlib from dataclasses import dataclass -from typing import Any, Iterable, List, Literal, Optional, Tuple +from typing import Callable, Iterable, List, Literal, Optional, Tuple import jax from vllm.config import get_current_vllm_config @@ -104,6 +104,19 @@ def get_kv_connector_cache_layout(): return None +SwapFn = Callable[ + [ + List[jax.Array], # src_kv_caches + jax.sharding.NamedSharding, # src_sharding + jax.sharding.NamedSharding, # dst_sharding + Literal["h2d", "d2h"], # direction + ], + List[jax.Array], # return value +] + +JittedKVCacheSwapFn = Callable[[List[jax.Array]], List[jax.Array]] + + # NOTE(jcgu): keep the same interface as the pallas one def jax_swap_kv_caches( src_kv_caches: List[jax.Array], @@ -165,10 +178,11 @@ def _swap_out(hbm_sharded_array): return swap_in_fn(src_kv_caches, src_sharding, dst_sharding) -def get_jitted_swap_fn( - swap_op_type: CPU_OFFLOADING_SWAP_OP_TYPE, - host_sharding: jax.sharding.NamedSharding, - device_sharding: jax.sharding.NamedSharding) -> List[Any]: +def get_jitted_kv_cache_swap_fn( + swap_op_type: CPU_OFFLOADING_SWAP_OP_TYPE, + host_sharding: jax.sharding.NamedSharding, + device_sharding: jax.sharding.NamedSharding +) -> Tuple[JittedKVCacheSwapFn, JittedKVCacheSwapFn]: """jit compile the swap_in and swap_out functions Args: @@ -177,16 +191,22 @@ def get_jitted_swap_fn( device_sharding: Returns: - [jitted_swap_in_fn, jitted_swap_out_fn] + A tuple containing the jitted swap-in and swap-out functions. """ - _swap_fn = pallas_swap_kv_caches if swap_op_type == "pallas" else jax_swap_kv_caches + _swap_fn: SwapFn = pallas_swap_kv_caches if swap_op_type == "pallas" else jax_swap_kv_caches # swap_in (host_sharding), swap_out (device_sharding) - return functools.partial( - jax.jit(_swap_fn, - static_argnames=["src_sharding", "dst_sharding", "direction"], - out_shardings=device_sharding), - direction="h2d"), functools.partial(jax.jit( - _swap_fn, - static_argnames=["src_sharding", "dst_sharding", "direction"], - out_shardings=host_sharding), - direction="d2h") + swap_in_fn = functools.partial(jax.jit( + _swap_fn, + static_argnames=["src_sharding", "dst_sharding", "direction"], + out_shardings=device_sharding), + src_sharding=host_sharding, + dst_sharding=device_sharding, + direction="h2d") + swap_out_fn = functools.partial(jax.jit( + _swap_fn, + static_argnames=["src_sharding", "dst_sharding", "direction"], + out_shardings=host_sharding), + src_sharding=device_sharding, + dst_sharding=host_sharding, + direction="d2h") + return swap_in_fn, swap_out_fn diff --git a/tpu_inference/distributed/tpu_connector_local.py b/tpu_inference/distributed/tpu_connector_local.py index e7f184818..ad5c65fc3 100644 --- a/tpu_inference/distributed/tpu_connector_local.py +++ b/tpu_inference/distributed/tpu_connector_local.py @@ -108,8 +108,8 @@ from tpu_inference.runner.kv_cache_manager import KVCacheManager from tpu_inference.runner.tpu_jax_runner import TPUModelRunner -from .cache_util import (CPU_OFFLOADING_SWAP_OP_TYPE, TokenProcessor, - get_jitted_swap_fn) +from .cache_util import (CPU_OFFLOADING_SWAP_OP_TYPE, JittedKVCacheSwapFn, + TokenProcessor, get_jitted_kv_cache_swap_fn) from .local_cpu_backend import LocalCPUBackend EngineId = str @@ -761,8 +761,8 @@ def __init__(self, vllm_config: VllmConfig, connector: "TPUConnector"): logger.info( f"(cpu offloading) swap operation type is {self.swap_op_type}") - self.swap_in_fn = None - self.swap_out_fn = None + self.swap_in_fn: JittedKVCacheSwapFn = None + self.swap_out_fn: JittedKVCacheSwapFn = None self.host = self.config.kv_ip self.kv_transfer_port = self.config.kv_port @@ -818,7 +818,7 @@ def register_runner(self, runner: TPUModelRunner): spec=jax.sharding.PartitionSpec(None, "model"), memory_kind="pinned_host") - self.swap_in_fn, self.swap_out_fn = get_jitted_swap_fn( + self.swap_in_fn, self.swap_out_fn = get_jitted_kv_cache_swap_fn( self.swap_op_type, host_sharding=self.flatten_host_sharding, device_sharding=self.flatten_device_sharding) @@ -890,9 +890,7 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], f"extracted_blocks_tpu: {extracted_blocks_tpu[0].shape}, {extracted_blocks_tpu[0].sharding}" ) - kv_caches_on_cpu = self.swap_out_fn(extracted_blocks_tpu, - self.flatten_device_sharding, - self.flatten_host_sharding) + kv_caches_on_cpu = self.swap_out_fn(extracted_blocks_tpu) # Block until the transfer is complete if kv_caches_on_cpu: jax.block_until_ready(kv_caches_on_cpu) @@ -1174,9 +1172,7 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: ) # 5. Transfer to TPU, applying the correct sharding. - loaded_kv_sharded_on_tpu = self.swap_in_fn( - padded_kv_on_cpu, self.flatten_host_sharding, - self.flatten_device_sharding) #, "h2d") + loaded_kv_sharded_on_tpu = self.swap_in_fn(padded_kv_on_cpu) jax.block_until_ready(loaded_kv_sharded_on_tpu) logger.info( f"loaded_kv_on_tpu[0]: {loaded_kv_sharded_on_tpu[0].shape}, {loaded_kv_sharded_on_tpu[0].sharding}" From 1cba04d62fe2e91aa40b3b7a889c1a50fc0d3883 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 23 Oct 2025 18:35:28 +0000 Subject: [PATCH 04/10] fix env name Signed-off-by: Juncheng Gu --- tests/distributed/cpu_offloading_kv_roundtrip_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/cpu_offloading_kv_roundtrip_test.py b/tests/distributed/cpu_offloading_kv_roundtrip_test.py index 7cb1b76ab..86dcba1e6 100644 --- a/tests/distributed/cpu_offloading_kv_roundtrip_test.py +++ b/tests/distributed/cpu_offloading_kv_roundtrip_test.py @@ -139,7 +139,7 @@ def test_tpu_connector_d2h_h2d_roundtrip(self, model_axis_size: int, return None # 1. Setup - os.environ['TPU_KV_OFFLOADING_SWAP_OP_TYPE'] = swap_op_type + os.environ['TPU_OFFLOADING_SWAP_OP_TYPE'] = swap_op_type mesh = self.create_mesh((1, model_axis_size), ("data", "model")) if mesh is None: return None From dcbf7ff64a47e0ea97cc1318571282297ad65a59 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 23 Oct 2025 20:00:15 +0000 Subject: [PATCH 05/10] fix slice Signed-off-by: Juncheng Gu --- tpu_inference/distributed/tpu_connector_local.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tpu_inference/distributed/tpu_connector_local.py b/tpu_inference/distributed/tpu_connector_local.py index ad5c65fc3..c49907a9d 100644 --- a/tpu_inference/distributed/tpu_connector_local.py +++ b/tpu_inference/distributed/tpu_connector_local.py @@ -1132,9 +1132,9 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: # now truncate to N-1 before padding and loading, to match the # allocation made by the scheduler. if meta.load_spec.is_full_prefix_hit: - final_kv_on_cpu = [ - layer_data[:-1] for layer_data in final_kv_on_cpu - ] + final_kv_on_cpu = jax.tree.map( + lambda x: jax.lax.slice_in_dim(x, 0, x.shape[0] - 1), + final_kv_on_cpu) logger.info( f"Request {meta.req_id}: is_full_prefix_hit = {meta.load_spec.is_full_prefix_hit}" "Truncated fetched cache data by 1 token. New shape: " From bb81a6672a5982da404e41e26d86060c75b37e7a Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 23 Oct 2025 23:10:23 +0000 Subject: [PATCH 06/10] check memory_kind Signed-off-by: Juncheng Gu --- tests/distributed/cpu_offloading_kv_roundtrip_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/distributed/cpu_offloading_kv_roundtrip_test.py b/tests/distributed/cpu_offloading_kv_roundtrip_test.py index 86dcba1e6..02abdd3ab 100644 --- a/tests/distributed/cpu_offloading_kv_roundtrip_test.py +++ b/tests/distributed/cpu_offloading_kv_roundtrip_test.py @@ -211,6 +211,7 @@ def create_on_device(key): assert len( cached_value ) == num_layers, f"cache_value layer: {len(cached_value)} != {num_layers}" + assert cached_value[0].sharding.memory_kind == "pinned_host" retrieved_chunks.append(cached_value[0]) # Get first layer # Assemble on CPU and compare with original From c1b8ab545fdb2fb11448c4fffb96219fcaa835f2 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 24 Oct 2025 02:37:26 +0000 Subject: [PATCH 07/10] nit Signed-off-by: Juncheng Gu --- tests/distributed/host_offloading_accuracy_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/distributed/host_offloading_accuracy_test.py b/tests/distributed/host_offloading_accuracy_test.py index f2d32e795..1bcffbe79 100644 --- a/tests/distributed/host_offloading_accuracy_test.py +++ b/tests/distributed/host_offloading_accuracy_test.py @@ -51,9 +51,10 @@ def test_kv_cache_cpu_offloading_accuracy( ): with monkeypatch.context(): os.environ['SKIP_JAX_PRECOMPILE'] = '1' + os.environ['TPU_OFFLOADING_SWAP_OP_TYPE'] = "pallas" llm = LLM(model="meta-llama/Llama-3.2-3B", max_model_len=1024, - tensor_parallel_size=1, + tensor_parallel_size=8, task="generate", kv_transfer_config=kv_transfer_config) From 2d196e6f29eaef690452f2d2f54bb45f45fcc745 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 24 Oct 2025 03:44:33 +0000 Subject: [PATCH 08/10] swap_out, jax device_put, to SingleDeviceSharding(CPUDevice) Signed-off-by: Juncheng Gu --- .../distributed/tpu_connector_local.py | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/tpu_inference/distributed/tpu_connector_local.py b/tpu_inference/distributed/tpu_connector_local.py index c49907a9d..85b888f95 100644 --- a/tpu_inference/distributed/tpu_connector_local.py +++ b/tpu_inference/distributed/tpu_connector_local.py @@ -109,7 +109,7 @@ from tpu_inference.runner.tpu_jax_runner import TPUModelRunner from .cache_util import (CPU_OFFLOADING_SWAP_OP_TYPE, JittedKVCacheSwapFn, - TokenProcessor, get_jitted_kv_cache_swap_fn) + TokenProcessor) from .local_cpu_backend import LocalCPUBackend EngineId = str @@ -813,15 +813,36 @@ def register_runner(self, runner: TPUModelRunner): mesh=self.device_sharding.mesh, spec=jax.sharding.PartitionSpec(None, "model"), memory_kind="device") - self.flatten_host_sharding = jax.sharding.NamedSharding( - mesh=self.device_sharding.mesh, - spec=jax.sharding.PartitionSpec(None, "model"), - memory_kind="pinned_host") + # self.flatten_host_sharding = jax.sharding.NamedSharding( + # mesh=self.device_sharding.mesh, + # spec=jax.sharding.PartitionSpec(None, "model"), + # memory_kind="pinned_host") + + # self.swap_in_fn, self.swap_out_fn = get_jitted_kv_cache_swap_fn( + # self.swap_op_type, + # host_sharding=self.flatten_host_sharding, + # device_sharding=self.flatten_device_sharding) + + self.flatten_host_sharding = jax.devices("cpu")[0] + + def _jax_swap_in(src_kv_caches): + + def _jax_swap_in_(input_array): + return jax.device_put(input_array, jax.devices("cpu")[0]) + + return jax.tree.map(_jax_swap_in_, src_kv_caches) + + def _jax_swap_out(src_kv_caches): + + def _jax_swap_out_(input_array): + return jax.device_put(input_array, + self.flatten_device_sharding) + + return jax.tree.map(_jax_swap_out_, src_kv_caches) - self.swap_in_fn, self.swap_out_fn = get_jitted_kv_cache_swap_fn( - self.swap_op_type, - host_sharding=self.flatten_host_sharding, - device_sharding=self.flatten_device_sharding) + self.swap_in_fn = jax.jit( + _jax_swap_in, out_shardings=self.flatten_device_sharding) + self.swap_out_fn = jax.jit(_jax_swap_out) logger.info("KV Cache details registered in TPUConnectorWorker:") logger.info(f" - Num layers: {self.num_layers}") From 6aee938a758b2f3e595f4b81c2289c420f54da47 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 24 Oct 2025 05:19:27 +0000 Subject: [PATCH 09/10] nit Signed-off-by: Juncheng Gu --- tests/distributed/cpu_offloading_kv_roundtrip_test.py | 3 ++- tpu_inference/distributed/tpu_connector_local.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/distributed/cpu_offloading_kv_roundtrip_test.py b/tests/distributed/cpu_offloading_kv_roundtrip_test.py index 02abdd3ab..ac209fe47 100644 --- a/tests/distributed/cpu_offloading_kv_roundtrip_test.py +++ b/tests/distributed/cpu_offloading_kv_roundtrip_test.py @@ -211,7 +211,8 @@ def create_on_device(key): assert len( cached_value ) == num_layers, f"cache_value layer: {len(cached_value)} != {num_layers}" - assert cached_value[0].sharding.memory_kind == "pinned_host" + # NOTE(jcgu): comment out this assertion since we've reverted back to using SingleDeviceSharding + # assert cached_value[0].sharding.memory_kind == "pinned_host" retrieved_chunks.append(cached_value[0]) # Get first layer # Assemble on CPU and compare with original diff --git a/tpu_inference/distributed/tpu_connector_local.py b/tpu_inference/distributed/tpu_connector_local.py index 85b888f95..929c57063 100644 --- a/tpu_inference/distributed/tpu_connector_local.py +++ b/tpu_inference/distributed/tpu_connector_local.py @@ -813,6 +813,8 @@ def register_runner(self, runner: TPUModelRunner): mesh=self.device_sharding.mesh, spec=jax.sharding.PartitionSpec(None, "model"), memory_kind="device") + + # NOTE(jcgu): disable "pallas" swap op / NamedSharding due to core crash # self.flatten_host_sharding = jax.sharding.NamedSharding( # mesh=self.device_sharding.mesh, # spec=jax.sharding.PartitionSpec(None, "model"), @@ -823,8 +825,6 @@ def register_runner(self, runner: TPUModelRunner): # host_sharding=self.flatten_host_sharding, # device_sharding=self.flatten_device_sharding) - self.flatten_host_sharding = jax.devices("cpu")[0] - def _jax_swap_in(src_kv_caches): def _jax_swap_in_(input_array): From 8a77b7533dae5d8549470b338cad0b526962caf1 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Sat, 25 Oct 2025 15:39:37 +0000 Subject: [PATCH 10/10] fix comments Signed-off-by: Juncheng Gu --- .../cpu_offloading_kv_roundtrip_test.py | 7 +-- .../distributed/tpu_connector_local.py | 48 +++++++------------ 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/tests/distributed/cpu_offloading_kv_roundtrip_test.py b/tests/distributed/cpu_offloading_kv_roundtrip_test.py index ac209fe47..5cc000b03 100644 --- a/tests/distributed/cpu_offloading_kv_roundtrip_test.py +++ b/tests/distributed/cpu_offloading_kv_roundtrip_test.py @@ -245,6 +245,7 @@ def create_on_device(key): jax.block_until_ready(worker.runner.kv_caches) # 5. Verify TPU Reloaded Content - self.assertArraysEqual( - source_kv_cache[0][target_block_ids, ...], - worker.runner.kv_caches[0][target_block_ids, ...]) + for i in range(num_layers): + self.assertArraysEqual( + source_kv_cache[i][target_block_ids, ...], + worker.runner.kv_caches[i][target_block_ids, ...]) diff --git a/tpu_inference/distributed/tpu_connector_local.py b/tpu_inference/distributed/tpu_connector_local.py index 929c57063..6627617ca 100644 --- a/tpu_inference/distributed/tpu_connector_local.py +++ b/tpu_inference/distributed/tpu_connector_local.py @@ -802,11 +802,6 @@ def register_runner(self, runner: TPUModelRunner): self.shape = list(kv_layer.shape) self.dtype = kv_layer.dtype self.device_sharding = kv_layer.sharding - # TODO(jcgu): handle SingleDeviceSharding - self.host_sharding = jax.sharding.NamedSharding( - mesh=self.device_sharding.mesh, - spec=self.device_sharding.spec, - memory_kind="pinned_host") # NOTE(jcgu): needed when sliced-kv is [num_tokens, num_head, head_dim] self.flatten_device_sharding = jax.sharding.NamedSharding( @@ -814,34 +809,25 @@ def register_runner(self, runner: TPUModelRunner): spec=jax.sharding.PartitionSpec(None, "model"), memory_kind="device") - # NOTE(jcgu): disable "pallas" swap op / NamedSharding due to core crash - # self.flatten_host_sharding = jax.sharding.NamedSharding( - # mesh=self.device_sharding.mesh, - # spec=jax.sharding.PartitionSpec(None, "model"), - # memory_kind="pinned_host") - - # self.swap_in_fn, self.swap_out_fn = get_jitted_kv_cache_swap_fn( - # self.swap_op_type, - # host_sharding=self.flatten_host_sharding, - # device_sharding=self.flatten_device_sharding) - def _jax_swap_in(src_kv_caches): - + # input_array should exist on HBM def _jax_swap_in_(input_array): return jax.device_put(input_array, jax.devices("cpu")[0]) return jax.tree.map(_jax_swap_in_, src_kv_caches) def _jax_swap_out(src_kv_caches): - + # input_array should exist on CPU def _jax_swap_out_(input_array): return jax.device_put(input_array, self.flatten_device_sharding) return jax.tree.map(_jax_swap_out_, src_kv_caches) + # the output (on device) of swap_in should apply NamedSharding self.swap_in_fn = jax.jit( _jax_swap_in, out_shardings=self.flatten_device_sharding) + # the output (on host) of swap_out should apply SingleDeviceSharding self.swap_out_fn = jax.jit(_jax_swap_out) logger.info("KV Cache details registered in TPUConnectorWorker:") @@ -903,36 +889,36 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], start_time = time.time() blocks_to_process = jnp.array(blocks_to_process) # gather and reshape blocks on TPU first: output_shape: [process_blocks * block_size, num_heads, 2, head_dim] - extracted_blocks_tpu = KVCacheManager._jitted_gather_kv_cache( + flat_kv_caches_tpu = KVCacheManager._jitted_gather_kv_cache( self.runner.kv_caches, blocks_to_process) - jax.block_until_ready(extracted_blocks_tpu) + jax.block_until_ready(flat_kv_caches_tpu) logger.info( - f"extracted_blocks_tpu: {extracted_blocks_tpu[0].shape}, {extracted_blocks_tpu[0].sharding}" + f"extracted_blocks_tpu: {flat_kv_caches_tpu[0].shape}, {flat_kv_caches_tpu[0].sharding}" ) - kv_caches_on_cpu = self.swap_out_fn(extracted_blocks_tpu) + flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu) # Block until the transfer is complete - if kv_caches_on_cpu: - jax.block_until_ready(kv_caches_on_cpu) + if flat_kv_caches_cpu: + jax.block_until_ready(flat_kv_caches_cpu) duration = time.time() - start_time logger.info( f"Successfully saved {len(blocks_to_process)} blocks for " f"request {req_id} to CPU in {duration:.4f} seconds.") - if kv_caches_on_cpu: + if flat_kv_caches_cpu: logger.info( - f"Shape of a single layer on CPU before reshape (num_blocks, block_size, ...): {kv_caches_on_cpu[0].shape}" + f"Shape of a single layer on CPU before reshape (num_blocks, block_size, ...): {flat_kv_caches_cpu[0].shape}" ) total_size_bytes = sum(layer.nbytes - for layer in kv_caches_on_cpu) + for layer in flat_kv_caches_cpu) logger.info( - f"Total size of kv_caches_on_cpu: {total_size_bytes / 1024**2:.2f} MB" + f"Total size of flat_kv_caches_cpu: {total_size_bytes / 1024**2:.2f} MB" ) logger.info( - f"Shape of a single layer after reshape (total_tokens, ...): {kv_caches_on_cpu[0].shape}" + f"Shape of a single layer after reshape (total_tokens, ...): {flat_kv_caches_cpu[0].shape}" ) post_transfer_start_time = time.time() @@ -951,7 +937,7 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], relevant_keys.append((abs_start_idx, abs_end_idx, key)) if relevant_keys: - # The flat_kv_caches_on_cpu array corresponds to the new tokens, + # The flat_kv_caches_cpu array corresponds to the new tokens, # so its indexing is relative to the start of the new data. for abs_start_idx, abs_end_idx, key in relevant_keys: # Calculate indices relative to the start of our new data slice. @@ -964,7 +950,7 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], rel_start_idx, rel_end_idx, axis=0) - for flat_layer_cache in kv_caches_on_cpu + for flat_layer_cache in flat_kv_caches_cpu ] jax.block_until_ready(value_for_key) self.cpu_backend.add(key, value_for_key)