From de744f6a675be94a2b5800f74b0f3c7c1c0168ed Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Thu, 23 Oct 2025 11:52:16 +0800 Subject: [PATCH 1/3] kvpool sync load Signed-off-by: fems14 <1804143737@qq.com> --- .../distributed/mooncake/mooncake_engine.py | 71 +++++++++++++++---- .../mooncake/mooncake_store_connector_v1.py | 4 +- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index d89dcd7a7a..14ee48f8da 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -37,6 +37,8 @@ def __init__( self.tp_rank = parallel_config.rank self.tp_size = parallel_config.tensor_parallel_size self.kv_role = vllm_config.kv_transfer_config.kv_role + self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "load_async", False) self.block_size = vllm_config.cache_config.block_size self.current_layer = 0 # self.use_mla = first_kv_cache_tuple[0].size( @@ -142,13 +144,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_caches_base_addr, self.token_database, self.block_len, self.block_size, ready_event_sending) self.kv_send_thread.start() - ready_event = threading.Event() - self.kv_recv_thread = KVCacheStoreRecvingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, self.block_len, - self.block_size, ready_event) - self.kv_recv_thread.start() - ready_event.wait() + if self.load_async: + ready_event = threading.Event() + self.kv_recv_thread = KVCacheStoreRecvingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, + self.block_size, ready_event) + self.kv_recv_thread.start() + ready_event.wait() def start_load_kv(self, metadata: MooncakeConnectorMetadata): self.current_layer = 0 @@ -179,12 +182,47 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): next(layerwise_retriever) # first layer load self.layerwise_retrievers.append(layerwise_retriever) else: - self.kv_recv_thread.add_request( # type: ignore[union-attr] - req_id, - tokens, - request.block_ids, - token_mask, - ) + if self.load_async: + self.kv_recv_thread.add_request( # type: ignore[union-attr] + req_id, + tokens, + request.block_ids, + token_mask, + ) + else: + if self.m_store.config.use_ascend_direct: + addr_list = [] + size_list = [] + key_list = [] + blockIds = [] + for start, end, key in self.token_database.process_tokens( + tokens, token_mask): + addr, size, block_id = self.prepare_value( + start, end, request.block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + blockIds.append(block_id) + self.m_store.get_batch(key_list, addr_list, size_list, blockIds) + else: + for start, end, key in self.token_database.process_tokens( + tokens, token_mask): + addr, size, _ = self.prepare_value(start, end, request.block_ids) + self.m_store.get(key, addr, size) + + def prepare_value(self, start: int, end: int, block_ids: list[int]): + addr_list = [] + size_list = [] + block_id = block_ids[start // self.block_size] + for index, base_addr in enumerate(self.kv_caches_base_addr): + block_len = (self.block_len[index % 2] + if self.use_mla else self.block_len[0]) + + addr = base_addr + block_id * block_len + length = int(block_len / self.block_size * (end - start)) + addr_list.append(addr) + size_list.append(length) + return addr_list, size_list, block_id def wait_for_layer_load(self) -> None: """MooncakeConnector does not do layerwise saving.""" @@ -430,8 +468,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.kv_send_thread. get_and_clear_finished_requests( # type: ignore[union-attr] ) if self.kv_role in ['kv_producer', 'kv_both'] else set()) - done_recving = self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr] - ) + + done_recving = ( + self.kv_recv_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.load_async else set()) logger.debug( "Number of completed KV cache send requests: %d, receive " diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py index 3a7169a4cc..f8a6b3d8bb 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -165,6 +165,8 @@ def __init__(self, vllm_config: "VllmConfig", use_layerwise): self.kv_role = vllm_config.kv_transfer_config.kv_role self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "consumer_is_to_load", False) + self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "load_async", False) # request_id -> (vllm cached tokes, mooncake cached tokens) self.load_specs: dict[str, LoadSpec] = {} self._block_size = vllm_config.cache_config.block_size @@ -229,7 +231,7 @@ def get_num_new_matched_tokens( can_load=False, ) - return need_to_allocate, not self.use_layerwise + return need_to_allocate, self.load_async def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", From 507a3ebadc36ad2b4a6bb3799eb023e51968e7d0 Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Thu, 23 Oct 2025 20:30:47 +0800 Subject: [PATCH 2/3] fix lint Signed-off-by: fems14 <1804143737@qq.com> --- vllm_ascend/distributed/mooncake/mooncake_engine.py | 11 ++++++----- vllm_ascend/distributed/mooncake/mooncake_store.py | 3 --- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py index 14ee48f8da..423a6e63d8 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -104,7 +104,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.use_mla, first_kv_cache.shape) self.kv_caches = kv_caches - self.m_store.set_kv_caches(kv_caches.values()) self.kv_caches_base_addr = [] for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches @@ -148,8 +147,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, self.block_len, - self.block_size, ready_event) + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event) self.kv_recv_thread.start() ready_event.wait() @@ -203,11 +202,13 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): addr_list.append(addr) size_list.append(size) blockIds.append(block_id) - self.m_store.get_batch(key_list, addr_list, size_list, blockIds) + self.m_store.get_batch(key_list, addr_list, size_list, + blockIds) else: for start, end, key in self.token_database.process_tokens( tokens, token_mask): - addr, size, _ = self.prepare_value(start, end, request.block_ids) + addr, size, _ = self.prepare_value( + start, end, request.block_ids) self.m_store.get(key, addr, size) def prepare_value(self, start: int, end: int, block_ids: list[int]): diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py index bf522f7acc..397c2c85a0 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store.py +++ b/vllm_ascend/distributed/mooncake/mooncake_store.py @@ -65,9 +65,6 @@ def __init__(self, parallel_config: ParallelConfig): logger.error(msg) raise RuntimeError(msg) - def set_kv_caches(self, kvcache): - self.kvcache = list(kvcache) - def exists(self, key: MooncakeEngineKey) -> bool: return self.store.is_exist(key.to_string()) == 1 From 7fc7ed8704cba71bf3b0eace67131b4b58459dee Mon Sep 17 00:00:00 2001 From: fems14 <1804143737@qq.com> Date: Fri, 24 Oct 2025 09:50:13 +0800 Subject: [PATCH 3/3] add readme Signed-off-by: fems14 <1804143737@qq.com> --- .../mooncake_connector_store_deployment_guide.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md b/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md index 3bf9240d99..783251730c 100644 --- a/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md +++ b/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md @@ -10,7 +10,12 @@ * vLLM-Ascend:main branch * Mooncake:[AscendTransport/Mooncake at pooling-async-memcpy](https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy)(Currently available branch code, continuously updated.) Installation and Compilation Guide:https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy?tab=readme-ov-file#build-and-use-binaries - + +### KV Pooling Parameter Description +**kv_connector_extra_config**:Additional Configurable Parameters for Pooling + **mooncake_rpc_port**:Port for RPC Communication Between Pooling Scheduler Process and Worker Process: Each Instance Requires a Unique Port Configuration, + **load_async**:Whether to Enable Asynchronous Loading. The default value is false + ## run mooncake master ### 1.Configure mooncake.json