-
Notifications
You must be signed in to change notification settings - Fork 307
[Stick-Breaking Attention] Add Model #599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
db7411a
f34262d
56ffe33
cb99d40
9b2554c
b666a57
1f6b464
b15eaf6
93c98b3
18d3fb2
1bf8ba2
3d037ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,108 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import warnings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from einops import rearrange | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers.utils import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from fla.modules import RMSNorm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from fla.ops.stickbreaking_attn import parallel_stickbreaking_attn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from fla.models.utils import Cache | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger = logging.get_logger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class StickBreakingAttention(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_size: int = 2048, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_heads: int = 32, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_kv_heads: int | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qkv_bias: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qk_norm: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| window_size: int | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_position_embeddings: int | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer_idx: int | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if parallel_stickbreaking_attn is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ImportError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable.", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.hidden_size = hidden_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.num_heads = num_heads | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if num_kv_heads is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.num_kv_heads = self.num_heads | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.num_kv_heads = num_kv_heads | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.num_kv_groups = self.num_heads // self.num_kv_heads | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.head_dim = self.hidden_size // self.num_heads | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.kv_dim = self.num_kv_heads * self.head_dim | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.qkv_bias = qkv_bias | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.qk_norm = qk_norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.window_size = window_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.max_position_embeddings = max_position_embeddings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.layer_idx = layer_idx | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if qk_norm: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.q_norm = RMSNorm(self.head_dim) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.k_norm = RMSNorm(self.head_dim) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def forward( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attention_mask: torch.LongTensor | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| past_key_values: Cache | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_attentions: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_cache: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if attention_mask is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert len(attention_mask.shape) == 2, ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "for padding purposes (0 indicating padding). " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if use_cache: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| warnings.warn( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+77
to
+86
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not silently ignore We assert that an Until the kernel can consume the mask (e.g. by translating it to if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)
+ valid = attention_mask.to(dtype=torch.bool)
+ if not torch.all(valid):
+ raise NotImplementedError(
+ "Padding masks are not yet supported. Please supply packed sequences via `cu_seqlens` "
+ "or implement masking before calling StickBreakingAttention."
+ )📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_cache = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size, q_len, _ = hidden_states.size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.qk_norm: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q, k = self.q_norm(q), self.k_norm(k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cu_seqlens = kwargs.get('cu_seqlens') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o, _rem = parallel_stickbreaking_attn( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q=q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k=k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v=v, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cu_seqlens=cu_seqlens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o = o.reshape(batch_size, q_len, -1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| o = self.o_proj(o) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return o, None, past_key_values | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
|
|
||
| from transformers import AutoConfig, AutoModel, AutoModelForCausalLM | ||
|
|
||
| from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import StickBreakingAttentionConfig | ||
| from fla.models.stickbreaking_attn.modeling_stickbreaking_attn import ( | ||
| StickBreakingAttentionForCausalLM, | ||
| StickBreakingAttentionModel, | ||
| ) | ||
|
|
||
| AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig, exist_ok=True) | ||
| AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel, exist_ok=True) | ||
| AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM, exist_ok=True) | ||
|
|
||
|
|
||
| __all__ = ['StickBreakingAttentionConfig', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel'] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| import warnings | ||
|
|
||
| from transformers.configuration_utils import PretrainedConfig | ||
|
|
||
|
|
||
| class StickBreakingAttentionConfig(PretrainedConfig): | ||
|
|
||
| model_type = 'stickbreaking_attn' | ||
| keys_to_ignore_at_inference = ['past_key_values'] | ||
|
|
||
| def __init__( | ||
| self, | ||
| hidden_size: int = 2048, | ||
| num_hidden_layers: int = 24, | ||
| num_heads: int = 32, | ||
| num_kv_heads: int | None = None, | ||
| qkv_bias: bool = False, | ||
| qk_norm: bool = False, | ||
| window_size: int | None = None, | ||
| max_position_embeddings: int = 2048, | ||
| hidden_ratio: int | None = 4, | ||
| intermediate_size: int | None = None, | ||
| hidden_act: str = "swish", | ||
| initializer_range: float = 0.02, | ||
| elementwise_affine: bool | None = True, | ||
| norm_eps: float = 1e-6, | ||
| use_cache: bool = True, | ||
| pad_token_id: int | None = None, | ||
| bos_token_id: int = 1, | ||
| eos_token_id: int = 2, | ||
| tie_word_embeddings: bool = False, | ||
| fuse_norm: bool = True, | ||
| fuse_swiglu: bool = True, | ||
| fuse_cross_entropy: bool = True, | ||
| fuse_linear_cross_entropy: bool = False, | ||
| use_l2warp: bool = False, | ||
| vocab_size: int = 32000, | ||
| **kwargs, | ||
| ): | ||
| self.hidden_size = hidden_size | ||
| self.num_hidden_layers = num_hidden_layers | ||
| self.num_heads = num_heads | ||
| self.num_kv_heads = num_kv_heads | ||
| self.qkv_bias = qkv_bias | ||
| self.qk_norm = qk_norm | ||
| self.window_size = window_size | ||
| self.max_position_embeddings = max_position_embeddings | ||
|
|
||
| self.hidden_ratio = hidden_ratio | ||
| self.intermediate_size = intermediate_size | ||
| self.hidden_act = hidden_act | ||
|
|
||
| self.initializer_range = initializer_range | ||
| self.elementwise_affine = elementwise_affine | ||
| self.norm_eps = norm_eps | ||
| self.use_cache = use_cache | ||
|
|
||
| self.fuse_norm = fuse_norm | ||
| self.fuse_swiglu = fuse_swiglu | ||
| self.fuse_cross_entropy = fuse_cross_entropy | ||
| self.fuse_linear_cross_entropy = fuse_linear_cross_entropy | ||
| self.use_l2warp = use_l2warp | ||
| self.vocab_size = vocab_size | ||
|
|
||
| if fuse_cross_entropy and fuse_linear_cross_entropy: | ||
| raise ValueError( | ||
| "`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time.", | ||
| ) | ||
| if fuse_linear_cross_entropy: | ||
| warnings.warn( | ||
| "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency " | ||
| "at the potential cost of reduced precision. " | ||
| "If you observe issues like loss divergence, consider disabling this setting.", | ||
| ) | ||
|
|
||
| super().__init__( | ||
| pad_token_id=pad_token_id, | ||
| bos_token_id=bos_token_id, | ||
| eos_token_id=eos_token_id, | ||
| tie_word_embeddings=tie_word_embeddings, | ||
| **kwargs, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate head divisibility and GQA/MQA support.
Avoid silent mis-sharding. Either implement KV grouping or gate it for now.
Apply this diff:
If GQA/MQA is not yet supported by the kernel, also add:
📝 Committable suggestion