Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 56 additions & 15 deletions vllm_ascend/distributed/mooncake/mooncake_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(
self.tp_rank = parallel_config.rank
self.tp_size = parallel_config.tensor_parallel_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.block_size = vllm_config.cache_config.block_size
self.current_layer = 0
# self.use_mla = first_kv_cache_tuple[0].size(
Expand Down Expand Up @@ -142,13 +144,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event_sending)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database, self.block_len,
self.block_size, ready_event)
self.kv_recv_thread.start()
ready_event.wait()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database, self.block_len,
self.block_size, ready_event)
self.kv_recv_thread.start()
ready_event.wait()

def start_load_kv(self, metadata: MooncakeConnectorMetadata):
self.current_layer = 0
Expand Down Expand Up @@ -179,12 +182,47 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever)
else:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id,
tokens,
request.block_ids,
token_mask,
)
if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id,
tokens,
request.block_ids,
token_mask,
)
else:
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, token_mask):
addr, size, block_id = self.prepare_value(
start, end, request.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, token_mask):
addr, size, _ = self.prepare_value(start, end, request.block_ids)
self.m_store.get(key, addr, size)
Comment on lines 191 to 212
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new block for synchronous loading is almost identical to the logic in KVCacheStoreRecvingThread._handle_request. This code duplication can lead to maintenance issues where a bug fix or change in one place might be missed in the other.

To improve maintainability, I recommend extracting this logic into a new private helper method within the MooncakeEngine class, for example, _load_kv_sync. The synchronous path can then call this method.

Ideally, KVCacheStoreRecvingThread should also be refactored to use this new helper method to completely eliminate the duplication. This would likely involve passing the MooncakeEngine instance to the thread's constructor.


def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])

addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
Comment on lines +214 to +226
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This prepare_value method is a direct copy of the method with the same name in KVTransferThread from kv_transfer.py. Duplicating code like this makes the codebase harder to maintain.

It would be better to refactor this into a common utility function that can be called from both MooncakeEngine and KVTransferThread. This function would take necessary context (like block_size, kv_caches_base_addr, block_len, and use_mla) as arguments, ensuring the logic is defined in a single place.


def wait_for_layer_load(self) -> None:
"""MooncakeConnector does not do layerwise saving."""
Expand Down Expand Up @@ -430,8 +468,11 @@ def get_finished(self) -> tuple[set[str], set[str]]:
self.kv_send_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
done_recving = self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr]
)

done_recving = (
self.kv_recv_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.load_async else set())

logger.debug(
"Number of completed KV cache send requests: %d, receive "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
# request_id -> (vllm cached tokes, mooncake cached tokens)
self.load_specs: dict[str, LoadSpec] = {}
self._block_size = vllm_config.cache_config.block_size
Expand Down Expand Up @@ -229,7 +231,7 @@ def get_num_new_matched_tokens(
can_load=False,
)

return need_to_allocate, not self.use_layerwise
return need_to_allocate, self.load_async

def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
Expand Down
Loading