[Feature] Add TurboMind support for Qwen3.5 models (dense + MoE)#4389
[Feature] Add TurboMind support for Qwen3.5 models (dense + MoE)#4389lvhan028 merged 16 commits intoInternLM:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds TurboMind backend support for Qwen3.5 dense + MoE models by introducing Gated DeltaNet linear attention, head_dim=256 kernel support, and converter/export updates for mixed attention + mixed quantization.
Changes:
- Implemented Gated DeltaNet linear-attention layer + CUDA kernels and integrated it into the UnifiedDecoder execution path.
- Extended attention/decoding and RMSNorm kernels to support head_dim=256 and added optional attention output gating.
- Updated Python converter/reader/export modules to register Qwen3.5 architectures, export linear-attn weights, and handle mixed AWQ quantization and per-layer unquantized overrides.
Reviewed changes
Copilot reviewed 53 out of 53 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| src/turbomind/turbomind.cc | Parses new Qwen3.5 layer_types and linear-attention config fields into ModelParam |
| src/turbomind/models/llama/unified_decoder.h | Adds GatedDeltaNetLayer member to decoder |
| src/turbomind/models/llama/unified_decoder.cc | Dispatches per-layer attention vs linear-attention and forwards correct bias to residual path |
| src/turbomind/models/llama/unified_attention_layer.cc | Adds optional output gating and adjusts QKV stride accounting |
| src/turbomind/models/llama/moe_ffn_layer.cc | Routes sigmoid scoring func to compatible MoE gate kernel path |
| src/turbomind/models/llama/llama_params.h | Adds ModelParam fields for layer_types, linear-attn params, and unquantized_expert_layers |
| src/turbomind/models/llama/llama_kernels.h | Declares sigmoid gate multiply kernel entry point |
| src/turbomind/models/llama/llama_kernels.cu | Implements sigmoid gate multiply CUDA kernel |
| src/turbomind/models/llama/gated_delta_net_kernels.h | Declares CUDA kernels for Gated DeltaNet conv/recurrence/norm utilities |
| src/turbomind/models/llama/gated_delta_net_kernels.cu | Implements CUDA kernels for Gated DeltaNet operations |
| src/turbomind/models/llama/LlamaDenseWeight.h | Extends attention weight ctor to accept attn_output_gate flag |
| src/turbomind/models/llama/LlamaDenseWeight.cc | Adjusts QKV output dim when attn_output_gate is enabled |
| src/turbomind/models/llama/LlamaDecoderLayerWeight.h | Adds GatedDeltaNetWeight pointer to layer weights |
| src/turbomind/models/llama/LlamaDecoderLayerWeight.cc | Instantiates linear-attn weights conditionally and applies mixed-quantization type routing |
| src/turbomind/models/llama/GatedDeltaNetWeight.h | Introduces weight container for Gated DeltaNet tensors |
| src/turbomind/models/llama/GatedDeltaNetWeight.cc | Implements parameter registration and prepare() for linear-attn weights |
| src/turbomind/models/llama/GatedDeltaNetLayer.h | Declares GatedDeltaNet execution layer and request state handling |
| src/turbomind/models/llama/GatedDeltaNetLayer.cc | Implements linear-attn forward pass with per-request persistent state |
| src/turbomind/models/llama/CMakeLists.txt | Adds new GatedDeltaNet sources to Llama static lib |
| src/turbomind/models/CMakeLists.txt | Adds new GatedDeltaNet sources to models static lib |
| src/turbomind/kernels/norm/rms_norm.cu | Extends QK RMSNorm launcher to max_dim=256 |
| src/turbomind/kernels/core/thread_map.h | Clamps WarpThreadC default to WARP_SIZE for larger DimC |
| src/turbomind/kernels/attention/reduce.cu | Adds ReduceV3 instantiations for head_dim=256 |
| src/turbomind/kernels/attention/kv_cache_utils_v2.cu | Adds KV cache processing/flattening dispatch for head_dim=256 |
| src/turbomind/kernels/attention/decoding.cu | Adds decoding dispatch for size_per_head=256 |
| src/turbomind/kernels/attention/codegen/decoding_sm80_256_f16_u8.cu | Adds SM80 decoding codegen instantiations (f16/u8, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm80_256_f16_u4.cu | Adds SM80 decoding codegen instantiations (f16/u4, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm80_256_f16_f16.cu | Adds SM80 decoding codegen instantiations (f16/f16, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm80_256_bf16_u8.cu | Adds SM80 decoding codegen instantiations (bf16/u8, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm80_256_bf16_u4.cu | Adds SM80 decoding codegen instantiations (bf16/u4, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm80_256_bf16_bf16.cu | Adds SM80 decoding codegen instantiations (bf16/bf16, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm75_256_f16_u8.cu | Adds SM75 decoding codegen instantiations (f16/u8, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm75_256_f16_u4.cu | Adds SM75 decoding codegen instantiations (f16/u4, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm75_256_f16_f16.cu | Adds SM75 decoding codegen instantiations (f16/f16, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm70_256_f16_u8.cu | Adds SM70 decoding codegen instantiations (f16/u8, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm70_256_f16_u4.cu | Adds SM70 decoding codegen instantiations (f16/u4, head_dim=256) |
| src/turbomind/kernels/attention/codegen/decoding_sm70_256_f16_f16.cu | Adds SM70 decoding codegen instantiations (f16/f16, head_dim=256) |
| src/turbomind/kernels/attention/codegen/attention_sm80_256_f16.cu | Adds SM80 attention codegen instantiations (f16, head_dim=256) |
| src/turbomind/kernels/attention/codegen/attention_sm80_256_bf16.cu | Adds SM80 attention codegen instantiations (bf16, head_dim=256) |
| src/turbomind/kernels/attention/codegen/attention_sm75_256_f16.cu | Adds SM75 attention codegen instantiations (f16, head_dim=256) |
| src/turbomind/kernels/attention/codegen/attention_sm70_256_f16.cu | Adds SM70 attention codegen instantiations (f16, head_dim=256) |
| src/turbomind/kernels/attention/attention_config.h | Adds head_dim=256 attention config specializations for SM70/SM75 |
| src/turbomind/kernels/attention/attention.cu | Adds attention dispatch for size_per_head=256 |
| src/turbomind/kernels/attention/CMakeLists.txt | Adds new 256 codegen compilation units |
| src/turbomind/engine/request.h | Adds persistent per-request Gated DeltaNet state tensors |
| src/turbomind/core/module.h | Adds include guards and minor formatting adjustments |
| lmdeploy/turbomind/supported_models.py | Registers Qwen3.5 architectures for TurboMind |
| lmdeploy/turbomind/deploy/source_model/qwen.py | Adds Qwen3.5 reader/model info including linear-attn export and mixed-AWQ handling |
| lmdeploy/turbomind/deploy/source_model/llama.py | Makes intermediate_size optional in config parsing |
| lmdeploy/turbomind/deploy/module.py | Adds partial RoPE permute, Q+gate splitting, QKVG merge, and LinearAttn export module |
| lmdeploy/turbomind/deploy/converter.py | Adds per-layer unquantized overrides and avoids overwriting TP sizes when engine config fields are None |
| lmdeploy/turbomind/deploy/config.py | Adds Qwen3.5 config fields (layer_types, linear-attn params, gating, unquantized_expert_layers) |
| lmdeploy/archs.py | Treats Qwen3.5 architectures as non-VL LLMs in VL detection |
Comments suppressed due to low confidence (7)
src/turbomind/models/llama/gated_delta_net_kernels.cu:1
- The causal_conv1d_* kernels are launched with a 2D grid (grid.x=blocks, grid.y=batch_size) but the kernel indexing ignores blockIdx.y, causing each y-slice to redundantly process the full batch and race on out/conv_states. Fix by incorporating blockIdx.y into the global index (or use blockIdx.y as the batch index and only iterate channels in x), and set total to batch_size * conv_dim (and * seq_len for prefill) without relying on gridDim.y.
src/turbomind/models/llama/gated_delta_net_kernels.cu:1 batch_sizeis currently unused in invokeFusedConv1dSiLU, which can trigger unused-parameter warnings (and potentially break builds that treat warnings as errors). Either remove it from the API or explicitly mark it unused (e.g., (void)batch_size), or implement the intended state offsetting logic using batch_size.
src/turbomind/models/llama/GatedDeltaNetLayer.cc:1- GatedDeltaNetLayer::Forward performs multiple GPU tensor allocations inside the per-request loop (conv_out, q/k/v_contig, optional expanded buffers). This will add significant allocator overhead at runtime (especially for decode where this runs every step). Consider switching these to a reusable workspace/scratch allocation (e.g., per-phase buffers sized to max tokens/seq_len) or using an existing workspace allocator to avoid repeated cudaMalloc/free patterns.
src/turbomind/models/llama/GatedDeltaNetLayer.cc:1 - GatedDeltaNetLayer::Forward performs multiple GPU tensor allocations inside the per-request loop (conv_out, q/k/v_contig, optional expanded buffers). This will add significant allocator overhead at runtime (especially for decode where this runs every step). Consider switching these to a reusable workspace/scratch allocation (e.g., per-phase buffers sized to max tokens/seq_len) or using an existing workspace allocator to avoid repeated cudaMalloc/free patterns.
src/turbomind/models/llama/GatedDeltaNetLayer.cc:1 - GatedDeltaNetLayer::Forward performs multiple GPU tensor allocations inside the per-request loop (conv_out, q/k/v_contig, optional expanded buffers). This will add significant allocator overhead at runtime (especially for decode where this runs every step). Consider switching these to a reusable workspace/scratch allocation (e.g., per-phase buffers sized to max tokens/seq_len) or using an existing workspace allocator to avoid repeated cudaMalloc/free patterns.
src/turbomind/models/llama/GatedDeltaNetLayer.cc:1 - GatedDeltaNetLayer::Forward performs multiple GPU tensor allocations inside the per-request loop (conv_out, q/k/v_contig, optional expanded buffers). This will add significant allocator overhead at runtime (especially for decode where this runs every step). Consider switching these to a reusable workspace/scratch allocation (e.g., per-phase buffers sized to max tokens/seq_len) or using an existing workspace allocator to avoid repeated cudaMalloc/free patterns.
src/turbomind/models/llama/GatedDeltaNetWeight.cc:1 - These dimensions use integer division by tp_size without validating divisibility. If num_k_heads/num_v_heads aren't divisible by tp_size, this will silently truncate and mis-shape weights/state. Add explicit checks (e.g., TM_CHECK_EQ(num_k_heads % tp_size, 0) and TM_CHECK_EQ(num_v_heads % tp_size, 0)) to fail fast with a clear error.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Fix tensor parallelism in module.py with proper assertions and validation - Add assertions for rotation dimension validation - Ensure rotary_dim is even for proper reshaping - Add bounds checking for rotary_dim vs size_per_head - Add divisibility check for output dimensions - Fix attribute name typo in qwen.py - Correct 'attn_layer_patten' to 'attn_layer_pattern' in Qwen3_5ReaderMixin - Improve MSVC compiler compatibility in rms_norm.cu - Use std::decay_t for proper template type deduction across compilers
Latest commit and both TP=1 and TP>=1. check_env: Detailssys.platform: linux Python: 3.10.12 (main, Jan 26 2026, 14:55:28) [GCC 11.4.0] CUDA available: True MUSA available: False numpy_random_seed: 2147483648 GPU 0,1,2,3: Tesla V100-SXM2-32GB CUDA_HOME: /usr/local/cuda NVCC: Cuda compilation tools, release 12.8, V12.8.93 GCC: x86_64-linux-gnu-gcc (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0 PyTorch: 2.10.0+cu128 PyTorch compiling details: PyTorch built with: - GCC 13.3 - C++ Version: 201703 - Intel(R) oneAPI Math Kernel Library Version 2024.2-Product Build 20240605 for Intel(R) 64 architecture applications - Intel(R) MKL-DNN v3.7.1 (Git Hash 8d263e693366ef8db40acc569cc7d8edf644556d) - OpenMP 201511 (a.k.a. OpenMP 4.5) - LAPACK is enabled (usually provided by MKL) - NNPACK is enabled - CPU capability usage: AVX512 - CUDA Runtime 12.8 - NVCC architecture flags: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_120,code=sm_120 - CuDNN 91.0.2 (built against CUDA 12.9) - Magma 2.6.1 - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, COMMIT_SHA=449b1768410104d3ed79d3bcfe4ba1d65c7f22c0, CUDA_VERSION=12.8, CUDNN_VERSION=9.10.2, CXX_COMPILER=/opt/rh/gcc-toolset-13/root/usr/bin/c++, CXX_FLAGS= -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_FBGEMM_GENAI -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.10.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, USE_XCCL=OFF, USE_XPU=OFF,TorchVision: 0.25.0+cu128 Legend: X = Self |
I'll take a better look in a few hours after work, can you please add what command did you use to test? I am testing with e.g.:
|
|
All default, just add |
I haven’t tested the full Qwen models, only QuantTrio AWQ quants (refer to the end of the PR description for the model names). Please use QuantTrio for now if you want to test this PR. |
Confirmed working with QuantTrio/Qwen3.5-122B-A10B-AWQ and x4 v100s. |
What’s the performance for you? Both prompt processing and generation? |
From log: |
|
Hi, @lapy |
|
Could you just update from main to fix the clang-format issue? |
Add TurboMind backend support for Qwen3.5 dense and MoE architectures, featuring Gated DeltaNet linear attention with mixed full/linear layers. Key changes: - C++ Gated DeltaNet layer implementation with fused CUDA kernels for short convolution, decay gate, and delta rule recurrence - Head dimension 256 support for attention and decoding kernels - Mixed attention layer types (linear/full) via per-layer configuration - Linear attention weight export (in_proj_a, in_proj_b, conv1d, etc.) - AWQ mixed quantization support: dequantize attention O-proj to fp16 when QKV are already fp16 (modules_to_not_convert handling) - Separate weight_type / ffn_weight_type / expert_weight_type tracking in converter and C++ weight loading for mixed-quantization models - Model registration for Qwen3_5ForConditionalGeneration and Qwen3_5MoeForConditionalGeneration architectures Tested with Qwen3.5-35B-A3B-AWQ (MoE) and Qwen3.5-27B-AWQ (dense).
- Fix tensor parallelism in module.py with proper assertions and validation - Add assertions for rotation dimension validation - Ensure rotary_dim is even for proper reshaping - Add bounds checking for rotary_dim vs size_per_head - Add divisibility check for output dimensions - Fix attribute name typo in qwen.py - Correct 'attn_layer_patten' to 'attn_layer_pattern' in Qwen3_5ReaderMixin - Improve MSVC compiler compatibility in rms_norm.cu - Use std::decay_t for proper template type deduction across compilers
Address Windreamer's review: swap dispatch nesting order in invokeQkRMSNorm and invokeRMSNormQK to resolve dtype first (outer) and then launch for different head dims (inner).
This commit introduces several algorithmic and micro-architectural optimizations for the Gated Delta Net implementation, significantly reducing kernel launch overhead and improving memory bandwidth utilization, particularly on SM70 (Volta) hardware. Key Optimizations: 1. Algorithmic Changes: - Single-Launch Prefill: Replaced the serial host-side loop with a parallel chunked-scan prefill kernel (invokeGatedDeltaRulePrefill). - In-Kernel GQA & L2Norm: Grouped-Query Attention handling and block-level L2 normalization are now fused directly into the delta rule kernels, eliminating redundant allocations and round-trips. - Fused Input Projections: Fused 4 independent GEMMs into a single \'in_proj_all\' projection matrix, slicing the output columns directly to avoid extra memory reads. 2. SM70-Specific Enhancements: - Vectorization: Added half2/nv_bfloat162 vectorized memory access to the delta rule, compute_beta_g, and silu kernels, doubling data throughput in memory-bound operations. - Warp-Synchronous Reductions: Optimized block_l2_inv_norm to bypass shared memory and rely purely on __shfl_xor_sync when block size <= 32. - ILP Loop Unrolling: Unrolled d_conv loops in causal Conv1D kernels to improve instruction-level parallelism. Bug Fixes: - Added dynamic striding (in_stride, gate_stride) to ensure kernels correctly access non-contiguous columns output by the fused GEMM.
Done! :) |
Hi @lapy, just a quick follow-up regarding the merge process. On a separate note, we noticed that there's no email address listed on your GitHub profile, so I wanted to ask here: could you share your email with us? We'd love to learn a bit more about your experience if you don't mind sharing |
Thank you! Just sent you an email :) |
|
Getting this error with v100. |


[Feature] Add TurboMind support for Qwen3.5 models (dense + MoE)
Add TurboMind backend support for Qwen3.5 dense and MoE architectures, featuring Gated DeltaNet linear attention with mixed full/linear layers.
Key changes:
Tested with Qwen3.5-35B-A3B-AWQ (MoE) and Qwen3.5-27B-AWQ (dense).
Motivation
Qwen3.5 introduces a new hybrid architecture that alternates between Gated DeltaNet linear attention and standard full attention layers. This is a fundamentally different attention mechanism from what TurboMind currently supports — it replaces softmax attention with a linear recurrence (delta rule) combined with a short convolution and gating, enabling sub-quadratic sequence processing while retaining strong performance.
Qwen3.5 ships in two variants:
Both variants use head dimension 256 (vs the typical 128) and feature AWQ quantization with mixed precision: QKV projections are kept in fp16 (
modules_to_not_convert) while the O projection and FFN/MoE layers are quantized to int4. This PR adds full TurboMind backend support for both variants.Modification
C++ / CUDA (Engine)
GatedDeltaNetLayer.cc/h): New layer type implementing the linear attention forward pass — short 1D convolution → SiLU gate → delta rule recurrence with exponential decay, producing output via a gated projection. Manages its own CUDA workspace allocations for recurrence state and intermediates.gated_delta_net_kernels.cu/h): GPU kernels for:conv1d + silu)log(1 - sigmoid(x))numerically stable path)sm70/sm75/sm80attention and decoding kernels atHeadDim=256, for all KV quantization variants (f16, u4, u8, bf16).unified_decoder.cc/h,llama_params.h): Per-layerlayer_typearray (0=full attention, 1=linear attention) read from model config, dispatching to eitherUnifiedAttentionLayerorGatedDeltaNetLayerat each decoder step.LlamaDecoderLayerWeight.cc/h,LlamaDenseWeight.cc/h): Linear attention weights (in_proj_a,in_proj_b,conv1d,out_proj,gate) loaded alongside traditional attention weights. Addedffn_weight_type/expert_weight_typeseparation so that mixed-quantization models correctly use int4 for FFN layers even when attentionweight_typeis fp16.Python (Converter / Deployment)
supported_models.py,archs.py): RegisteredQwen3_5ForConditionalGenerationandQwen3_5MoeForConditionalGenerationarchitectures.config.py): Addedlayer_types,d_conv(convolution width),expand_k/expand_v(KV expansion ratios),attn_output_gatefields toModelConfig.qwen.py):Qwen3_5ReaderMixinwith:_attn()override that detects mixed-quantization attention (QKV fp16 + O int4 AWQ) and dequantizes O-proj weights to fp16 at conversion time via_awq_dequant().linear_attn()method exporting GDN-specific weights (in_proj_a,in_proj_b,conv1d,out_proj,gate), also with on-the-fly AWQ dequantization for linear attention modules that appear inmodules_to_not_convert.model_info()providing layer type arrays, convolution config, and correctinter_sizehandling (dense vs MoE).module.py): NewLinearAttnexport class for linear attention weight packing;permute_v2_partial()for partial rotary embedding with head_dim 256;merge_qkvg_v2()for Q/K/V/Gate merging;Attn._split_q_gate()for splitting interleaved Q+Gate projections.converter.py):mixed_awqdetection fromquantize_config.modules_to_not_convert; separateweight_type(attention, fp16 for mixed),ffn_weight_type(FFN, int4), andexpert_weight_type(MoE experts, int4); per-layerunquantized_expert_layersfor layers excluded from quantization (e.g.,model.layers.0).BC-breaking (Optional)
No breaking changes. All modifications are additive — new model architectures, new layer types, and new config fields with defaults that preserve existing behavior. Existing TurboMind models are unaffected.
Use cases (Optional)
Checklist
dequantize_gemmfromlmdeploy.pytorch.backends.default.awq_modules.docs/in a follow-up.Test Results
QuantTrio/Qwen3.5-35B-A3B-AWQQuantTrio/Qwen3.5-27B-AWQ