Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
RodimusAttention,
RWKV6Attention,
RWKV7Attention,
StickBreakingAttention,
)
from fla.models import (
ABCForCausalLM,
Expand Down Expand Up @@ -74,6 +75,8 @@
RWKV6Model,
RWKV7ForCausalLM,
RWKV7Model,
StickBreakingAttentionForCausalLM,
StickBreakingAttentionModel,
TransformerForCausalLM,
TransformerModel,
)
Expand Down Expand Up @@ -105,6 +108,7 @@
'RodimusAttention', 'RodimusForCausalLM', 'RodimusModel',
'RWKV6Attention', 'RWKV6ForCausalLM', 'RWKV6Model',
'RWKV7Attention', 'RWKV7ForCausalLM', 'RWKV7Model',
'StickBreakingAttention', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel',
]

__version__ = '0.4.0'
2 changes: 2 additions & 0 deletions fla/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .rodimus import RodimusAttention, SlidingWindowSharedKeyAttention
from .rwkv6 import RWKV6Attention
from .rwkv7 import RWKV7Attention
from .stickbreaking_attn import StickBreakingAttention

__all__ = [
'ABCAttention',
Expand Down Expand Up @@ -61,6 +62,7 @@
'RodimusAttention',
'RWKV6Attention',
'RWKV7Attention',
'StickBreakingAttention',
'SlidingWindowSharedKeyAttention',
'DeltaFormerAttention',
]
108 changes: 108 additions & 0 deletions fla/layers/stickbreaking_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import torch
import torch.nn as nn
from einops import rearrange
from transformers.utils import logging

from fla.modules import RMSNorm
from fla.ops.stickbreaking_attn import parallel_stickbreaking_attn

if TYPE_CHECKING:
from fla.models.utils import Cache


logger = logging.get_logger(__name__)


class StickBreakingAttention(nn.Module):

def __init__(
self,
hidden_size: int = 2048,
num_heads: int = 32,
num_kv_heads: int | None = None,
qkv_bias: bool = False,
qk_norm: bool = False,
window_size: int | None = None,
max_position_embeddings: int | None = None,
layer_idx: int | None = None,
):
super().__init__()

if parallel_stickbreaking_attn is None:
raise ImportError(
"StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable.",
)

self.hidden_size = hidden_size
self.num_heads = num_heads
if num_kv_heads is None:
self.num_kv_heads = self.num_heads
else:
self.num_kv_heads = num_kv_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
Comment on lines +49 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Validate head divisibility and GQA/MQA support.

Avoid silent mis-sharding. Either implement KV grouping or gate it for now.

Apply this diff:

-        self.num_kv_groups = self.num_heads // self.num_kv_heads
-        self.head_dim = self.hidden_size // self.num_heads
+        if self.hidden_size % self.num_heads != 0:
+            raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
+        if self.num_heads % self.num_kv_heads != 0:
+            raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
+        self.num_kv_groups = self.num_heads // self.num_kv_heads
+        self.head_dim = self.hidden_size // self.num_heads

If GQA/MQA is not yet supported by the kernel, also add:

+        if self.num_kv_heads != self.num_heads:
+            raise NotImplementedError("GQA/MQA (num_kv_heads != num_heads) is not supported yet for StickBreakingAttention.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
if self.hidden_size % self.num_heads != 0:
raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
if self.num_heads % self.num_kv_heads != 0:
raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
if self.num_kv_heads != self.num_heads:
raise NotImplementedError("GQA/MQA (num_kv_heads != num_heads) is not supported yet for StickBreakingAttention.")
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias

self.qk_norm = qk_norm

self.window_size = window_size
self.max_position_embeddings = max_position_embeddings
self.layer_idx = layer_idx

self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)

if qk_norm:
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)

if use_cache:
warnings.warn(
"StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.")
Comment on lines +77 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not silently ignore attention_mask

We assert that an attention_mask is 2-D, but we never use it afterward. When padding is present in the batch, the kernel still allocates stick mass to those padded positions, so the outputs (and gradients) depend on garbage tokens. That violates the huggingface contract for attention layers and breaks training/inference with padded batches.

Until the kernel can consume the mask (e.g. by translating it to cu_seqlens or masking logits), we should not silently continue when zeros are present. Please either wire up proper masking or explicitly fail to avoid corrupt results. For example:

         if attention_mask is not None:
             assert len(attention_mask.shape) == 2, (
                 "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
                 "for padding purposes (0 indicating padding). "
                 "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
             )
