Skip to content

Commit f15cacc

Browse files
[Feature][TPU host offload] Revise KV Cache slicing, gathering, ... operations (#931)
Signed-off-by: Juncheng Gu <[email protected]>
1 parent 5e8c42c commit f15cacc

File tree

5 files changed

+211
-126
lines changed

5 files changed

+211
-126
lines changed

tests/distributed/cpu_offloading_kv_roundtrip_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_tpu_connector_d2h_h2d_roundtrip(self, model_axis_size: int,
139139
return None
140140

141141
# 1. Setup
142-
os.environ['TPU_KV_OFFLOADING_SWAP_OP_TYPE'] = swap_op_type
142+
os.environ['TPU_OFFLOADING_SWAP_OP_TYPE'] = swap_op_type
143143
mesh = self.create_mesh((1, model_axis_size), ("data", "model"))
144144
if mesh is None:
145145
return None
@@ -211,6 +211,8 @@ def create_on_device(key):
211211
assert len(
212212
cached_value
213213
) == num_layers, f"cache_value layer: {len(cached_value)} != {num_layers}"
214+
# NOTE(jcgu): comment out this assertion since we've reverted back to using SingleDeviceSharding
215+
# assert cached_value[0].sharding.memory_kind == "pinned_host"
214216
retrieved_chunks.append(cached_value[0]) # Get first layer
215217

216218
# Assemble on CPU and compare with original
@@ -243,5 +245,7 @@ def create_on_device(key):
243245
jax.block_until_ready(worker.runner.kv_caches)
244246

245247
# 5. Verify TPU Reloaded Content
246-
self.assertArraysEqual(source_kv_cache[0][target_block_ids, ...],
247-
dest_kv_cache[0][target_block_ids, ...])
248+
for i in range(num_layers):
249+
self.assertArraysEqual(
250+
source_kv_cache[i][target_block_ids, ...],
251+
worker.runner.kv_caches[i][target_block_ids, ...])

tests/distributed/host_offloading_accuracy_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ def test_kv_cache_cpu_offloading_accuracy(
5151
):
5252
with monkeypatch.context():
5353
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
54+
os.environ['TPU_OFFLOADING_SWAP_OP_TYPE'] = "pallas"
5455
llm = LLM(model="meta-llama/Llama-3.2-3B",
5556
max_model_len=1024,
56-
tensor_parallel_size=1,
57+
tensor_parallel_size=8,
5758
task="generate",
5859
kv_transfer_config=kv_transfer_config)
5960

tpu_inference/distributed/cache_util.py

Lines changed: 104 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the LMCache project
33

4+
import functools
45
import hashlib
56
from dataclasses import dataclass
6-
from typing import Iterable, List, Literal, Optional, Tuple
7+
from typing import Callable, Iterable, List, Literal, Optional, Tuple
78

89
import jax
910
from vllm.config import get_current_vllm_config
@@ -103,31 +104,109 @@ def get_kv_connector_cache_layout():
103104
return None
104105

105106

106-
def swap_ops(
107-
src_kv_cache: jax.Array,
108-
out_sharding: Optional[jax.sharding.NamedSharding],
107+
SwapFn = Callable[
108+
[
109+
List[jax.Array], # src_kv_caches
110+
jax.sharding.NamedSharding, # src_sharding
111+
jax.sharding.NamedSharding, # dst_sharding
112+
Literal["h2d", "d2h"], # direction
113+
],
114+
List[jax.Array], # return value
115+
]
116+
117+
JittedKVCacheSwapFn = Callable[[List[jax.Array]], List[jax.Array]]
118+
119+
120+
# NOTE(jcgu): keep the same interface as the pallas one
121+
def jax_swap_kv_caches(
122+
src_kv_caches: List[jax.Array],
123+
src_sharding: jax.sharding.NamedSharding,
124+
dst_sharding: jax.sharding.NamedSharding,
109125
direction: Literal["h2d", "d2h"],
110-
op_type: CPU_OFFLOADING_SWAP_OP_TYPE,
111-
) -> jax.Array:
112-
if op_type == "jax":
113-
return jax_swap_kv_cache(src_kv_cache, out_sharding, direction)
114-
return dma_kv_cache(src_kv_cache, out_sharding, direction)
126+
) -> List[jax.Array]:
127+
"""Swap in / out multi-layer kv_cache using jax device_put
128+
129+
Args:
130+
src_kv_caches: [kv_cache of each layer]
131+
src_sharding: kv_caches' original sharding
132+
dst_sharding: kv_caches' target sharding (different memory_kind)
133+
direction: h2d -> swap_in, d2h -> swap_out
134+
Returns:
135+
a list of jax.Array objects with the dst_sharding
136+
"""
137+
138+
def _jax_device_put(input_array):
139+
return jax.device_put(input_array, dst_sharding)
115140

141+
return jax.tree.map(_jax_device_put, src_kv_caches)
116142

117-
def jax_swap_kv_cache(
118-
src_kv_cache: jax.Array,
119-
out_sharding: Optional[jax.sharding.NamedSharding],
143+
144+
def pallas_swap_kv_caches(
145+
src_kv_caches: List[jax.Array],
146+
src_sharding: jax.sharding.NamedSharding,
147+
dst_sharding: jax.sharding.NamedSharding,
120148
direction: Literal["h2d", "d2h"],
121-
) -> jax.Array:
122-
cpu_device = jax.devices("cpu")[0]
123-
return jax.device_put(src_kv_cache,
124-
cpu_device if direction == "d2h" else out_sharding)
125-
126-
127-
def dma_kv_cache(
128-
src_kv_cache: jax.Array,
129-
out_sharding: jax.sharding.NamedSharding,
130-
direction: CPU_OFFLOADING_SWAP_OP_TYPE,
131-
) -> jax.Array:
132-
dma_fn = d2h_dma if direction == "d2h" else h2d_dma
133-
return dma_fn(src_kv_cache, out_sharding)
149+
) -> List[jax.Array]:
150+
"""Swap in / out multi-layer kv_cache using pallas dma kernel
151+
152+
Args:
153+
src_kv_caches: [kv_cache of each layer]
154+
src_sharding: kv_caches' original sharding
155+
dst_sharding: kv_caches' target sharding (different memory_kind)
156+
direction: h2d -> swap_in, d2h -> swap_out
157+
Returns:
158+
a list of jax.Array objects with the dst_sharding
159+
"""
160+
161+
def swap_in_fn(inputs, input_sharding, out_sharding):
162+
163+
def _swap_in(hbm_sharded_array):
164+
return h2d_dma(hbm_sharded_array, input_sharding, out_sharding)
165+
166+
return jax.tree.map(_swap_in, inputs)
167+
168+
def swap_out_fn(inputs, input_sharding, out_sharding):
169+
170+
def _swap_out(hbm_sharded_array):
171+
return d2h_dma(hbm_sharded_array, input_sharding, out_sharding)
172+
173+
return jax.tree.map(_swap_out, inputs)
174+
175+
if direction == "d2h":
176+
return swap_out_fn(src_kv_caches, src_sharding, dst_sharding)
177+
elif direction == "h2d":
178+
return swap_in_fn(src_kv_caches, src_sharding, dst_sharding)
179+
180+
181+
def get_jitted_kv_cache_swap_fn(
182+
swap_op_type: CPU_OFFLOADING_SWAP_OP_TYPE,
183+
host_sharding: jax.sharding.NamedSharding,
184+
device_sharding: jax.sharding.NamedSharding
185+
) -> Tuple[JittedKVCacheSwapFn, JittedKVCacheSwapFn]:
186+
"""jit compile the swap_in and swap_out functions
187+
188+
Args:
189+
swap_op_type : (str) pallas or jax
190+
host_sharding:
191+
device_sharding:
192+
193+
Returns:
194+
A tuple containing the jitted swap-in and swap-out functions.
195+
"""
196+
_swap_fn: SwapFn = pallas_swap_kv_caches if swap_op_type == "pallas" else jax_swap_kv_caches
197+
# swap_in (host_sharding), swap_out (device_sharding)
198+
swap_in_fn = functools.partial(jax.jit(
199+
_swap_fn,
200+
static_argnames=["src_sharding", "dst_sharding", "direction"],
201+
out_shardings=device_sharding),
202+
src_sharding=host_sharding,
203+
dst_sharding=device_sharding,
204+
direction="h2d")
205+
swap_out_fn = functools.partial(jax.jit(
206+
_swap_fn,
207+
static_argnames=["src_sharding", "dst_sharding", "direction"],
208+
out_shardings=host_sharding),
209+
src_sharding=device_sharding,
210+
dst_sharding=host_sharding,
211+
direction="d2h")
212+
return swap_in_fn, swap_out_fn

0 commit comments

Comments
 (0)