2222
2323# First Party
2424from lmcache .logging import init_logger
25+ from lmcache .utils import EngineType
2526from lmcache .v1 .gpu_connector import GPUConnectorInterface
27+ from lmcache .v1 .gpu_connector .utils import (
28+ discover_gpu_kv_format ,
29+ get_block_size ,
30+ get_dtype ,
31+ get_head_size ,
32+ get_hidden_dim_size ,
33+ get_num_blocks ,
34+ get_num_heads ,
35+ get_num_layers ,
36+ get_page_buffer_size ,
37+ is_mla ,
38+ )
2639from lmcache .v1 .memory_management import MemoryFormat , MemoryObj
2740from lmcache .v1 .metadata import LMCacheMetadata
2841
@@ -41,13 +54,12 @@ class VLLMPagedMemHPUConnectorV2(GPUConnectorInterface):
4154
4255 def __init__ (
4356 self ,
44- hidden_dim_size : int ,
45- num_layers : int ,
4657 use_gpu : bool = False ,
4758 ** kwargs ,
4859 ):
60+ self ._attributes_initialized = False
4961 self .kvcaches : Optional [List [torch .Tensor ]] = None
50- self .use_mla = "use_mla" in kwargs and kwargs [ "use_mla" ]
62+ self .use_gpu = use_gpu
5163
5264 @classmethod
5365 def from_metadata (
@@ -64,22 +76,8 @@ def from_metadata(
6476 Returns:
6577 A new instance of VLLMPagedMemHPUConnectorV2.
6678 """
67- # Extract parameters from metadata
68- # kv_shape: (num_layer, 2 or 1, chunk_size, num_kv_head, head_size)
69- num_layers = metadata .kv_shape [0 ]
70- chunk_size = metadata .kv_shape [2 ]
71- num_kv_head = metadata .kv_shape [3 ]
72- head_size = metadata .kv_shape [4 ]
73- hidden_dim_size = num_kv_head * head_size
74-
7579 return cls (
76- hidden_dim_size = hidden_dim_size ,
77- num_layers = num_layers ,
7880 use_gpu = use_gpu ,
79- chunk_size = chunk_size ,
80- dtype = metadata .kv_dtype ,
81- device = device ,
82- use_mla = metadata .use_mla ,
8381 )
8482
8583 def to_gpu (self , memory_obj : MemoryObj , start : int , end : int , ** kwargs ):
@@ -101,19 +99,6 @@ def to_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
10199 """
102100 assert memory_obj .tensor is not None
103101
104- if self .use_mla :
105- if memory_obj .metadata .fmt != MemoryFormat .KV_MLA_FMT :
106- raise ValueError (
107- "The memory object should be in KV_MLA_FMT format in"
108- " order to be processed by VLLMPagedMemHPUConnectorV2"
109- )
110- else :
111- if memory_obj .metadata .fmt != MemoryFormat .KV_2LTD :
112- raise ValueError (
113- "The memory object should be in KV_2LTD format in"
114- " order to be processed by VLLMPagedMemHPUConnectorV2"
115- )
116-
117102 self .initialize_kvcaches_ptr (** kwargs )
118103
119104 assert self .kvcaches is not None , (
@@ -125,6 +110,8 @@ def to_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
125110
126111 slot_mapping : torch .Tensor = kwargs ["slot_mapping" ]
127112 slices = slot_mapping [start :end ]
113+ self ._initialize_attributes (self .kvcaches )
114+ self ._validate_memory_format (memory_obj )
128115
129116 # Flush the HPU lazy-mode op graph so the slot_mapping slice is
130117 # materialized before downstream ops consume it. This also keeps
@@ -134,17 +121,17 @@ def to_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
134121
135122 if self .use_mla :
136123 tmp = memory_obj .tensor [0 ].to (slot_mapping .device )
137- num_blocks , block_size , head_size = self .kvcaches [0 ].shape
138- total_blocks = num_blocks * block_size
124+ total_blocks = self .num_blocks * self .block_size
139125 for i , kvcache in enumerate (self .kvcaches ):
140- kvcache .view (total_blocks , head_size ).index_copy_ (0 , slices , tmp [i ])
126+ kvcache .view (total_blocks , self .head_size ).index_copy_ (
127+ 0 , slices , tmp [i ]
128+ )
141129 htorch .core .mark_step ()
142130 else :
143131 tmp_k = memory_obj .tensor [0 ].to (slot_mapping .device )
144132 tmp_v = memory_obj .tensor [1 ].to (slot_mapping .device )
145- num_blocks , block_size , num_heads , head_size = self .kvcaches [0 ][0 ].shape
146- total_blocks = num_blocks * block_size
147- d = num_heads * head_size
133+ total_blocks = self .num_blocks * self .block_size
134+ d = self .num_heads * self .head_size
148135 for i , (kcache , vcache ) in enumerate (self .kvcaches ):
149136 kcache .view (total_blocks , d ).index_copy_ (0 , slices , tmp_k [i ])
150137 vcache .view (total_blocks , d ).index_copy_ (0 , slices , tmp_v [i ])
@@ -183,22 +170,22 @@ def from_gpu(self, memory_obj: MemoryObj, start: int, end: int, **kwargs):
183170
184171 slot_mapping : torch .Tensor = kwargs ["slot_mapping" ]
185172 slices = slot_mapping [start :end ]
173+ self ._initialize_attributes (self .kvcaches )
174+ self ._validate_memory_format (memory_obj )
186175
187176 htorch .core .mark_step ()
188177
189178 if self .use_mla :
190- num_blocks , block_size , head_size = self .kvcaches [0 ].shape
191- total_blocks = num_blocks * block_size
179+ total_blocks = self .num_blocks * self .block_size
192180 tmp = torch .stack (
193181 [
194- kvcache .view (total_blocks , head_size ).index_select (0 , slices )
182+ kvcache .view (total_blocks , self . head_size ).index_select (0 , slices )
195183 for kvcache in self .kvcaches
196184 ]
197185 )
198186 else :
199- num_blocks , block_size , num_heads , head_size = self .kvcaches [0 ][0 ].shape
200- total_blocks = num_blocks * block_size
201- d = num_heads * head_size
187+ total_blocks = self .num_blocks * self .block_size
188+ d = self .num_heads * self .head_size
202189 tmp_k = torch .stack (
203190 [
204191 kvcache [0 ].view (total_blocks , d ).index_select (0 , slices )
@@ -229,5 +216,111 @@ def batched_from_gpu(self, memory_objs, starts, ends, **kwargs):
229216 self .from_gpu (memory_obj , start , end , ** kwargs )
230217
231218 def get_shape (self , num_tokens : int ) -> torch .Size :
232- """Get the shape of the data given the number of tokens."""
233- raise NotImplementedError
219+ """Get the shape of the data given the number of tokens.
220+
221+ Args:
222+ num_tokens: The number of tokens in the data.
223+
224+ Returns:
225+ The shape of the KV cache data.
226+
227+ Raises:
228+ RuntimeError: If attributes have not been initialized yet
229+ (i.e., no kv_caches have been seen).
230+ """
231+ if not self ._attributes_initialized :
232+ raise RuntimeError (
233+ "Cannot determine shape before attributes are initialized. "
234+ "Call to_gpu or from_gpu first so that _initialize_attributes "
235+ "can discover the KV cache layout."
236+ )
237+ kv_size = 1 if self .use_mla else 2
238+ return torch .Size ([kv_size , self .num_layers , num_tokens , self .hidden_dim_size ])
239+
240+ def _validate_memory_format (self , memory_obj : MemoryObj ) -> None :
241+ """Validate that the memory object has the expected format.
242+
243+ Args:
244+ memory_obj: The memory object to validate.
245+
246+ Raises:
247+ ValueError: If the memory format does not match the expected
248+ format based on whether MLA is in use.
249+ """
250+ if self .use_mla :
251+ if memory_obj .metadata .fmt != MemoryFormat .KV_MLA_FMT :
252+ raise ValueError (
253+ "The memory object should be in KV_MLA_FMT format in"
254+ " order to be processed by VLLMPagedMemHPUConnectorV2"
255+ )
256+ else :
257+ if memory_obj .metadata .fmt != MemoryFormat .KV_2LTD :
258+ raise ValueError (
259+ "The memory object should be in KV_2LTD format in"
260+ " order to be processed by VLLMPagedMemHPUConnectorV2"
261+ )
262+
263+ def _initialize_attributes (self , kv_caches : List [torch .Tensor ]):
264+ if self ._attributes_initialized :
265+ return
266+
267+ self .device = kv_caches [0 ].device
268+ assert self .device .type == "hpu" , "The device should be HPU."
269+
270+ # HPU vLLM provides kv_caches as List[TensorTuple(k_tensor, v_tensor)],
271+ # where each TensorTuple contains two 4D tensors of shape
272+ # (num_blocks, block_size, num_heads, head_size).
273+ # We create a lightweight proxy List[Tensor(2, ...)] to match the
274+ # standard vLLM format (NL_X_TWO_NB_BS_NH_HS) for format discovery.
275+ if (
276+ isinstance (kv_caches , (list , tuple ))
277+ and len (kv_caches ) > 0
278+ and len (kv_caches [0 ]) == 2
279+ and not isinstance (kv_caches [0 ], torch .Tensor )
280+ and isinstance (kv_caches [0 ][0 ], torch .Tensor )
281+ and isinstance (kv_caches [0 ][1 ], torch .Tensor )
282+ ):
283+ # kv_caches[i][0].shape = (num_blocks, block_size, num_heads, head_size)
284+ # We need shape (2, num_blocks, block_size, num_heads, head_size)
285+ inner_shape = kv_caches [0 ][0 ].shape
286+ fake_shape = (2 , * inner_shape )
287+ kv_caches = [
288+ torch .empty (fake_shape , dtype = kv_caches [0 ][0 ].dtype , device = "meta" )
289+ for _ in range (len (kv_caches ))
290+ ]
291+ logger .info (
292+ "HPU: created lightweight kv_caches proxy with shape %s "
293+ "for format discovery" ,
294+ fake_shape ,
295+ )
296+
297+ self .gpu_kv_format = discover_gpu_kv_format (kv_caches , EngineType .VLLM )
298+ self .num_layers = get_num_layers (kv_caches , self .gpu_kv_format )
299+ self .num_blocks = get_num_blocks (kv_caches , self .gpu_kv_format )
300+ self .block_size = get_block_size (kv_caches , self .gpu_kv_format )
301+ self .page_buffer_size = get_page_buffer_size (kv_caches , self .gpu_kv_format )
302+ self .hidden_dim_size = get_hidden_dim_size (kv_caches , self .gpu_kv_format )
303+ self .head_size = get_head_size (kv_caches , self .gpu_kv_format )
304+ self .use_mla = is_mla (self .gpu_kv_format )
305+ self .dtype = get_dtype (kv_caches , self .gpu_kv_format )
306+ self .num_heads = (
307+ 1 if self .use_mla else get_num_heads (kv_caches , self .gpu_kv_format )
308+ )
309+
310+ self ._attributes_initialized = True
311+ logger .info (
312+ "HPU: attributes initialized - format: %s, "
313+ "num_layers: %d, num_blocks: %d, block_size: %d, "
314+ "page_buffer_size: %d, hidden_dim_size: %d, head_size: %d, "
315+ "use_mla: %s, dtype: %s, num_heads: %d" ,
316+ self .gpu_kv_format ,
317+ self .num_layers ,
318+ self .num_blocks ,
319+ self .block_size ,
320+ self .page_buffer_size ,
321+ self .hidden_dim_size ,
322+ self .head_size ,
323+ self .use_mla ,
324+ self .dtype ,
325+ self .num_heads ,
326+ )
0 commit comments