Skip to content

Commit e0ab650

Browse files
committed
[Core] Async scheduling + structured outputs compatibility#26866
Signed-off-by: leo-pony <[email protected]>
1 parent b739ae4 commit e0ab650

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@
180180

181181
if TYPE_CHECKING:
182182
import xgrammar as xgr # type: ignore[import-untyped]
183-
from vllm.v1.core.sched.output import SchedulerOutput
183+
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
184184
else:
185185
xgr = LazyLoader("xgr", globals(), "xgrammar")
186186

@@ -279,6 +279,17 @@ def get_output(self) -> ModelRunnerOutput:
279279
output.sampled_token_ids = valid_sampled_token_ids
280280
return output
281281

282+
class ExecuteModelState(NamedTuple):
283+
"""Ephemeral cached state transferred between execute_model() and
284+
sample_tokens(), after execute_model() returns None."""
285+
286+
scheduler_output: "SchedulerOutput"
287+
logits: torch.Tensor
288+
spec_decode_metadata: SpecDecodeMetadata | None
289+
hidden_states: torch.Tensor
290+
sample_hidden_states: torch.Tensor
291+
aux_hidden_states: list[torch.Tensor] | None
292+
kv_connector_output: KVConnectorOutput | None
282293

283294
class NPUModelRunner(LoRAModelRunnerMixin):
284295

@@ -600,6 +611,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
600611
# TODO: EVS Support (Video tokens pruning) (see vllm#22980)
601612
self.is_multimodal_pruning_enabled = False
602613

614+
# Ephemeral state transferred between execute_model() and sample_tokens().
615+
self.execute_model_state: ExecuteModelState | None = None
616+
603617
def _set_up_drafter(self):
604618
# Set up speculative decoding.
605619
self.spec_attn_mask = None
@@ -2423,7 +2437,13 @@ def execute_model(
24232437
self,
24242438
scheduler_output: "SchedulerOutput",
24252439
intermediate_tensors: Optional[IntermediateTensors] = None,
2426-
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
2440+
) -> Union[ModelRunnerOutput, IntermediateTensors] | None:
2441+
if self.execute_model_state is not None:
2442+
raise RuntimeError(
2443+
"State error: sample_tokens() must be called "
2444+
"after execute_model() returns None."
2445+
)
2446+
24272447
with ProfileExecuteDuration().capture_async("prepare input"):
24282448
self._update_states(scheduler_output)
24292449
if not scheduler_output.total_num_scheduled_tokens:
@@ -2537,9 +2557,43 @@ def execute_model(
25372557
logits = self.apply_grammar_bitmask(
25382558
scheduler_output, logits)
25392559
else:
2540-
if scheduler_output.structured_output_request_ids:
2541-
logits = self.apply_grammar_bitmask(
2542-
scheduler_output, logits)
2560+
self.execute_model_state = ExecuteModelState(
2561+
scheduler_output,
2562+
logits,
2563+
spec_decode_metadata,
2564+
hidden_states,
2565+
sample_hidden_states,
2566+
aux_hidden_states,
2567+
kv_connector_output,
2568+
)
2569+
return None
2570+
2571+
@torch.inference_mode
2572+
def sample_tokens(
2573+
self, grammar_output: "GrammarOutput | None"
2574+
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
2575+
if self.execute_model_state is None:
2576+
# Nothing to do (PP non-final rank case), output isn't used.
2577+
return None # noqa
2578+
2579+
# Unpack ephemeral state.
2580+
(
2581+
scheduler_output,
2582+
logits,
2583+
spec_decode_metadata,
2584+
hidden_states,
2585+
sample_hidden_states,
2586+
aux_hidden_states,
2587+
kv_connector_output,
2588+
) = self.execute_model_state
2589+
# Clear ephemeral state.
2590+
self.execute_model_state = None
2591+
2592+
# Apply structured output bitmasks if present.
2593+
if grammar_output is not None:
2594+
logits = self.apply_grammar_bitmask(
2595+
scheduler_output, grammar_output, self.input_batch, logits
2596+
)
25432597

25442598
with ProfileExecuteDuration().capture_async("Sample"):
25452599
# Sample the next token and get logprobs if needed.

vllm_ascend/worker/worker_v1.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import copy
2121
from typing import Optional, Union
22+
from types import NoneType
2223

2324
import torch
2425
import torch.nn as nn
@@ -35,7 +36,7 @@
3536
from vllm.lora.request import LoRARequest
3637
from vllm.sequence import IntermediateTensors
3738
from vllm.tasks import SupportedTask
38-
from vllm.v1.core.sched.output import SchedulerOutput
39+
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
3940
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
4041
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
4142
DraftTokenIds, ModelRunnerOutput)
@@ -274,7 +275,7 @@ def determine_available_memory(self) -> int:
274275
def execute_model(
275276
self,
276277
scheduler_output: "SchedulerOutput",
277-
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
278+
) -> ModelRunnerOutput | None:
278279
# enable msMonitor to monitor the performance of vllm-ascend
279280
if envs_ascend.MSMONITOR_USE_DAEMON:
280281
dp.step()
@@ -288,7 +289,7 @@ def execute_model(
288289

289290
output = self.model_runner.execute_model(scheduler_output,
290291
intermediate_tensors)
291-
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
292+
if isinstance(output, (ModelRunnerOutput, NoneType)):
292293
return output
293294

294295
assert isinstance(output, IntermediateTensors)
@@ -312,6 +313,12 @@ def execute_model(
312313
output.kv_connector_output = kv_connector_output
313314
return output
314315

316+
@torch.inference_mode()
317+
def sample_tokens(
318+
self, grammar_output: "GrammarOutput"
319+
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
320+
return self.model_runner.sample_tokens(grammar_output)
321+
315322
def load_model(self) -> None:
316323
if self.vllm_config.model_config.enable_sleep_mode:
317324
allocator = CaMemAllocator.get_instance()

0 commit comments

Comments
 (0)