|
180 | 180 |
|
181 | 181 | if TYPE_CHECKING: |
182 | 182 | import xgrammar as xgr # type: ignore[import-untyped] |
183 | | - from vllm.v1.core.sched.output import SchedulerOutput |
| 183 | + from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput |
184 | 184 | else: |
185 | 185 | xgr = LazyLoader("xgr", globals(), "xgrammar") |
186 | 186 |
|
@@ -279,6 +279,17 @@ def get_output(self) -> ModelRunnerOutput: |
279 | 279 | output.sampled_token_ids = valid_sampled_token_ids |
280 | 280 | return output |
281 | 281 |
|
| 282 | +class ExecuteModelState(NamedTuple): |
| 283 | + """Ephemeral cached state transferred between execute_model() and |
| 284 | + sample_tokens(), after execute_model() returns None.""" |
| 285 | + |
| 286 | + scheduler_output: "SchedulerOutput" |
| 287 | + logits: torch.Tensor |
| 288 | + spec_decode_metadata: SpecDecodeMetadata | None |
| 289 | + hidden_states: torch.Tensor |
| 290 | + sample_hidden_states: torch.Tensor |
| 291 | + aux_hidden_states: list[torch.Tensor] | None |
| 292 | + kv_connector_output: KVConnectorOutput | None |
282 | 293 |
|
283 | 294 | class NPUModelRunner(LoRAModelRunnerMixin): |
284 | 295 |
|
@@ -610,6 +621,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): |
610 | 621 | # TODO: EVS Support (Video tokens pruning) (see vllm#22980) |
611 | 622 | self.is_multimodal_pruning_enabled = False |
612 | 623 |
|
| 624 | + # Ephemeral state transferred between execute_model() and sample_tokens(). |
| 625 | + self.execute_model_state: ExecuteModelState | None = None |
| 626 | + |
613 | 627 | def _set_up_drafter(self): |
614 | 628 | # Set up speculative decoding. |
615 | 629 | self.spec_attn_mask = None |
@@ -2444,7 +2458,13 @@ def execute_model( |
2444 | 2458 | self, |
2445 | 2459 | scheduler_output: "SchedulerOutput", |
2446 | 2460 | intermediate_tensors: Optional[IntermediateTensors] = None, |
2447 | | - ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: |
| 2461 | + ) -> Union[ModelRunnerOutput, IntermediateTensors] | None: |
| 2462 | + if self.execute_model_state is not None: |
| 2463 | + raise RuntimeError( |
| 2464 | + "State error: sample_tokens() must be called " |
| 2465 | + "after execute_model() returns None." |
| 2466 | + ) |
| 2467 | + |
2448 | 2468 | with ProfileExecuteDuration().capture_async("prepare input"): |
2449 | 2469 | self._update_states(scheduler_output) |
2450 | 2470 | if not scheduler_output.total_num_scheduled_tokens: |
@@ -2558,9 +2578,43 @@ def execute_model( |
2558 | 2578 | logits = self.apply_grammar_bitmask( |
2559 | 2579 | scheduler_output, logits) |
2560 | 2580 | else: |
2561 | | - if scheduler_output.structured_output_request_ids: |
2562 | | - logits = self.apply_grammar_bitmask( |
2563 | | - scheduler_output, logits) |
| 2581 | + self.execute_model_state = ExecuteModelState( |
| 2582 | + scheduler_output, |
| 2583 | + logits, |
| 2584 | + spec_decode_metadata, |
| 2585 | + hidden_states, |
| 2586 | + sample_hidden_states, |
| 2587 | + aux_hidden_states, |
| 2588 | + kv_connector_output, |
| 2589 | + ) |
| 2590 | + return None |
| 2591 | + |
| 2592 | + @torch.inference_mode |
| 2593 | + def sample_tokens( |
| 2594 | + self, grammar_output: "GrammarOutput | None" |
| 2595 | + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: |
| 2596 | + if self.execute_model_state is None: |
| 2597 | + # Nothing to do (PP non-final rank case), output isn't used. |
| 2598 | + return None # noqa |
| 2599 | + |
| 2600 | + # Unpack ephemeral state. |
| 2601 | + ( |
| 2602 | + scheduler_output, |
| 2603 | + logits, |
| 2604 | + spec_decode_metadata, |
| 2605 | + hidden_states, |
| 2606 | + sample_hidden_states, |
| 2607 | + aux_hidden_states, |
| 2608 | + kv_connector_output, |
| 2609 | + ) = self.execute_model_state |
| 2610 | + # Clear ephemeral state. |
| 2611 | + self.execute_model_state = None |
| 2612 | + |
| 2613 | + # Apply structured output bitmasks if present. |
| 2614 | + if grammar_output is not None: |
| 2615 | + logits = self.apply_grammar_bitmask( |
| 2616 | + scheduler_output, grammar_output, self.input_batch, logits |
| 2617 | + ) |
2564 | 2618 |
|
2565 | 2619 | with ProfileExecuteDuration().capture_async("Sample"): |
2566 | 2620 | # Sample the next token and get logprobs if needed. |
|
0 commit comments