Skip to content

[Feat]: Support qwen35 with mtp#4437

Merged
lvhan028 merged 63 commits intoInternLM:mainfrom
RunningLeon:qwen35-mtp-dev
Apr 3, 2026
Merged

[Feat]: Support qwen35 with mtp#4437
lvhan028 merged 63 commits intoInternLM:mainfrom
RunningLeon:qwen35-mtp-dev

Conversation

@RunningLeon
Copy link
Copy Markdown
Collaborator

@RunningLeon RunningLeon commented Mar 20, 2026

Motivation

Suport Qwen3.5 mtp

api_server

lmdeploy serve api_server \
Qwen/Qwen3.5-35B-A3B \
--backend pytorch \
--tp 2 \
--speculative-algorithm 'qwen3_5_mtp' \
--speculative-num-draft-tokens 3 \
--max-batch-size 128 \
--session-len 65536

pipeline

from lmdeploy import pipeline, PytorchEngineConfig
from lmdeploy.messages import SpeculativeConfig

if __name__ == '__main__':
    prompts = ['Hi, pls intro yourself', 'Shanghai is']
    model_path = 'Qwen/Qwen3.5-35B-A3B'
    spec_cfg = SpeculativeConfig(method='qwen3_5_mtp', 
                                    num_speculative_tokens=3,
                                    model=model_path,
                                    )
    pipe = pipeline(model_path, 
                    backend_config=PytorchEngineConfig(max_batch_size=128),
                    speculative_config=spec_cfg)
    response = pipe(prompts)
    print(response)

Modification

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@RunningLeon RunningLeon removed the WIP label Apr 3, 2026
@RunningLeon RunningLeon changed the title [WIP]: qwen35 mtp [Feat]: Support qwen35 with mtp Apr 3, 2026
@RunningLeon RunningLeon marked this pull request as ready for review April 3, 2026 04:08
Copilot AI review requested due to automatic review settings April 3, 2026 04:08
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds PyTorch speculative decoding support for Qwen3.5 via an MTP-based draft model/proposer, including routing-expert recording and long-context chunk handling.

Changes:

  • Introduces qwen3_5_mtp proposer + Qwen3_5MTPModel and wires it through config/model maps/CLI & benchmarks.
  • Refactors speculative decoding flow to run sampling + rejection sampling inside SpecModelAgent, with expanded/sliced SamplingInputs and logprobs plumbing.
  • Extends long-context chunking + MROPE/state-cache handling to support spec decoding and multimodal chunk cases; adds new unit tests for spec agent + rejection sampler.

Reviewed changes

