|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the LMCache project |
3 | 3 |
|
| 4 | +import functools |
4 | 5 | import hashlib |
5 | 6 | from dataclasses import dataclass |
6 | | -from typing import Iterable, List, Literal, Optional, Tuple |
| 7 | +from typing import Callable, Iterable, List, Literal, Optional, Tuple |
7 | 8 |
|
8 | 9 | import jax |
9 | 10 | from vllm.config import get_current_vllm_config |
@@ -103,31 +104,109 @@ def get_kv_connector_cache_layout(): |
103 | 104 | return None |
104 | 105 |
|
105 | 106 |
|
106 | | -def swap_ops( |
107 | | - src_kv_cache: jax.Array, |
108 | | - out_sharding: Optional[jax.sharding.NamedSharding], |
| 107 | +SwapFn = Callable[ |
| 108 | + [ |
| 109 | + List[jax.Array], # src_kv_caches |
| 110 | + jax.sharding.NamedSharding, # src_sharding |
| 111 | + jax.sharding.NamedSharding, # dst_sharding |
| 112 | + Literal["h2d", "d2h"], # direction |
| 113 | + ], |
| 114 | + List[jax.Array], # return value |
| 115 | +] |
| 116 | + |
| 117 | +JittedKVCacheSwapFn = Callable[[List[jax.Array]], List[jax.Array]] |
| 118 | + |
| 119 | + |
| 120 | +# NOTE(jcgu): keep the same interface as the pallas one |
| 121 | +def jax_swap_kv_caches( |
| 122 | + src_kv_caches: List[jax.Array], |
| 123 | + src_sharding: jax.sharding.NamedSharding, |
| 124 | + dst_sharding: jax.sharding.NamedSharding, |
109 | 125 | direction: Literal["h2d", "d2h"], |
110 | | - op_type: CPU_OFFLOADING_SWAP_OP_TYPE, |
111 | | -) -> jax.Array: |
112 | | - if op_type == "jax": |
113 | | - return jax_swap_kv_cache(src_kv_cache, out_sharding, direction) |
114 | | - return dma_kv_cache(src_kv_cache, out_sharding, direction) |
| 126 | +) -> List[jax.Array]: |
| 127 | + """Swap in / out multi-layer kv_cache using jax device_put |
| 128 | +
|
| 129 | + Args: |
| 130 | + src_kv_caches: [kv_cache of each layer] |
| 131 | + src_sharding: kv_caches' original sharding |
| 132 | + dst_sharding: kv_caches' target sharding (different memory_kind) |
| 133 | + direction: h2d -> swap_in, d2h -> swap_out |
| 134 | + Returns: |
| 135 | + a list of jax.Array objects with the dst_sharding |
| 136 | + """ |
| 137 | + |
| 138 | + def _jax_device_put(input_array): |
| 139 | + return jax.device_put(input_array, dst_sharding) |
115 | 140 |
|
| 141 | + return jax.tree.map(_jax_device_put, src_kv_caches) |
116 | 142 |
|
117 | | -def jax_swap_kv_cache( |
118 | | - src_kv_cache: jax.Array, |
119 | | - out_sharding: Optional[jax.sharding.NamedSharding], |
| 143 | + |
| 144 | +def pallas_swap_kv_caches( |
| 145 | + src_kv_caches: List[jax.Array], |
| 146 | + src_sharding: jax.sharding.NamedSharding, |
| 147 | + dst_sharding: jax.sharding.NamedSharding, |
120 | 148 | direction: Literal["h2d", "d2h"], |
121 | | -) -> jax.Array: |
122 | | - cpu_device = jax.devices("cpu")[0] |
123 | | - return jax.device_put(src_kv_cache, |
124 | | - cpu_device if direction == "d2h" else out_sharding) |
125 | | - |
126 | | - |
127 | | -def dma_kv_cache( |
128 | | - src_kv_cache: jax.Array, |
129 | | - out_sharding: jax.sharding.NamedSharding, |
130 | | - direction: CPU_OFFLOADING_SWAP_OP_TYPE, |
131 | | -) -> jax.Array: |
132 | | - dma_fn = d2h_dma if direction == "d2h" else h2d_dma |
133 | | - return dma_fn(src_kv_cache, out_sharding) |
| 149 | +) -> List[jax.Array]: |
| 150 | + """Swap in / out multi-layer kv_cache using pallas dma kernel |
| 151 | +
|
| 152 | + Args: |
| 153 | + src_kv_caches: [kv_cache of each layer] |
| 154 | + src_sharding: kv_caches' original sharding |
| 155 | + dst_sharding: kv_caches' target sharding (different memory_kind) |
| 156 | + direction: h2d -> swap_in, d2h -> swap_out |
| 157 | + Returns: |
| 158 | + a list of jax.Array objects with the dst_sharding |
| 159 | + """ |
| 160 | + |
| 161 | + def swap_in_fn(inputs, input_sharding, out_sharding): |
| 162 | + |
| 163 | + def _swap_in(hbm_sharded_array): |
| 164 | + return h2d_dma(hbm_sharded_array, input_sharding, out_sharding) |
| 165 | + |
| 166 | + return jax.tree.map(_swap_in, inputs) |
| 167 | + |
| 168 | + def swap_out_fn(inputs, input_sharding, out_sharding): |
| 169 | + |
| 170 | + def _swap_out(hbm_sharded_array): |
| 171 | + return d2h_dma(hbm_sharded_array, input_sharding, out_sharding) |
| 172 | + |
| 173 | + return jax.tree.map(_swap_out, inputs) |
| 174 | + |
| 175 | + if direction == "d2h": |
| 176 | + return swap_out_fn(src_kv_caches, src_sharding, dst_sharding) |
| 177 | + elif direction == "h2d": |
| 178 | + return swap_in_fn(src_kv_caches, src_sharding, dst_sharding) |
| 179 | + |
| 180 | + |
| 181 | +def get_jitted_kv_cache_swap_fn( |
| 182 | + swap_op_type: CPU_OFFLOADING_SWAP_OP_TYPE, |
| 183 | + host_sharding: jax.sharding.NamedSharding, |
| 184 | + device_sharding: jax.sharding.NamedSharding |
| 185 | +) -> Tuple[JittedKVCacheSwapFn, JittedKVCacheSwapFn]: |
| 186 | + """jit compile the swap_in and swap_out functions |
| 187 | +
|
| 188 | + Args: |
| 189 | + swap_op_type : (str) pallas or jax |
| 190 | + host_sharding: |
| 191 | + device_sharding: |
| 192 | +
|
| 193 | + Returns: |
| 194 | + A tuple containing the jitted swap-in and swap-out functions. |
| 195 | + """ |
| 196 | + _swap_fn: SwapFn = pallas_swap_kv_caches if swap_op_type == "pallas" else jax_swap_kv_caches |
| 197 | + # swap_in (host_sharding), swap_out (device_sharding) |
| 198 | + swap_in_fn = functools.partial(jax.jit( |
| 199 | + _swap_fn, |
| 200 | + static_argnames=["src_sharding", "dst_sharding", "direction"], |
| 201 | + out_shardings=device_sharding), |
| 202 | + src_sharding=host_sharding, |
| 203 | + dst_sharding=device_sharding, |
| 204 | + direction="h2d") |
| 205 | + swap_out_fn = functools.partial(jax.jit( |
| 206 | + _swap_fn, |
| 207 | + static_argnames=["src_sharding", "dst_sharding", "direction"], |
| 208 | + out_shardings=host_sharding), |
| 209 | + src_sharding=device_sharding, |
| 210 | + dst_sharding=host_sharding, |
| 211 | + direction="d2h") |
| 212 | + return swap_in_fn, swap_out_fn |
0 commit comments