[Feat]: Support qwen35 with mtp#4437
Conversation
There was a problem hiding this comment.
Pull request overview
Adds PyTorch speculative decoding support for Qwen3.5 via an MTP-based draft model/proposer, including routing-expert recording and long-context chunk handling.
Changes:
- Introduces
qwen3_5_mtpproposer +Qwen3_5MTPModeland wires it through config/model maps/CLI & benchmarks. - Refactors speculative decoding flow to run sampling + rejection sampling inside
SpecModelAgent, with expanded/slicedSamplingInputsand logprobs plumbing. - Extends long-context chunking + MROPE/state-cache handling to support spec decoding and multimodal chunk cases; adds new unit tests for spec agent + rejection sampler.
Reviewed changes
Copilot reviewed 45 out of 45 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/pytorch/spec_decode/test_spec_agent.py | New tests for spec-agent sampling/logprobs + SamplingInputs expand/slice helpers |
| tests/pytorch/spec_decode/test_reject_sample.py | New tests for rejection sampling + Triton kernels |
| lmdeploy/utils.py | Allows is_bf16_supported('auto') to take CUDA path |
| lmdeploy/pytorch/strategies/dllm/model_agent.py | Updates to functional ModelInputs.step() return value |
| lmdeploy/pytorch/strategies/ar/model_agent.py | Updates to functional ModelInputs.step() return value |
| lmdeploy/pytorch/strategies/ar_spec/sequence.py | Records routed experts alongside token updates; handles per-token expert splits |
| lmdeploy/pytorch/strategies/ar_spec/sampling.py | Adds num_spec_tokens to ARSpec sampling strategy |
| lmdeploy/pytorch/strategies/ar_spec/model_inputs.py | Spec-decoding dummy inputs tweaks; MROPE pos-id reshaping during input updates |
| lmdeploy/pytorch/strategies/ar_spec/model_agent.py | Extends ARSpec extra inputs (embeds/logprobs), cloning/merge/update logic; prefill/decoding adjustments |
| lmdeploy/pytorch/strategies/ar_spec/engine.py | Adds get_num_required_tokens() for scheduling in spec decode |
| lmdeploy/pytorch/strategies/ar_spec/init.py | Passes num_spec_tokens into ARSpec sampling strategy |
| lmdeploy/pytorch/spec_decode/spec_agent.py | Major refactor: sampling + rejection sampling inside spec agent; chunk carry-over; input-embed support |
| lmdeploy/pytorch/spec_decode/reject_sampler.py | Adds Triton greedy/random rejection sampling kernels; supports mixed greedy/random batches |
| lmdeploy/pytorch/spec_decode/proposers/qwen3_5_mtp.py | New proposer registering qwen3_5_mtp (shares target embeddings) |
| lmdeploy/pytorch/spec_decode/proposers/base.py | Makes decoding input update functional; adds embed_input_ids helper |
| lmdeploy/pytorch/spec_decode/proposers/init.py | Exports Qwen3.5 MTP proposer |
| lmdeploy/pytorch/spec_decode/base.py | Base spec agent now stores SpecDecodeConfig + num_spec_tokens |
| lmdeploy/pytorch/spec_decode/init.py | Passes misc_config into spec-agent builder; initializes base agent with config |
| lmdeploy/pytorch/paging/scheduler.py | Renames scheduling arg to num_required_tokens |
| lmdeploy/pytorch/nn/gated_delta.py | Adds spec-decoding state/conv offset handling + cache seqlens plumbing |
| lmdeploy/pytorch/models/utils/cudagraph.py | Adds block_size to graph meta; updates FA3 metadata building and MROPE requirement |
| lmdeploy/pytorch/models/qwen3_5.py | Adds optional input-embed return for spec/multimodal chunking; attention head gating + TP toggles |
| lmdeploy/pytorch/models/qwen3_5_mtp.py | New Qwen3.5 MTP draft model implementation + weight loader |
| lmdeploy/pytorch/models/qwen3_5_moe.py | Adds is_tp parameter plumbing; tracks spec-decoding build context |
| lmdeploy/pytorch/models/module_map.py | Registers Qwen3_5MTPModel in module map |
| lmdeploy/pytorch/models/deepseek_mtp.py | Removes position-0 embedding masking |
| lmdeploy/pytorch/model_inputs.py | Adds target_inputs_embeds, chunk flags, clone(), and makes step() functional (non-mutating) |
| lmdeploy/pytorch/kernels/cuda/pagedattention.py | Casts block offsets to int64 in kernels |
| lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py | Minor cleanup (removes stray whitespace) |
| lmdeploy/pytorch/engine/model_agent/agent.py | Integrates spec-agent into sampling path; passes misc_config; chunk-output lifecycle tweaks; shields async postprocess |
| lmdeploy/pytorch/engine/inputs_maker.py | Tracks multimodal presence for chunking; schedules with num_required_tokens; sets chunk flags |
| lmdeploy/pytorch/engine/executor/base.py | Changes default num_state_caches sizing |
| lmdeploy/pytorch/engine/executor/init.py | Passes spec_method/num_spec_tokens/block_size into model config building |
| lmdeploy/pytorch/engine/engine_loop.py | Adjusts logprobs aggregation to support multi-token steps (spec decode) |
| lmdeploy/pytorch/configurations/qwen3_5.py | Adds spec/draft model config handling; adjusts state shapes for spec decoding |
| lmdeploy/pytorch/config.py | Plumbs num_spec_tokens + block_size into model config construction |
| lmdeploy/pytorch/backends/gated_delta_rule.py | Extends gated-delta interface to accept spec_state_offsets |
| lmdeploy/pytorch/backends/cuda/op_backend.py | Uses model_config.block_size for flash-attn metadata |
| lmdeploy/pytorch/backends/cuda/graph_runner.py | Stores block_size in CUDA graph meta |
| lmdeploy/pytorch/backends/cuda/gated_delta_rule.py | Adds Triton select/scatter for spec-state offsets; plumbs cache seqlens to recurrent path |
| lmdeploy/pytorch/backends/cuda/causal_conv1d.py | Extends conv update to accept cache_seqlens |
| lmdeploy/pytorch/backends/causal_conv1d.py | Extends conv update interface to accept cache_seqlens |
| lmdeploy/cli/utils.py | Adds qwen3_5_mtp to --speculative-algorithm choices |
| benchmark/profile_throughput.py | Adds speculative decode CLI parsing and passes config to engine |
| benchmark/profile_pipeline_api.py | Adds speculative decode CLI parsing and passes config into pipeline |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) | ||
|
|
||
| # 3. Uniform random [batch, num_spec] (float64 to avoid exact 0.0) | ||
| uniform_probs = torch.rand( | ||
| (batch_size, num_spec_tokens), | ||
| dtype=torch.float64, | ||
| device=device, | ||
| ) | ||
|
|
||
| # 4. Recovered tokens via Gumbel-max trick | ||
| q = torch.empty( | ||
| (batch_size, vocab_size), | ||
| dtype=torch.float32, | ||
| device=device, | ||
| ) | ||
| q.exponential_() | ||
| inv_q = q.reciprocal() | ||
|
|
There was a problem hiding this comment.
The random-rejection path uses torch.rand(...) and q.exponential_() without using sampling_inputs.random_seeds/random_offsets, so results will depend on the global RNG state and won’t be reproducible per-request (unlike the normal sampling path which uses seeded multinomial_sampling). Please wire in sampling_inputs RNG (seeds/offsets) for both uniform_probs and the Gumbel/exponential noise so speculative decoding stays deterministic under the same sampling inputs.
There was a problem hiding this comment.
may improve in another pr
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
|
||
|
|
||
| def _run_async(coro): | ||
| """Helper to run async function in sync test.""" | ||
| loop = asyncio.new_event_loop() | ||
| try: | ||
| return loop.run_until_complete(coro) | ||
| finally: |
There was a problem hiding this comment.
This test file falls back to device='cpu' when CUDA is unavailable, but the exercised code path (FusedLogitsProcessor via async_sampling_logits) unconditionally uses CUDA stream APIs (torch.cuda.current_stream()), so it will error on non-CUDA runners. Add a pytest.mark.skipif(not torch.cuda.is_available(), ...) (or otherwise guard) so CPU-only CI doesn’t fail.
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
|
|
||
|
|
||
| def _make_peaked_logits(token_ids_2d, vocab_size): | ||
| """Build logits where argmax(dim=-1) == token_ids_2d. | ||
|
|
||
| token_ids_2d: list[list[int]] or Tensor [batch, num_spec] | ||
| """ |
There was a problem hiding this comment.
These tests fall back to CPU when CUDA isn’t available, but rejection_sample and the direct kernel invocations rely on Triton CUDA kernels. Without a skip guard, CPU-only CI will fail. Add pytest.mark.skipif(not torch.cuda.is_available(), ...) (and/or a Triton availability check) around these tests/classes.
| # add more caches for eviction | ||
| # TODO: Share memory between state cache and pageable cache | ||
| num_state_caches = int(cache_config.max_batches + 8) | ||
| num_state_caches = int(cache_config.max_batches + 1) |
There was a problem hiding this comment.
Could you comment on the "+1" here?
There was a problem hiding this comment.
just allocate one more state to be used for padding.
Motivation
Suport Qwen3.5 mtp
api_server
lmdeploy serve api_server \ Qwen/Qwen3.5-35B-A3B \ --backend pytorch \ --tp 2 \ --speculative-algorithm 'qwen3_5_mtp' \ --speculative-num-draft-tokens 3 \ --max-batch-size 128 \ --session-len 65536pipeline
Modification
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist