Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tests/distributed/cpu_offloading_kv_roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_tpu_connector_d2h_h2d_roundtrip(self, model_axis_size: int,
return None

# 1. Setup
os.environ['TPU_KV_OFFLOADING_SWAP_OP_TYPE'] = swap_op_type
os.environ['TPU_OFFLOADING_SWAP_OP_TYPE'] = swap_op_type
mesh = self.create_mesh((1, model_axis_size), ("data", "model"))
if mesh is None:
return None
Expand Down Expand Up @@ -211,6 +211,7 @@ def create_on_device(key):
assert len(
cached_value
) == num_layers, f"cache_value layer: {len(cached_value)} != {num_layers}"
assert cached_value[0].sharding.memory_kind == "pinned_host"
retrieved_chunks.append(cached_value[0]) # Get first layer

# Assemble on CPU and compare with original
Expand Down Expand Up @@ -243,5 +244,6 @@ def create_on_device(key):
jax.block_until_ready(worker.runner.kv_caches)

# 5. Verify TPU Reloaded Content
self.assertArraysEqual(source_kv_cache[0][target_block_ids, ...],
dest_kv_cache[0][target_block_ids, ...])
self.assertArraysEqual(
source_kv_cache[0][target_block_ids, ...],
worker.runner.kv_caches[0][target_block_ids, ...])
3 changes: 2 additions & 1 deletion tests/distributed/host_offloading_accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def test_kv_cache_cpu_offloading_accuracy(
):
with monkeypatch.context():
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
os.environ['TPU_OFFLOADING_SWAP_OP_TYPE'] = "pallas"
llm = LLM(model="meta-llama/Llama-3.2-3B",
max_model_len=1024,
tensor_parallel_size=1,
tensor_parallel_size=8,
task="generate",
kv_transfer_config=kv_transfer_config)

Expand Down
129 changes: 104 additions & 25 deletions tpu_inference/distributed/cache_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the LMCache project

import functools
import hashlib
from dataclasses import dataclass
from typing import Iterable, List, Literal, Optional, Tuple
from typing import Callable, Iterable, List, Literal, Optional, Tuple

import jax
from vllm.config import get_current_vllm_config
Expand Down Expand Up @@ -103,31 +104,109 @@ def get_kv_connector_cache_layout():
return None


def swap_ops(
src_kv_cache: jax.Array,
out_sharding: Optional[jax.sharding.NamedSharding],
SwapFn = Callable[
[
List[jax.Array], # src_kv_caches
jax.sharding.NamedSharding, # src_sharding
jax.sharding.NamedSharding, # dst_sharding
Literal["h2d", "d2h"], # direction
],
List[jax.Array], # return value
]

JittedKVCacheSwapFn = Callable[[List[jax.Array]], List[jax.Array]]


# NOTE(jcgu): keep the same interface as the pallas one
def jax_swap_kv_caches(
src_kv_caches: List[jax.Array],
src_sharding: jax.sharding.NamedSharding,
dst_sharding: jax.sharding.NamedSharding,
direction: Literal["h2d", "d2h"],
op_type: CPU_OFFLOADING_SWAP_OP_TYPE,
) -> jax.Array:
if op_type == "jax":
return jax_swap_kv_cache(src_kv_cache, out_sharding, direction)
return dma_kv_cache(src_kv_cache, out_sharding, direction)
) -> List[jax.Array]:
"""Swap in / out multi-layer kv_cache using jax device_put

Args:
src_kv_caches: [kv_cache of each layer]
src_sharding: kv_caches' original sharding
dst_sharding: kv_caches' target sharding (different memory_kind)
direction: h2d -> swap_in, d2h -> swap_out
Returns:
a list of jax.Array objects with the dst_sharding
"""

def _jax_device_put(input_array):
return jax.device_put(input_array, dst_sharding)

return jax.tree.map(_jax_device_put, src_kv_caches)

def jax_swap_kv_cache(
src_kv_cache: jax.Array,
out_sharding: Optional[jax.sharding.NamedSharding],

def pallas_swap_kv_caches(
src_kv_caches: List[jax.Array],
src_sharding: jax.sharding.NamedSharding,
dst_sharding: jax.sharding.NamedSharding,
direction: Literal["h2d", "d2h"],
) -> jax.Array:
cpu_device = jax.devices("cpu")[0]
return jax.device_put(src_kv_cache,
cpu_device if direction == "d2h" else out_sharding)


def dma_kv_cache(
src_kv_cache: jax.Array,
out_sharding: jax.sharding.NamedSharding,
direction: CPU_OFFLOADING_SWAP_OP_TYPE,
) -> jax.Array:
dma_fn = d2h_dma if direction == "d2h" else h2d_dma
return dma_fn(src_kv_cache, out_sharding)
) -> List[jax.Array]:
"""Swap in / out multi-layer kv_cache using pallas dma kernel

Args:
src_kv_caches: [kv_cache of each layer]
src_sharding: kv_caches' original sharding
dst_sharding: kv_caches' target sharding (different memory_kind)
direction: h2d -> swap_in, d2h -> swap_out
Returns:
a list of jax.Array objects with the dst_sharding
"""

def swap_in_fn(inputs, input_sharding, out_sharding):

def _swap_in(hbm_sharded_array):
return h2d_dma(hbm_sharded_array, input_sharding, out_sharding)

return jax.tree.map(_swap_in, inputs)

def swap_out_fn(inputs, input_sharding, out_sharding):

def _swap_out(hbm_sharded_array):
return d2h_dma(hbm_sharded_array, input_sharding, out_sharding)

return jax.tree.map(_swap_out, inputs)

if direction == "d2h":
return swap_out_fn(src_kv_caches, src_sharding, dst_sharding)
elif direction == "h2d":
return swap_in_fn(src_kv_caches, src_sharding, dst_sharding)


def get_jitted_kv_cache_swap_fn(
swap_op_type: CPU_OFFLOADING_SWAP_OP_TYPE,
host_sharding: jax.sharding.NamedSharding,
device_sharding: jax.sharding.NamedSharding
) -> Tuple[JittedKVCacheSwapFn, JittedKVCacheSwapFn]:
"""jit compile the swap_in and swap_out functions

Args:
swap_op_type : (str) pallas or jax
host_sharding:
device_sharding:

Returns:
A tuple containing the jitted swap-in and swap-out functions.
"""
_swap_fn: SwapFn = pallas_swap_kv_caches if swap_op_type == "pallas" else jax_swap_kv_caches
# swap_in (host_sharding), swap_out (device_sharding)
swap_in_fn = functools.partial(jax.jit(
_swap_fn,
static_argnames=["src_sharding", "dst_sharding", "direction"],
out_shardings=device_sharding),
src_sharding=host_sharding,
dst_sharding=device_sharding,
direction="h2d")
swap_out_fn = functools.partial(jax.jit(
_swap_fn,
static_argnames=["src_sharding", "dst_sharding", "direction"],
out_shardings=host_sharding),
src_sharding=device_sharding,
dst_sharding=host_sharding,
direction="d2h")
return swap_in_fn, swap_out_fn
Loading