-
Notifications
You must be signed in to change notification settings - Fork 530
【main】kvpool sync load #3653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【main】kvpool sync load #3653
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,8 @@ def __init__( | |
| self.tp_rank = parallel_config.rank | ||
| self.tp_size = parallel_config.tensor_parallel_size | ||
| self.kv_role = vllm_config.kv_transfer_config.kv_role | ||
| self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( | ||
| "load_async", False) | ||
| self.block_size = vllm_config.cache_config.block_size | ||
| self.current_layer = 0 | ||
| # self.use_mla = first_kv_cache_tuple[0].size( | ||
|
|
@@ -142,13 +144,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
| self.kv_caches_base_addr, self.token_database, | ||
| self.block_len, self.block_size, ready_event_sending) | ||
| self.kv_send_thread.start() | ||
| ready_event = threading.Event() | ||
| self.kv_recv_thread = KVCacheStoreRecvingThread( | ||
| self.tp_rank, self.tp_size, self.m_store, | ||
| self.kv_caches_base_addr, self.token_database, self.block_len, | ||
| self.block_size, ready_event) | ||
| self.kv_recv_thread.start() | ||
| ready_event.wait() | ||
| if self.load_async: | ||
| ready_event = threading.Event() | ||
| self.kv_recv_thread = KVCacheStoreRecvingThread( | ||
| self.tp_rank, self.tp_size, self.m_store, | ||
| self.kv_caches_base_addr, self.token_database, self.block_len, | ||
| self.block_size, ready_event) | ||
| self.kv_recv_thread.start() | ||
| ready_event.wait() | ||
|
|
||
| def start_load_kv(self, metadata: MooncakeConnectorMetadata): | ||
| self.current_layer = 0 | ||
|
|
@@ -179,12 +182,47 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): | |
| next(layerwise_retriever) # first layer load | ||
| self.layerwise_retrievers.append(layerwise_retriever) | ||
| else: | ||
| self.kv_recv_thread.add_request( # type: ignore[union-attr] | ||
| req_id, | ||
| tokens, | ||
| request.block_ids, | ||
| token_mask, | ||
| ) | ||
| if self.load_async: | ||
| self.kv_recv_thread.add_request( # type: ignore[union-attr] | ||
| req_id, | ||
| tokens, | ||
| request.block_ids, | ||
| token_mask, | ||
| ) | ||
| else: | ||
| if self.m_store.config.use_ascend_direct: | ||
| addr_list = [] | ||
| size_list = [] | ||
| key_list = [] | ||
| blockIds = [] | ||
| for start, end, key in self.token_database.process_tokens( | ||
| tokens, token_mask): | ||
| addr, size, block_id = self.prepare_value( | ||
| start, end, request.block_ids) | ||
| key_list.append(key.to_string()) | ||
| addr_list.append(addr) | ||
| size_list.append(size) | ||
| blockIds.append(block_id) | ||
| self.m_store.get_batch(key_list, addr_list, size_list, blockIds) | ||
| else: | ||
| for start, end, key in self.token_database.process_tokens( | ||
| tokens, token_mask): | ||
| addr, size, _ = self.prepare_value(start, end, request.block_ids) | ||
| self.m_store.get(key, addr, size) | ||
|
|
||
| def prepare_value(self, start: int, end: int, block_ids: list[int]): | ||
| addr_list = [] | ||
| size_list = [] | ||
| block_id = block_ids[start // self.block_size] | ||
| for index, base_addr in enumerate(self.kv_caches_base_addr): | ||
| block_len = (self.block_len[index % 2] | ||
| if self.use_mla else self.block_len[0]) | ||
|
|
||
| addr = base_addr + block_id * block_len | ||
| length = int(block_len / self.block_size * (end - start)) | ||
| addr_list.append(addr) | ||
| size_list.append(length) | ||
| return addr_list, size_list, block_id | ||
|
Comment on lines
+214
to
+226
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This It would be better to refactor this into a common utility function that can be called from both |
||
|
|
||
| def wait_for_layer_load(self) -> None: | ||
| """MooncakeConnector does not do layerwise saving.""" | ||
|
|
@@ -430,8 +468,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: | |
| self.kv_send_thread. | ||
| get_and_clear_finished_requests( # type: ignore[union-attr] | ||
| ) if self.kv_role in ['kv_producer', 'kv_both'] else set()) | ||
| done_recving = self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr] | ||
| ) | ||
|
|
||
| done_recving = ( | ||
| self.kv_recv_thread. | ||
| get_and_clear_finished_requests( # type: ignore[union-attr] | ||
| ) if self.load_async else set()) | ||
|
|
||
| logger.debug( | ||
| "Number of completed KV cache send requests: %d, receive " | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new block for synchronous loading is almost identical to the logic in
KVCacheStoreRecvingThread._handle_request. This code duplication can lead to maintenance issues where a bug fix or change in one place might be missed in the other.To improve maintainability, I recommend extracting this logic into a new private helper method within the
MooncakeEngineclass, for example,_load_kv_sync. The synchronous path can then call this method.Ideally,
KVCacheStoreRecvingThreadshould also be refactored to use this new helper method to completely eliminate the duplication. This would likely involve passing theMooncakeEngineinstance to the thread's constructor.