Skip to content

Commit 34460c5

Browse files
committed
Fix code review suggestions from PR InternLM#4389
- Fix tensor parallelism in module.py with proper assertions and validation - Add assertions for rotation dimension validation - Ensure rotary_dim is even for proper reshaping - Add bounds checking for rotary_dim vs size_per_head - Add divisibility check for output dimensions - Fix attribute name typo in qwen.py - Correct 'attn_layer_patten' to 'attn_layer_pattern' in Qwen3_5ReaderMixin - Improve MSVC compiler compatibility in rms_norm.cu - Use std::decay_t for proper template type deduction across compilers
1 parent 847e04c commit 34460c5

3 files changed

Lines changed: 52 additions & 3 deletions

File tree

lmdeploy/turbomind/deploy/module.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ def permute_v2_partial(x: torch.Tensor, size_per_head: int, rotary_dim: int):
2929
layout.
3030
"""
3131
assert x.size(-1) > 1
32+
assert rotary_dim % 2 == 0, f'rotary_dim must be even, got {rotary_dim}'
33+
assert rotary_dim <= size_per_head, f'rotary_dim ({rotary_dim}) must be <= size_per_head ({size_per_head})'
3234
output_dims = x.size(-1)
35+
assert output_dims % size_per_head == 0, f'output_dims ({output_dims}) must be divisible by size_per_head ({size_per_head})'
3336
head_num = output_dims // size_per_head
3437
orig_shape = x.shape
3538
if x.dim() == 1:
@@ -483,6 +486,37 @@ class LinearAttn(Module):
483486
def __init__(self, model: BaseOutputModel):
484487
self.model = model
485488
self.tp = model.attn_tp_size
489+
cfg = model.model_config
490+
self.key_dim = cfg.linear_num_key_heads * cfg.linear_key_head_dim
491+
self.value_dim = cfg.linear_num_value_heads * cfg.linear_value_head_dim
492+
493+
def _tp_interleave_qkv(self, tensor, dim):
494+
"""Split a concatenated [Q, K, V] tensor into components, reshape each
495+
for TP interleaving, and re-concatenate.
496+
497+
in_proj_qkv layout along ``dim``: Q(key_dim) | K(key_dim) | V(value_dim).
498+
A naive split doesn't respect component boundaries when key_dim and
499+
value_dim differ. This method splits Q/K/V, reshapes each to
500+
``(tp, -1)`` along ``dim``, concatenates per-TP-shard, then flattens
501+
so that a subsequent ``save_split(split_dim=dim)`` gives each rank the
502+
correct portion.
503+
"""
504+
if dim < 0:
505+
dim = tensor.dim() + dim
506+
q, k, v = torch.split(tensor, [self.key_dim, self.key_dim, self.value_dim], dim=dim)
507+
508+
def reshape(x):
509+
# Move TP axis to a new dimension right after ``dim``
510+
shape = list(x.shape)
511+
d = shape[dim]
512+
new_shape = shape[:dim] + [self.tp, d // self.tp] + shape[dim + 1:]
513+
return x.view(new_shape)
514+
515+
parts = torch.cat([reshape(q), reshape(k), reshape(v)], dim=dim + 1)
516+
# Collapse tp and per-shard dims back
517+
shape = list(parts.shape)
518+
final_shape = shape[:dim] + [shape[dim] * shape[dim + 1]] + shape[dim + 2:]
519+
return parts.reshape(final_shape)
486520

487521
def apply(self, i: int, r: BaseReader):
488522
layer_types = getattr(self.model.model_config, 'layer_types', [])
@@ -499,6 +533,10 @@ def apply(self, i: int, r: BaseReader):
499533
if tensor is None:
500534
continue
501535
if name == 'conv1d':
536+
# conv1d shape: (conv_dim, 1, d_conv) where
537+
# conv_dim = key_dim*2 + value_dim. Interleave Q/K/V
538+
# portions along dim 0 before splitting for TP.
539+
tensor = self._tp_interleave_qkv(tensor, dim=0)
502540
self.model.save_split(tensor,
503541
self._linear_attn.format(i, name, kind),
504542
split_dim=0,
@@ -515,6 +553,17 @@ def apply(self, i: int, r: BaseReader):
515553
self._linear_attn.format(i, name, kind),
516554
split_dim=0,
517555
split_num=self.tp)
556+
elif name == 'in_proj_qkv':
557+
# in_proj_qkv: (conv_dim, hidden) where conv_dim =
558+
# key_dim*2 + value_dim. After transpose the QKV
559+
# components are along dim -1. Interleave for TP so
560+
# each shard gets the correct Q/K/V slice.
561+
t = transpose(tensor)
562+
t = self._tp_interleave_qkv(t, dim=-1)
563+
self.model.save_split(t,
564+
self._linear_attn.format(i, name, kind),
565+
split_dim=-1,
566+
split_num=self.tp)
518567
else:
519568
self.model.save_split(transpose(tensor),
520569
self._linear_attn.format(i, name, kind),

lmdeploy/turbomind/deploy/source_model/qwen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class Qwen3_5ReaderMixin:
228228
(``Qwen3_5MoeRMSNormGated``) uses standard weight and is NOT affected.
229229
"""
230230

231-
attn_layer_patten = r'(?:model\.language_model\.|model\.)layers\.([0-9]+)\.'
231+
attn_layer_pattern = r'(?:model\.language_model\.|model\.)layers\.([0-9]+)\.'
232232

233233
_LINEAR_ATTN_KEYS = ['conv1d', 'in_proj_qkv', 'in_proj_z', 'in_proj_b', 'in_proj_a', 'out_proj', 'A_log', 'dt_bias']
234234

src/turbomind/kernels/norm/rms_norm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ void invokeQkRMSNorm(void* data,
193193
{
194194

195195
auto launch = [&](auto max_dim_c) {
196-
constexpr int kMaxDim = decltype(max_dim_c)::value;
196+
constexpr int kMaxDim = std::decay_t<decltype(max_dim_c)>::value;
197197
TM_CHECK_LE(head_dim, kMaxDim);
198198

199199
auto invoke = [&](auto t) {
@@ -236,7 +236,7 @@ void invokeRMSNormQK(Tensor& x, const Tensor& w, float eps, cudaStream_t st)
236236
auto stride = x.stride(0);
237237

238238
auto launch = [&](auto max_dim_c) {
239-
constexpr int kMaxDim = decltype(max_dim_c)::value;
239+
constexpr int kMaxDim = std::decay_t<decltype(max_dim_c)>::value;
240240
TM_CHECK_LE(head_dim, kMaxDim);
241241

242242
auto invoke = [&](auto t) {

0 commit comments

Comments
 (0)