diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a332ac797b..7e2a63fe40 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -45,6 +45,8 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config +from vllm.distributed.ec_transfer import (get_ec_transfer, + has_ec_transfer) from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -96,9 +98,10 @@ MambaSpec, MLAAttentionSpec, UniformTypeKVCacheSpecs) # yapf: enable -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, LogprobsTensors, ModelRunnerOutput, - PoolerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, ECConnectorOutput, + AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors, + ModelRunnerOutput, + PoolerOutput, make_empty_encoder_model_runner_output) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -110,6 +113,7 @@ gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -280,7 +284,7 @@ def get_output(self) -> ModelRunnerOutput: return output -class NPUModelRunner(LoRAModelRunnerMixin): +class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config @@ -816,6 +820,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_ids_to_add.append(req_id) + # If this rank is an EC transfer producer, + # skip updating the states of KV cache blocks. + if has_ec_transfer() and get_ec_transfer().is_producer: + return + # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs @@ -1774,8 +1783,12 @@ def _prepare_inputs( # _prepare_inputs may reorder the batch, so we must gather # multi-modal outputs after that to ensure the correct order if self.is_multimodal_model: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + # Run the multimodal encoder if any. + self._execute_mm_encoder(scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -2447,6 +2460,14 @@ def execute_model( ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: with ProfileExecuteDuration().capture_async("prepare input"): self._update_states(scheduler_output) + if has_ec_transfer() and get_ec_transfer().is_producer: + with self.maybe_get_ec_connector_output( + scheduler_output, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): logger.debug( @@ -3873,6 +3894,9 @@ def get_kv_cache_spec_v0110(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ + if has_ec_transfer() and get_ec_transfer().is_producer: + return {} + block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla use_sparse = self.use_sparse diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 58ac27a0d2..ebbac7fc11 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -29,6 +29,7 @@ from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger @@ -413,6 +414,7 @@ def _init_worker_distributed_environment(self) -> None: self.parallel_config.decode_context_parallel_size) init_ascend_model_parallel(self.parallel_config) ensure_kv_transfer_initialized(self.vllm_config) + ensure_ec_transfer_initialized(self.vllm_config) def _init_profiler(self): # Torch profiler. Enabled and configured through env vars: