|
180 | 180 |
|
181 | 181 | if TYPE_CHECKING: |
182 | 182 | import xgrammar as xgr # type: ignore[import-untyped] |
183 | | - from vllm.v1.core.sched.output import SchedulerOutput |
| 183 | + from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput |
184 | 184 | else: |
185 | 185 | xgr = LazyLoader("xgr", globals(), "xgrammar") |
186 | 186 |
|
@@ -279,6 +279,17 @@ def get_output(self) -> ModelRunnerOutput: |
279 | 279 | output.sampled_token_ids = valid_sampled_token_ids |
280 | 280 | return output |
281 | 281 |
|
| 282 | +class ExecuteModelState(NamedTuple): |
| 283 | + """Ephemeral cached state transferred between execute_model() and |
| 284 | + sample_tokens(), after execute_model() returns None.""" |
| 285 | + |
| 286 | + scheduler_output: "SchedulerOutput" |
| 287 | + logits: torch.Tensor |
| 288 | + spec_decode_metadata: SpecDecodeMetadata | None |
| 289 | + hidden_states: torch.Tensor |
| 290 | + sample_hidden_states: torch.Tensor |
| 291 | + aux_hidden_states: list[torch.Tensor] | None |
| 292 | + kv_connector_output: KVConnectorOutput | None |
282 | 293 |
|
283 | 294 | class NPUModelRunner(LoRAModelRunnerMixin): |
284 | 295 |
|
@@ -600,6 +611,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): |
600 | 611 | # TODO: EVS Support (Video tokens pruning) (see vllm#22980) |
601 | 612 | self.is_multimodal_pruning_enabled = False |
602 | 613 |
|
| 614 | + # Ephemeral state transferred between execute_model() and sample_tokens(). |
| 615 | + self.execute_model_state: ExecuteModelState | None = None |
| 616 | + |
603 | 617 | def _set_up_drafter(self): |
604 | 618 | # Set up speculative decoding. |
605 | 619 | self.spec_attn_mask = None |
@@ -2423,7 +2437,13 @@ def execute_model( |
2423 | 2437 | self, |
2424 | 2438 | scheduler_output: "SchedulerOutput", |
2425 | 2439 | intermediate_tensors: Optional[IntermediateTensors] = None, |
2426 | | - ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: |
| 2440 | + ) -> Union[ModelRunnerOutput, IntermediateTensors] | None: |
| 2441 | + if self.execute_model_state is not None: |
| 2442 | + raise RuntimeError( |
| 2443 | + "State error: sample_tokens() must be called " |
| 2444 | + "after execute_model() returns None." |
| 2445 | + ) |
| 2446 | + |
2427 | 2447 | with ProfileExecuteDuration().capture_async("prepare input"): |
2428 | 2448 | self._update_states(scheduler_output) |
2429 | 2449 | if not scheduler_output.total_num_scheduled_tokens: |
@@ -2537,9 +2557,43 @@ def execute_model( |
2537 | 2557 | logits = self.apply_grammar_bitmask( |
2538 | 2558 | scheduler_output, logits) |
2539 | 2559 | else: |
2540 | | - if scheduler_output.structured_output_request_ids: |
2541 | | - logits = self.apply_grammar_bitmask( |
2542 | | - scheduler_output, logits) |
| 2560 | + self.execute_model_state = ExecuteModelState( |
| 2561 | + scheduler_output, |
| 2562 | + logits, |
| 2563 | + spec_decode_metadata, |
| 2564 | + hidden_states, |
| 2565 | + sample_hidden_states, |
| 2566 | + aux_hidden_states, |
| 2567 | + kv_connector_output, |
| 2568 | + ) |
| 2569 | + return None |
| 2570 | + |
| 2571 | + @torch.inference_mode |
| 2572 | + def sample_tokens( |
| 2573 | + self, grammar_output: "GrammarOutput | None" |
| 2574 | + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: |
| 2575 | + if self.execute_model_state is None: |
| 2576 | + # Nothing to do (PP non-final rank case), output isn't used. |
| 2577 | + return None # noqa |
| 2578 | + |
| 2579 | + # Unpack ephemeral state. |
| 2580 | + ( |
| 2581 | + scheduler_output, |
| 2582 | + logits, |
| 2583 | + spec_decode_metadata, |
| 2584 | + hidden_states, |
| 2585 | + sample_hidden_states, |
| 2586 | + aux_hidden_states, |
| 2587 | + kv_connector_output, |
| 2588 | + ) = self.execute_model_state |
| 2589 | + # Clear ephemeral state. |
| 2590 | + self.execute_model_state = None |
| 2591 | + |
| 2592 | + # Apply structured output bitmasks if present. |
| 2593 | + if grammar_output is not None: |
| 2594 | + logits = self.apply_grammar_bitmask( |
| 2595 | + scheduler_output, grammar_output, self.input_batch, logits |
| 2596 | + ) |
2543 | 2597 |
|
2544 | 2598 | with ProfileExecuteDuration().capture_async("Sample"): |
2545 | 2599 | # Sample the next token and get logprobs if needed. |
|
0 commit comments