@@ -29,7 +29,10 @@ def permute_v2_partial(x: torch.Tensor, size_per_head: int, rotary_dim: int):
2929 layout.
3030 """
3131 assert x .size (- 1 ) > 1
32+ assert rotary_dim % 2 == 0 , f'rotary_dim must be even, got { rotary_dim } '
33+ assert rotary_dim <= size_per_head , f'rotary_dim ({ rotary_dim } ) must be <= size_per_head ({ size_per_head } )'
3234 output_dims = x .size (- 1 )
35+ assert output_dims % size_per_head == 0 , f'output_dims ({ output_dims } ) must be divisible by size_per_head ({ size_per_head } )'
3336 head_num = output_dims // size_per_head
3437 orig_shape = x .shape
3538 if x .dim () == 1 :
@@ -483,6 +486,37 @@ class LinearAttn(Module):
483486 def __init__ (self , model : BaseOutputModel ):
484487 self .model = model
485488 self .tp = model .attn_tp_size
489+ cfg = model .model_config
490+ self .key_dim = cfg .linear_num_key_heads * cfg .linear_key_head_dim
491+ self .value_dim = cfg .linear_num_value_heads * cfg .linear_value_head_dim
492+
493+ def _tp_interleave_qkv (self , tensor , dim ):
494+ """Split a concatenated [Q, K, V] tensor into components, reshape each
495+ for TP interleaving, and re-concatenate.
496+
497+ in_proj_qkv layout along ``dim``: Q(key_dim) | K(key_dim) | V(value_dim).
498+ A naive split doesn't respect component boundaries when key_dim and
499+ value_dim differ. This method splits Q/K/V, reshapes each to
500+ ``(tp, -1)`` along ``dim``, concatenates per-TP-shard, then flattens
501+ so that a subsequent ``save_split(split_dim=dim)`` gives each rank the
502+ correct portion.
503+ """
504+ if dim < 0 :
505+ dim = tensor .dim () + dim
506+ q , k , v = torch .split (tensor , [self .key_dim , self .key_dim , self .value_dim ], dim = dim )
507+
508+ def reshape (x ):
509+ # Move TP axis to a new dimension right after ``dim``
510+ shape = list (x .shape )
511+ d = shape [dim ]
512+ new_shape = shape [:dim ] + [self .tp , d // self .tp ] + shape [dim + 1 :]
513+ return x .view (new_shape )
514+
515+ parts = torch .cat ([reshape (q ), reshape (k ), reshape (v )], dim = dim + 1 )
516+ # Collapse tp and per-shard dims back
517+ shape = list (parts .shape )
518+ final_shape = shape [:dim ] + [shape [dim ] * shape [dim + 1 ]] + shape [dim + 2 :]
519+ return parts .reshape (final_shape )
486520
487521 def apply (self , i : int , r : BaseReader ):
488522 layer_types = getattr (self .model .model_config , 'layer_types' , [])
@@ -499,6 +533,10 @@ def apply(self, i: int, r: BaseReader):
499533 if tensor is None :
500534 continue
501535 if name == 'conv1d' :
536+ # conv1d shape: (conv_dim, 1, d_conv) where
537+ # conv_dim = key_dim*2 + value_dim. Interleave Q/K/V
538+ # portions along dim 0 before splitting for TP.
539+ tensor = self ._tp_interleave_qkv (tensor , dim = 0 )
502540 self .model .save_split (tensor ,
503541 self ._linear_attn .format (i , name , kind ),
504542 split_dim = 0 ,
@@ -515,6 +553,17 @@ def apply(self, i: int, r: BaseReader):
515553 self ._linear_attn .format (i , name , kind ),
516554 split_dim = 0 ,
517555 split_num = self .tp )
556+ elif name == 'in_proj_qkv' :
557+ # in_proj_qkv: (conv_dim, hidden) where conv_dim =
558+ # key_dim*2 + value_dim. After transpose the QKV
559+ # components are along dim -1. Interleave for TP so
560+ # each shard gets the correct Q/K/V slice.
561+ t = transpose (tensor )
562+ t = self ._tp_interleave_qkv (t , dim = - 1 )
563+ self .model .save_split (t ,
564+ self ._linear_attn .format (i , name , kind ),
565+ split_dim = - 1 ,
566+ split_num = self .tp )
518567 else :
519568 self .model .save_split (transpose (tensor ),
520569 self ._linear_attn .format (i , name , kind ),
0 commit comments