Copilot reviewed 45 out of 45 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/pytorch/spec_decode/test_spec_agent.py New tests for spec-agent sampling/logprobs + SamplingInputs expand/slice helpers
tests/pytorch/spec_decode/test_reject_sample.py New tests for rejection sampling + Triton kernels
lmdeploy/utils.py Allows is_bf16_supported('auto') to take CUDA path
lmdeploy/pytorch/strategies/dllm/model_agent.py Updates to functional ModelInputs.step() return value
lmdeploy/pytorch/strategies/ar/model_agent.py Updates to functional ModelInputs.step() return value
lmdeploy/pytorch/strategies/ar_spec/sequence.py Records routed experts alongside token updates; handles per-token expert splits
lmdeploy/pytorch/strategies/ar_spec/sampling.py Adds num_spec_tokens to ARSpec sampling strategy
lmdeploy/pytorch/strategies/ar_spec/model_inputs.py Spec-decoding dummy inputs tweaks; MROPE pos-id reshaping during input updates
lmdeploy/pytorch/strategies/ar_spec/model_agent.py Extends ARSpec extra inputs (embeds/logprobs), cloning/merge/update logic; prefill/decoding adjustments
lmdeploy/pytorch/strategies/ar_spec/engine.py Adds get_num_required_tokens() for scheduling in spec decode
lmdeploy/pytorch/strategies/ar_spec/init.py Passes num_spec_tokens into ARSpec sampling strategy
lmdeploy/pytorch/spec_decode/spec_agent.py Major refactor: sampling + rejection sampling inside spec agent; chunk carry-over; input-embed support
lmdeploy/pytorch/spec_decode/reject_sampler.py Adds Triton greedy/random rejection sampling kernels; supports mixed greedy/random batches
lmdeploy/pytorch/spec_decode/proposers/qwen3_5_mtp.py New proposer registering qwen3_5_mtp (shares target embeddings)
lmdeploy/pytorch/spec_decode/proposers/base.py Makes decoding input update functional; adds embed_input_ids helper
lmdeploy/pytorch/spec_decode/proposers/init.py Exports Qwen3.5 MTP proposer
lmdeploy/pytorch/spec_decode/base.py Base spec agent now stores SpecDecodeConfig + num_spec_tokens
lmdeploy/pytorch/spec_decode/init.py Passes misc_config into spec-agent builder; initializes base agent with config
lmdeploy/pytorch/paging/scheduler.py Renames scheduling arg to num_required_tokens
lmdeploy/pytorch/nn/gated_delta.py Adds spec-decoding state/conv offset handling + cache seqlens plumbing
lmdeploy/pytorch/models/utils/cudagraph.py Adds block_size to graph meta; updates FA3 metadata building and MROPE requirement
lmdeploy/pytorch/models/qwen3_5.py Adds optional input-embed return for spec/multimodal chunking; attention head gating + TP toggles
lmdeploy/pytorch/models/qwen3_5_mtp.py New Qwen3.5 MTP draft model implementation + weight loader
lmdeploy/pytorch/models/qwen3_5_moe.py Adds is_tp parameter plumbing; tracks spec-decoding build context
lmdeploy/pytorch/models/module_map.py Registers Qwen3_5MTPModel in module map
lmdeploy/pytorch/models/deepseek_mtp.py Removes position-0 embedding masking
lmdeploy/pytorch/model_inputs.py Adds target_inputs_embeds, chunk flags, clone(), and makes step() functional (non-mutating)
lmdeploy/pytorch/kernels/cuda/pagedattention.py Casts block offsets to int64 in kernels
lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py Minor cleanup (removes stray whitespace)
lmdeploy/pytorch/engine/model_agent/agent.py Integrates spec-agent into sampling path; passes misc_config; chunk-output lifecycle tweaks; shields async postprocess
lmdeploy/pytorch/engine/inputs_maker.py Tracks multimodal presence for chunking; schedules with num_required_tokens; sets chunk flags
lmdeploy/pytorch/engine/executor/base.py Changes default num_state_caches sizing
lmdeploy/pytorch/engine/executor/init.py Passes spec_method/num_spec_tokens/block_size into model config building
lmdeploy/pytorch/engine/engine_loop.py Adjusts logprobs aggregation to support multi-token steps (spec decode)
lmdeploy/pytorch/configurations/qwen3_5.py Adds spec/draft model config handling; adjusts state shapes for spec decoding
lmdeploy/pytorch/config.py Plumbs num_spec_tokens + block_size into model config construction
lmdeploy/pytorch/backends/gated_delta_rule.py Extends gated-delta interface to accept spec_state_offsets
lmdeploy/pytorch/backends/cuda/op_backend.py Uses model_config.block_size for flash-attn metadata
lmdeploy/pytorch/backends/cuda/graph_runner.py Stores block_size in CUDA graph meta
lmdeploy/pytorch/backends/cuda/gated_delta_rule.py Adds Triton select/scatter for spec-state offsets; plumbs cache seqlens to recurrent path
lmdeploy/pytorch/backends/cuda/causal_conv1d.py Extends conv update to accept cache_seqlens
lmdeploy/pytorch/backends/causal_conv1d.py Extends conv update interface to accept cache_seqlens
lmdeploy/cli/utils.py Adds qwen3_5_mtp to --speculative-algorithm choices
benchmark/profile_throughput.py Adds speculative decode CLI parsing and passes config to engine
benchmark/profile_pipeline_api.py Adds speculative decode CLI parsing and passes config into pipeline

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +179 to +196
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)

# 3. Uniform random [batch, num_spec] (float64 to avoid exact 0.0)
uniform_probs = torch.rand(
(batch_size, num_spec_tokens),
dtype=torch.float64,
device=device,
)

# 4. Recovered tokens via Gumbel-max trick
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
inv_q = q.reciprocal()

Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The random-rejection path uses torch.rand(...) and q.exponential_() without using sampling_inputs.random_seeds/random_offsets, so results will depend on the global RNG state and won’t be reproducible per-request (unlike the normal sampling path which uses seeded multinomial_sampling). Please wire in sampling_inputs RNG (seeds/offsets) for both uniform_probs and the Gumbel/exponential noise so speculative decoding stays deterministic under the same sampling inputs.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may improve in another pr

Comment on lines +8 to +16
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def _run_async(coro):
"""Helper to run async function in sync test."""
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test file falls back to device='cpu' when CUDA is unavailable, but the exercised code path (FusedLogitsProcessor via async_sampling_logits) unconditionally uses CUDA stream APIs (torch.cuda.current_stream()), so it will error on non-CUDA runners. Add a pytest.mark.skipif(not torch.cuda.is_available(), ...) (or otherwise guard) so CPU-only CI doesn’t fail.

Copilot uses AI. Check for mistakes.
Comment on lines +13 to +20
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def _make_peaked_logits(token_ids_2d, vocab_size):
"""Build logits where argmax(dim=-1) == token_ids_2d.

token_ids_2d: list[list[int]] or Tensor [batch, num_spec]
"""
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests fall back to CPU when CUDA isn’t available, but rejection_sample and the direct kernel invocations rely on Triton CUDA kernels. Without a skip guard, CPU-only CI will fail. Add pytest.mark.skipif(not torch.cuda.is_available(), ...) (and/or a Triton availability check) around these tests/classes.

Copilot uses AI. Check for mistakes.
# add more caches for eviction
# TODO: Share memory between state cache and pageable cache
num_state_caches = int(cache_config.max_batches + 8)
num_state_caches = int(cache_config.max_batches + 1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment on the "+1" here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just allocate one more state to be used for padding.

@lvhan028 lvhan028 merged commit 12c877c into InternLM:main Apr 3, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants