Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/mcore_bridge/model/modules/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,13 @@ def forward(
# TODO: support inference
raise NotImplementedError('GDN does not support inference for now.')

cu_seqlens = None if packed_seq_params is None else packed_seq_params.cu_seqlens_q
# Only use packed sequences when qkv_format is 'thd' and cu_seqlens_q is present
cu_seqlens = None
if (packed_seq_params is not None and
hasattr(packed_seq_params, 'qkv_format') and
packed_seq_params.qkv_format == 'thd' and
packed_seq_params.cu_seqlens_q is not None):
cu_seqlens = packed_seq_params.cu_seqlens_q
Comment on lines +187 to +193

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

使用 getattr 可以更安全、更简洁地获取属性。这样不仅能避免多次调用 hasattr 和直接属性访问,还能防止在 packed_seq_params 缺少 cu_seqlens_q 属性时触发 AttributeError 异常。

        # Only use packed sequences when qkv_format is 'thd' and cu_seqlens_q is present
        cu_seqlens = None
        if packed_seq_params is not None and getattr(packed_seq_params, 'qkv_format', None) == 'thd':
            cu_seqlens = getattr(packed_seq_params, 'cu_seqlens_q', None)

# Input projection
num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size
nvtx_range_push(suffix='in_proj')
Expand Down