diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index d2ee0dd..84dd0fd 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -184,7 +184,13 @@ def forward( # TODO: support inference raise NotImplementedError('GDN does not support inference for now.') - cu_seqlens = None if packed_seq_params is None else packed_seq_params.cu_seqlens_q + # Only use packed sequences when qkv_format is 'thd' and cu_seqlens_q is present + cu_seqlens = None + if (packed_seq_params is not None and + hasattr(packed_seq_params, 'qkv_format') and + packed_seq_params.qkv_format == 'thd' and + packed_seq_params.cu_seqlens_q is not None): + cu_seqlens = packed_seq_params.cu_seqlens_q # Input projection num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size nvtx_range_push(suffix='in_proj')