Skip to content

Commit ae349dc

Browse files
committed
[Core] Async scheduling + structured outputs compatibility#26866
Signed-off-by: leo-pony <[email protected]>
1 parent fdeb9d9 commit ae349dc

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@
180180

181181
if TYPE_CHECKING:
182182
import xgrammar as xgr # type: ignore[import-untyped]
183-
from vllm.v1.core.sched.output import SchedulerOutput
183+
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
184184
else:
185185
xgr = LazyLoader("xgr", globals(), "xgrammar")
186186

@@ -279,6 +279,17 @@ def get_output(self) -> ModelRunnerOutput:
279279
output.sampled_token_ids = valid_sampled_token_ids
280280
return output
281281

282+
class ExecuteModelState(NamedTuple):
283+
"""Ephemeral cached state transferred between execute_model() and
284+
sample_tokens(), after execute_model() returns None."""
285+
286+
scheduler_output: "SchedulerOutput"
287+
logits: torch.Tensor
288+
spec_decode_metadata: SpecDecodeMetadata | None
289+
hidden_states: torch.Tensor
290+
sample_hidden_states: torch.Tensor
291+
aux_hidden_states: list[torch.Tensor] | None
292+
kv_connector_output: KVConnectorOutput | None
282293

283294
class NPUModelRunner(LoRAModelRunnerMixin):
284295

@@ -610,6 +621,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
610621
# TODO: EVS Support (Video tokens pruning) (see vllm#22980)
611622
self.is_multimodal_pruning_enabled = False
612623

624+
# Ephemeral state transferred between execute_model() and sample_tokens().
625+
self.execute_model_state: ExecuteModelState | None = None
626+
613627
def _set_up_drafter(self):
614628
# Set up speculative decoding.
615629
self.spec_attn_mask = None
@@ -2444,7 +2458,13 @@ def execute_model(
24442458
self,
24452459
scheduler_output: "SchedulerOutput",
24462460
intermediate_tensors: Optional[IntermediateTensors] = None,
2447-
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
2461+
) -> Union[ModelRunnerOutput, IntermediateTensors] | None:
2462+
if self.execute_model_state is not None:
2463+
raise RuntimeError(
2464+
"State error: sample_tokens() must be called "
2465+
"after execute_model() returns None."
2466+
)
2467+
24482468
with ProfileExecuteDuration().capture_async("prepare input"):
24492469
self._update_states(scheduler_output)
24502470
if not scheduler_output.total_num_scheduled_tokens:
@@ -2558,9 +2578,43 @@ def execute_model(
25582578
logits = self.apply_grammar_bitmask(
25592579
scheduler_output, logits)
25602580
else:
2561-
if scheduler_output.structured_output_request_ids:
2562-
logits = self.apply_grammar_bitmask(
2563-
scheduler_output, logits)
2581+
self.execute_model_state = ExecuteModelState(
2582+
scheduler_output,
2583+
logits,
2584+
spec_decode_metadata,
2585+
hidden_states,
2586+
sample_hidden_states,
2587+
aux_hidden_states,
2588+
kv_connector_output,
2589+
)
2590+
return None
2591+
2592+
@torch.inference_mode
2593+
def sample_tokens(
2594+
self, grammar_output: "GrammarOutput | None"
2595+
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
2596+
if self.execute_model_state is None:
2597+
# Nothing to do (PP non-final rank case), output isn't used.
2598+
return None # noqa
2599+
2600+
# Unpack ephemeral state.
2601+
(
2602+
scheduler_output,
2603+
logits,
2604+
spec_decode_metadata,
2605+
hidden_states,
2606+
sample_hidden_states,
2607+
aux_hidden_states,
2608+
kv_connector_output,
2609+
) = self.execute_model_state
2610+
# Clear ephemeral state.
2611+
self.execute_model_state = None
2612+
2613+
# Apply structured output bitmasks if present.
2614+
if grammar_output is not None:
2615+
logits = self.apply_grammar_bitmask(
2616+
scheduler_output, grammar_output, self.input_batch, logits
2617+
)
25642618

25652619
with ProfileExecuteDuration().capture_async("Sample"):
25662620
# Sample the next token and get logprobs if needed.

vllm_ascend/worker/worker_v1.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import copy
2121
from typing import Optional, Union
22+
from types import NoneType
2223

2324
import torch
2425
import torch.nn as nn
@@ -35,7 +36,7 @@
3536
from vllm.lora.request import LoRARequest
3637
from vllm.sequence import IntermediateTensors
3738
from vllm.tasks import SupportedTask
38-
from vllm.v1.core.sched.output import SchedulerOutput
39+
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
3940
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
4041
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
4142
DraftTokenIds, ModelRunnerOutput)
@@ -274,7 +275,7 @@ def determine_available_memory(self) -> int:
274275
def execute_model(
275276
self,
276277
scheduler_output: "SchedulerOutput",
277-
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
278+
) -> ModelRunnerOutput | None:
278279
# enable msMonitor to monitor the performance of vllm-ascend
279280
if envs_ascend.MSMONITOR_USE_DAEMON:
280281
dp.step()
@@ -288,7 +289,7 @@ def execute_model(
288289

289290
output = self.model_runner.execute_model(scheduler_output,
290291
intermediate_tensors)
291-
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
292+
if isinstance(output, (ModelRunnerOutput, NoneType)):
292293
return output
293294

294295
assert isinstance(output, IntermediateTensors)
@@ -312,6 +313,12 @@ def execute_model(
312313
output.kv_connector_output = kv_connector_output
313314
return output
314315

316+
@torch.inference_mode()
317+
def sample_tokens(
318+
self, grammar_output: "GrammarOutput"
319+
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
320+
return self.model_runner.sample_tokens(grammar_output)
321+
315322
def load_model(self) -> None:
316323
if self.vllm_config.model_config.enable_sleep_mode:
317324
allocator = CaMemAllocator.get_instance()

0 commit comments

Comments
 (0)