Skip to content

Commit b73b69b

Browse files
authored
Refactor: Generalize utils.py for all devices by lifting the CUDA limitation (LMCache#2848)
* Refactor: Generalize utils.py for all devices by lifting the CUDA limitation Previously, these utilities were restricted to CUDA devices. This commit breaks that limitation and generalizes the implementation so it can be seamlessly applied to any general device. Signed-off-by: Tony Lin <tony.lin@intel.com> * xpu: get attributes from utils.py Signed-off-by: Tony Lin <tony.lin@intel.com> * hpu: get attributes from utils.py Signed-off-by: Tony Lin <tony.lin@intel.com> * address gemini's review comments Signed-off-by: Tony Lin <tony.lin@intel.com> * chore: fix code formatting according to pre-commit Signed-off-by: Tony Lin <tony.lin@intel.com> * UT change according to new signature. Signed-off-by: Tony Lin <tony.lin@intel.com> * add new formats to align with latest c ops NL_X_TWO_NB_NH_BS_HS = 6 NL_X_NB_TWO_NH_BS_HS = 7 Signed-off-by: Tony Lin <tony.lin@intel.com> --------- Signed-off-by: Tony Lin <tony.lin@intel.com>
1 parent 0f51fab commit b73b69b

5 files changed

Lines changed: 310 additions & 133 deletions

File tree

lmcache/non_cuda_equivalents.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# CUDA-specific operations.
55
#
66
# Standard
7+
from enum import Enum, IntEnum
78
from multiprocessing import shared_memory
89
import ctypes
910

@@ -17,6 +18,41 @@
1718
_buf_registry: dict[int, ctypes.Array] = {}
1819

1920

21+
class TransferDirection(Enum):
22+
"""Specifies the direction of a memory transfer."""
23+
24+
H2D = 0
25+
D2H = 1
26+
27+
28+
class GPUKVFormat(IntEnum):
29+
"""Enumeration of different GPU KV cache memory layouts."""
30+
31+
# used by: vLLM CROSS_LAYER mode
32+
NB_NL_TWO_BS_NH_HS = 0
33+
34+
# used by: vLLM non-MLA flash attention
35+
NL_X_TWO_NB_BS_NH_HS = 1
36+
37+
# used by: vLLM non-MLA flash infer
38+
NL_X_NB_TWO_BS_NH_HS = 2
39+
40+
# used by: vLLM MLA
41+
NL_X_NB_BS_HS = 3
42+
43+
# used by: SGLang MHA (flash attention and flash infer)
44+
TWO_X_NL_X_NBBS_NH_HS = 4
45+
46+
# used by: SGLang MLA
47+
NL_X_NBBS_ONE_HS = 5
48+
49+
# used by: vLLM non-MLA flash attention (HND layout)
50+
NL_X_TWO_NB_NH_BS_HS = 6
51+
52+
# used by: vLLM non-MLA flash infer (HND layout)
53+
NL_X_NB_TWO_NH_BS_HS = 7
54+
55+
2056
# On XPU (Intel GPU), PyTorch 2.4+ supports pin_memory=True via SYCL USM
2157
# host allocation, enabling fast DMA for XPU<->CPU transfers.
2258
_XPU_PIN_MEMORY = hasattr(torch, "xpu") and torch.xpu.is_available()

lmcache/v1/gpu_connector/hpu_connector.py

Lines changed: 137 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,20 @@
2222

2323
# First Party
2424
from lmcache.logging import init_logger
25+
from lmcache.utils import EngineType
2526
from lmcache.v1.gpu_connector import GPUConnectorInterface
27+
from lmcache.v1.gpu_connector.utils import (
28+
discover_gpu_kv_format,
29+
get_block_size,
30+
get_dtype,
31+
get_head_size,
32+
get_hidden_dim_size,
33+
get_num_blocks,
34+
get_num_heads,
35+
get_num_layers,
36+
get_page_buffer_size,
37+
is_mla,
38+
)
2639
from lmcache.v1.memory_management import MemoryFormat, MemoryObj
2740
from lmcache.v1.metadata import LMCacheMetadata
2841

@@ -41,13 +54,12 @@ class VLLMPagedMemHPUConnectorV2(GPUConnectorInterface):
4154

4255
def __init__(
4356
self,
44-
hidden_dim_size: int,
45-
num_layers: int,
4657
use_gpu: bool = False,
4758
**kwargs,
4859
):
60+
self._attributes_initialized = False
4961
self.kvcaches: Optional[List[torch.Tensor]] = None
50-
self.use_mla = "use_mla" in kwargs and kwargs["use_mla"]
62+
self.use_gpu = use_gpu
5163

5264
@classmethod
5365
def from_metadata(
@@ -64,22 +76,8 @@ def from_metadata(
6476
Returns:
6577
A new instance of VLLMPagedMemHPUConnectorV2.
6678
"""
67-
# Extract parameters from metadata
68-
# kv_shape: (num_layer, 2 or 1, chunk_size, num_kv_head, head_size)
69-
num_layers = metadata.kv_shape[0]
70-
chunk_size = metadata.kv_shape[2]
71-
num_kv_head = metadata.kv_shape[3]
72-
head_size = metadata.kv_shape[4]
73-
hidden_dim_size = num_kv_head * head_size
74-
7579
return cls(
76-
hidden_dim_size=hidden_dim_size,
77-
num_layers=num_layers,
7880
use_gpu=use_gpu,
79-
chunk_size=chunk_size,
80-
dtype=metadata.kv_dtype,
81-
device=device,
82-
use_mla=metadata.use_mla,
8381
)
8482

8583
def to_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
@@ -101,19 +99,6 @@ def to_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
10199
"""
102100
assert memory_obj.tensor is not None
103101

104-
if self.use_mla:
105-
if memory_obj.metadata.fmt != MemoryFormat.KV_MLA_FMT:
106-
raise ValueError(
107-
"The memory object should be in KV_MLA_FMT format in"
108-
" order to be processed by VLLMPagedMemHPUConnectorV2"
109-
)
110-
else:
111-
if memory_obj.metadata.fmt != MemoryFormat.KV_2LTD:
112-
raise ValueError(
113-
"The memory object should be in KV_2LTD format in"
114-
" order to be processed by VLLMPagedMemHPUConnectorV2"
115-
)
116-
117102
self.initialize_kvcaches_ptr(**kwargs)
118103

119104
assert self.kvcaches is not None, (
@@ -125,6 +110,8 @@ def to_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
125110

126111
slot_mapping: torch.Tensor = kwargs["slot_mapping"]
127112
slices = slot_mapping[start:end]
113+
self._initialize_attributes(self.kvcaches)
114+
self._validate_memory_format(memory_obj)
128115

129116
# Flush the HPU lazy-mode op graph so the slot_mapping slice is
130117
# materialized before downstream ops consume it. This also keeps
@@ -134,17 +121,17 @@ def to_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
134121

135122
if self.use_mla:
136123
tmp = memory_obj.tensor[0].to(slot_mapping.device)
137-
num_blocks, block_size, head_size = self.kvcaches[0].shape
138-
total_blocks = num_blocks * block_size
124+
total_blocks = self.num_blocks * self.block_size
139125
for i, kvcache in enumerate(self.kvcaches):
140-
kvcache.view(total_blocks, head_size).index_copy_(0, slices, tmp[i])
126+
kvcache.view(total_blocks, self.head_size).index_copy_(
127+
0, slices, tmp[i]
128+
)
141129
htorch.core.mark_step()
142130
else:
143131
tmp_k = memory_obj.tensor[0].to(slot_mapping.device)
144132
tmp_v = memory_obj.tensor[1].to(slot_mapping.device)
145-
num_blocks, block_size, num_heads, head_size = self.kvcaches[0][0].shape
146-
total_blocks = num_blocks * block_size
147-
d = num_heads * head_size
133+
total_blocks = self.num_blocks * self.block_size
134+
d = self.num_heads * self.head_size
148135
for i, (kcache, vcache) in enumerate(self.kvcaches):
149136
kcache.view(total_blocks, d).index_copy_(0, slices, tmp_k[i])
150137
vcache.view(total_blocks, d).index_copy_(0, slices, tmp_v[i])
@@ -183,22 +170,22 @@ def from_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
183170

184171
slot_mapping: torch.Tensor = kwargs["slot_mapping"]
185172
slices = slot_mapping[start:end]
173+
self._initialize_attributes(self.kvcaches)
174+
self._validate_memory_format(memory_obj)
186175

187176
htorch.core.mark_step()
188177

189178
if self.use_mla:
190-
num_blocks, block_size, head_size = self.kvcaches[0].shape
191-
total_blocks = num_blocks * block_size
179+
total_blocks = self.num_blocks * self.block_size
192180
tmp = torch.stack(
193181
[
194-
kvcache.view(total_blocks, head_size).index_select(0, slices)
182+
kvcache.view(total_blocks, self.head_size).index_select(0, slices)
195183
for kvcache in self.kvcaches
196184
]
197185
)
198186
else:
199-
num_blocks, block_size, num_heads, head_size = self.kvcaches[0][0].shape
200-
total_blocks = num_blocks * block_size
201-
d = num_heads * head_size
187+
total_blocks = self.num_blocks * self.block_size
188+
d = self.num_heads * self.head_size
202189
tmp_k = torch.stack(
203190
[
204191
kvcache[0].view(total_blocks, d).index_select(0, slices)
@@ -229,5 +216,111 @@ def batched_from_gpu(self, memory_objs, starts, ends, **kwargs):
229216
self.from_gpu(memory_obj, start, end, **kwargs)
230217

231218
def get_shape(self, num_tokens: int) -> torch.Size:
232-
"""Get the shape of the data given the number of tokens."""
233-
raise NotImplementedError
219+
"""Get the shape of the data given the number of tokens.
220+
221+
Args:
222+
num_tokens: The number of tokens in the data.
223+
224+
Returns:
225+
The shape of the KV cache data.
226+
227+
Raises:
228+
RuntimeError: If attributes have not been initialized yet
229+
(i.e., no kv_caches have been seen).
230+
"""
231+
if not self._attributes_initialized:
232+
raise RuntimeError(
233+
"Cannot determine shape before attributes are initialized. "
234+
"Call to_gpu or from_gpu first so that _initialize_attributes "
235+
"can discover the KV cache layout."
236+
)
237+
kv_size = 1 if self.use_mla else 2
238+
return torch.Size([kv_size, self.num_layers, num_tokens, self.hidden_dim_size])
239+
240+
def _validate_memory_format(self, memory_obj: MemoryObj) -> None:
241+
"""Validate that the memory object has the expected format.
242+
243+
Args:
244+
memory_obj: The memory object to validate.
245+
246+
Raises:
247+
ValueError: If the memory format does not match the expected
248+
format based on whether MLA is in use.
249+
"""
250+
if self.use_mla:
251+
if memory_obj.metadata.fmt != MemoryFormat.KV_MLA_FMT:
252+
raise ValueError(
253+
"The memory object should be in KV_MLA_FMT format in"
254+
" order to be processed by VLLMPagedMemHPUConnectorV2"
255+
)
256+
else:
257+
if memory_obj.metadata.fmt != MemoryFormat.KV_2LTD:
258+
raise ValueError(
259+
"The memory object should be in KV_2LTD format in"
260+
" order to be processed by VLLMPagedMemHPUConnectorV2"
261+
)
262+
263+
def _initialize_attributes(self, kv_caches: List[torch.Tensor]):
264+
if self._attributes_initialized:
265+
return
266+
267+
self.device = kv_caches[0].device
268+
assert self.device.type == "hpu", "The device should be HPU."
269+
270+
# HPU vLLM provides kv_caches as List[TensorTuple(k_tensor, v_tensor)],
271+
# where each TensorTuple contains two 4D tensors of shape
272+
# (num_blocks, block_size, num_heads, head_size).
273+
# We create a lightweight proxy List[Tensor(2, ...)] to match the
274+
# standard vLLM format (NL_X_TWO_NB_BS_NH_HS) for format discovery.
275+
if (
276+
isinstance(kv_caches, (list, tuple))
277+
and len(kv_caches) > 0
278+
and len(kv_caches[0]) == 2
279+
and not isinstance(kv_caches[0], torch.Tensor)
280+
and isinstance(kv_caches[0][0], torch.Tensor)
281+
and isinstance(kv_caches[0][1], torch.Tensor)
282+
):
283+
# kv_caches[i][0].shape = (num_blocks, block_size, num_heads, head_size)
284+
# We need shape (2, num_blocks, block_size, num_heads, head_size)
285+
inner_shape = kv_caches[0][0].shape
286+
fake_shape = (2, *inner_shape)
287+
kv_caches = [
288+
torch.empty(fake_shape, dtype=kv_caches[0][0].dtype, device="meta")
289+
for _ in range(len(kv_caches))
290+
]
291+
logger.info(
292+
"HPU: created lightweight kv_caches proxy with shape %s "
293+
"for format discovery",
294+
fake_shape,
295+
)
296+
297+
self.gpu_kv_format = discover_gpu_kv_format(kv_caches, EngineType.VLLM)
298+
self.num_layers = get_num_layers(kv_caches, self.gpu_kv_format)
299+
self.num_blocks = get_num_blocks(kv_caches, self.gpu_kv_format)
300+
self.block_size = get_block_size(kv_caches, self.gpu_kv_format)
301+
self.page_buffer_size = get_page_buffer_size(kv_caches, self.gpu_kv_format)
302+
self.hidden_dim_size = get_hidden_dim_size(kv_caches, self.gpu_kv_format)
303+
self.head_size = get_head_size(kv_caches, self.gpu_kv_format)
304+
self.use_mla = is_mla(self.gpu_kv_format)
305+
self.dtype = get_dtype(kv_caches, self.gpu_kv_format)
306+
self.num_heads = (
307+
1 if self.use_mla else get_num_heads(kv_caches, self.gpu_kv_format)
308+
)
309+
310+
self._attributes_initialized = True
311+
logger.info(
312+
"HPU: attributes initialized - format: %s, "
313+
"num_layers: %d, num_blocks: %d, block_size: %d, "
314+
"page_buffer_size: %d, hidden_dim_size: %d, head_size: %d, "
315+
"use_mla: %s, dtype: %s, num_heads: %d",
316+
self.gpu_kv_format,
317+
self.num_layers,
318+
self.num_blocks,
319+
self.block_size,
320+
self.page_buffer_size,
321+
self.hidden_dim_size,
322+
self.head_size,
323+
self.use_mla,
324+
self.dtype,
325+
self.num_heads,
326+
)

lmcache/v1/gpu_connector/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
if torch.cuda.is_available():
2828
# First Party
2929
import lmcache.c_ops as lmc_ops
30+
else:
31+
# First Party
32+
import lmcache.non_cuda_equivalents as lmc_ops
3033

3134
logger = init_logger(__name__)
3235

0 commit comments

Comments
 (0)