+            valid = attention_mask.to(dtype=torch.bool)
+            if not torch.all(valid):
+                raise NotImplementedError(
+                    "Padding masks are not yet supported. Please supply packed sequences via `cu_seqlens` "
+                    "or implement masking before calling StickBreakingAttention."
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)
if use_cache:
warnings.warn(
"StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.")
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)
valid = attention_mask.to(dtype=torch.bool)
if not torch.all(valid):
raise NotImplementedError(
"Padding masks are not yet supported. Please supply packed sequences via `cu_seqlens` "
"or implement masking before calling StickBreakingAttention."
)
if use_cache:
warnings.warn(
"StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.")
🤖 Prompt for AI Agents
In fla/layers/stickbreaking_attn.py around lines 79-88, the code asserts
attention_mask is 2-D but never uses it so padded positions still receive stick
mass; either wire the mask into the kernel (e.g., convert [batch, seq_len] mask
into cu_seqlens or apply it to attention logits before the stick-breaking
kernel) or explicitly fail when any padding is present. Implement one of two
fixes: (1) translate the 0/1 attention_mask to the kernel-friendly format and
ensure masked positions are excluded from stick mass allocation, or (2) check
for any zeros in attention_mask and raise a clear ValueError (or
NotImplementedError) stating that StickBreakingAttention does not support padded
masks yet. Ensure the failure case includes guidance to use a mask-free batch or
a compatible attention implementation.

use_cache = False

batch_size, q_len, _ = hidden_states.size()

q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)

if self.qk_norm:
q, k = self.q_norm(q), self.k_norm(k)

cu_seqlens = kwargs.get('cu_seqlens')
o, _rem = parallel_stickbreaking_attn(
q=q,
k=k,
v=v,
cu_seqlens=cu_seqlens,
)
o = o.reshape(batch_size, q_len, -1)
o = self.o_proj(o)

return o, None, past_key_values
6 changes: 6 additions & 0 deletions fla/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model
from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel
from fla.models.stickbreaking_attn import (
StickBreakingAttentionConfig,
StickBreakingAttentionForCausalLM,
StickBreakingAttentionModel,
)
from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel

__all__ = [
Expand Down Expand Up @@ -63,4 +68,5 @@
'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model',
'SambaConfig', 'SambaForCausalLM', 'SambaModel',
'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
'StickBreakingAttentionConfig', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel',
]
15 changes: 15 additions & 0 deletions fla/models/stickbreaking_attn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from fla.models.stickbreaking_attn.configuration_stickbreaking_attn import StickBreakingAttentionConfig
from fla.models.stickbreaking_attn.modeling_stickbreaking_attn import (
StickBreakingAttentionForCausalLM,
StickBreakingAttentionModel,
)

AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig, exist_ok=True)
AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel, exist_ok=True)
AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM, exist_ok=True)


__all__ = ['StickBreakingAttentionConfig', 'StickBreakingAttentionForCausalLM', 'StickBreakingAttentionModel']
82 changes: 82 additions & 0 deletions fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import warnings

from transformers.configuration_utils import PretrainedConfig


class StickBreakingAttentionConfig(PretrainedConfig):

model_type = 'stickbreaking_attn'
keys_to_ignore_at_inference = ['past_key_values']

def __init__(
self,
hidden_size: int = 2048,
num_hidden_layers: int = 24,
num_heads: int = 32,
num_kv_heads: int | None = None,
qkv_bias: bool = False,
qk_norm: bool = False,
window_size: int | None = None,
max_position_embeddings: int = 2048,
hidden_ratio: int | None = 4,
intermediate_size: int | None = None,
hidden_act: str = "swish",
initializer_range: float = 0.02,
elementwise_affine: bool | None = True,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int | None = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
fuse_norm: bool = True,
fuse_swiglu: bool = True,
fuse_cross_entropy: bool = True,
fuse_linear_cross_entropy: bool = False,
use_l2warp: bool = False,
vocab_size: int = 32000,
**kwargs,
):
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.qkv_bias = qkv_bias
self.qk_norm = qk_norm
self.window_size = window_size
self.max_position_embeddings = max_position_embeddings

self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act

self.initializer_range = initializer_range
self.elementwise_affine = elementwise_affine
self.norm_eps = norm_eps
self.use_cache = use_cache

self.fuse_norm = fuse_norm
self.fuse_swiglu = fuse_swiglu
self.fuse_cross_entropy = fuse_cross_entropy
self.fuse_linear_cross_entropy = fuse_linear_cross_entropy
self.use_l2warp = use_l2warp
self.vocab_size = vocab_size

if fuse_cross_entropy and fuse_linear_cross_entropy:
raise ValueError(
"`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time.",
)
if fuse_linear_cross_entropy:
warnings.warn(
"`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
"at the potential cost of reduced precision. "
"If you observe issues like loss divergence, consider disabling this setting.",
)

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading
Loading