Skip to content

Conversation

@Nathancgy
Copy link
Contributor

@Nathancgy Nathancgy commented Sep 23, 2025

  • Added Stick-Breaking Attention model
  • Supports varlen
  • 27-28k throughput on python -m benchmarks.benchmark_training_throughput --name stickbreaking_attn --batch_size 8 --seq_len 4096
  • Referred to this implementation
  • FLA style: specific naming, (B, T, H, D) dim, varlen support through cu_seqlen
  • Passes both model and ops test

Summary by CodeRabbit

  • New Features

    • Introduced stick-breaking attention for improved modeling and generation.
    • Added causal language model support with HuggingFace Transformers integration.
    • Enabled fused normalization and fused cross-entropy optimizations for performance.
    • Added gradient checkpointing, KV caching support, and variable-length sequence handling.
  • Tests

    • Added comprehensive tests validating attention ops, model forward/backward, and generation.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 23, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds Stick-Breaking Attention support: Triton and naive kernel implementations, a new attention layer, HF Transformers config and model wrappers (encoder and causal LM), ops exports, and tests for kernels and modeling behavior.

Changes

Cohort / File(s) Summary
Public API exports
fla/__init__.py, fla/layers/__init__.py, fla/models/__init__.py, fla/ops/__init__.py, fla/ops/stickbreaking_attn/__init__.py
Export new public symbols: StickBreakingAttention, StickBreakingAttentionForCausalLM, StickBreakingAttentionModel, StickBreakingAttentionConfig, and rename/re-export kernel functions as parallel_stickbreaking_attn and naive_stickbreaking_attn via updated __all__ and import chains.
Core attention layer
fla/layers/stickbreaking_attn.py
New StickBreakingAttention layer: Q/K/V projections, optional q/k RMSNorm, forward that calls stick-breaking kernels (requires cu_seqlens), warns if KV caching requested, returns (output, None, past_key_values).
Kernel implementations
fla/ops/stickbreaking_attn/parallel.py, fla/ops/stickbreaking_attn/naive.py, fla/ops/stickbreaking_attn/softplus.py
Add Triton-based parallel implementation with forward/backward kernels and autograd (parallel_stickbreaking_attn and helper wrappers), a reference Python/torch naive_stickbreaking_attn, and a Triton inline-asm softplus kernel.
Models / HF integration
fla/models/stickbreaking_attn/..., fla/models/stickbreaking_attn/__init__.py, fla/models/__init__.py
New StickBreakingAttentionConfig (PretrainedConfig), StickBreakingAttentionBlock, StickBreakingAttentionPreTrainedModel, StickBreakingAttentionModel (encoder), and StickBreakingAttentionForCausalLM (causal LM wrapper). Register config and models with HF AutoConfig/AutoModel/AutoModelForCausalLM.
Configuration
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py
New config class with many hyperparameters and validation (mutual exclusivity of fuse_cross_entropy and fuse_linear_cross_entropy, warning on linear fuse).
Tests
tests/ops/test_stickbreaking_attn.py, tests/models/test_modeling_stickbreaking_attn.py, tests/models/test_modeling_utils.py
Add kernel-level tests comparing naive vs parallel (forward/backward), add modeling forward/backward and generation tests, and mark StickBreakingAttentionConfig as generation-unsupported in test utilities.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant Layer as StickBreakingAttention
    participant Kernels as parallel_stickbreaking_attn / naive
    participant ModelOut

    Caller->>Layer: forward(hidden_states, attention_mask?, cu_seqlens, ...)
    activate Layer

    Layer->>Layer: project Q, K, V
    Layer->>Layer: optional q/k RMSNorm
    alt use_cache requested
        Layer-->>Caller: emit warning (KV caching unsupported)
    end

    Layer->>Kernels: call parallel_stickbreaking_attn(q,k,v,scale,cu_seqlens)
    activate Kernels

    Kernels->>Kernels: compute logits (log-space), apply triangular mask
    Kernels->>Kernels: compute log_z/log_beta → attention weights
    Kernels-->>Layer: return (o, rem)
    deactivate Kernels

    Layer->>Layer: reshape & output projection
    Layer-->>ModelOut: return (output, None, past_key_values)
    deactivate Layer
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Review focus:
    • Triton kernels and autograd wrappers in fla/ops/stickbreaking_attn/parallel.py (numerical stability, grid dispatch, cu_seqlens handling).
    • Gradient correctness and dtype handling across naive vs parallel paths.
    • HF integration and model registration in fla/models/stickbreaking_attn/__init__.py.
    • Loss fusion logic and config validation in modeling code and StickBreakingAttentionConfig.

Possibly related PRs

  • [Testing] Enhance generation testing #468 — Updates test generation utilities used by the new modeling/generation tests (run_test_generation signature and helpers); directly related to the new tests' invocation patterns.

Poem

🐰 Hop-hop, kernels tuned in line,

Triton threads and logits refine,
Configs and models snug in place,
Gradients dance, softplus keeps pace,
A carrot for builds — attention's prime! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 3.23% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Stick-Breaking Attention] Add Model' directly relates to the main changeset, which adds a complete Stick-Breaking Attention model implementation including layers, models, configuration, and tests.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link

Summary of Changes

Hello @Nathancgy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the novel Stick-Breaking Attention model, providing an efficient and scalable attention mechanism. The implementation adheres to the project's style guidelines, supports variable-length sequences, and has been benchmarked for competitive throughput. It includes both a naive reference implementation and an optimized Triton kernel, alongside comprehensive tests to ensure correctness and stability across various scenarios.

Highlights

  • New Model Integration: The 'Stick-Breaking Attention' model, based on a recent research paper, has been added to the framework.
  • Varlen Support: The newly integrated attention mechanism fully supports variable-length sequences, enhancing its flexibility and applicability.
  • Performance Benchmarking: The model achieves a throughput of 27-28k on specified training benchmarks, indicating an efficient implementation.
  • Comprehensive Testing: The new model and its underlying operations have successfully passed both model-level and operations-level tests, ensuring correctness.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Stick-Breaking Attention model, including the attention layer, model definition, configuration, and optimized Triton kernels. The implementation is comprehensive and includes both a naive version for reference and a parallelized version for performance, along with corresponding tests.

My review focuses on improving code clarity and correctness. I've identified a critical issue with a type hint in a constructor, and several medium-severity issues related to unused parameters, dead code, and confusing code constructs. These changes will make the new model's code cleaner and more maintainable.

def __init__(
self,
config: StickBreakingAttentionConfig
) -> StickBreakingAttentionModel:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The __init__ method in Python should have a return type hint of None. Please remove the -> StickBreakingAttentionModel annotation.

Suggested change
) -> StickBreakingAttentionModel:
):

Comment on lines 34 to 61
window_size: Optional[int] = None,
rope_theta: Optional[float] = None, # sba doesn't use RoPE
max_position_embeddings: Optional[int] = None,
layer_idx: int | None = None,
):
super().__init__()

if sb_attn is None:
raise ImportError(
"StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable."
)

self.hidden_size = hidden_size
self.num_heads = num_heads
if num_kv_heads is None:
self.num_kv_heads = self.num_heads
else:
self.num_kv_heads = num_kv_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
self.qk_norm = qk_norm

self.window_size = window_size
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.layer_idx = layer_idx

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters window_size, rope_theta, max_position_embeddings, and layer_idx are accepted in the constructor (lines 34-37) and stored as attributes (lines 58-61), but they are never used within the StickBreakingAttention class. This can be misleading and adds unnecessary clutter to the API. It's best to remove these unused parameters and their assignments to keep the code clean and maintainable.

Comment on lines 22 to 49
window_size: Optional[int] = None,
max_position_embeddings: int = 2048,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
hidden_act: str = "swish",
initializer_range: float = 0.02,
elementwise_affine: Optional[bool] = True,
norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: Optional[int] = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
fuse_norm: bool = True,
fuse_swiglu: bool = True,
fuse_cross_entropy: bool = True,
fuse_linear_cross_entropy: bool = False,
use_l2warp: bool = False,
vocab_size: int = 32000,
**kwargs,
):
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.qkv_bias = qkv_bias
self.qk_norm = qk_norm
self.window_size = window_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The window_size parameter is defined in the configuration (line 22) and assigned (line 49), but it is not used by the StickBreakingAttention layer. It should be removed to avoid confusion.

Comment on lines 51 to 54
window_size=config.window_size,
rope_theta=None,
max_position_embeddings=config.max_position_embeddings,
layer_idx=layer_idx

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters window_size, rope_theta, max_position_embeddings, and layer_idx are not used by StickBreakingAttention. They should be removed from this instantiation to align with the refactoring of the StickBreakingAttention layer.

Comment on lines 9 to 13
def _tril_mask(T: int, strict: bool = True, device=None) -> torch.Tensor:
i = torch.arange(T, device=device).view(1, 1, T, 1)
j = torch.arange(T, device=device).view(1, 1, 1, T)
return (j < i) if strict else (j <= i)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function _tril_mask is defined but never used in this file. It appears to be dead code and should be removed.

batch_id = 0 if IS_VARLEN else tl.program_id(0)
head_pid = tl.program_id(1)
prog_id = tl.program_id(2)
tl.num_programs(2)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This call to tl.num_programs(2) has no effect as its return value is not used. It seems like a leftover from debugging and should be removed to improve code clarity.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (25)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (3)

11-12: Annotate class attributes with ClassVar to satisfy linters.

Use typing.ClassVar for model_type and keys_to_ignore_at_inference (RUF012).

Apply this diff:

-class StickBreakingAttentionConfig(PretrainedConfig):
+class StickBreakingAttentionConfig(PretrainedConfig):
@@
-    model_type = 'stickbreaking_attn'
-    keys_to_ignore_at_inference = ['past_key_values']
+    model_type: ClassVar[str] = 'stickbreaking_attn'
+    keys_to_ignore_at_inference: ClassVar[List[str]] = ['past_key_values']

4-4: Add missing imports for ClassVar/List.

Needed for the ClassVar annotations above.

Apply this diff:

-from typing import Optional
+from typing import Optional, ClassVar, List

73-77: Add stacklevel to warnings and fix grammar.

Include stacklevel=2 (B028) and fix “can improves” → “can improve”.

Apply this diff:

-            warnings.warn(
-                "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
-                "at the potential cost of reduced precision. "
-                "If you observe issues like loss divergence, consider disabling this setting."
-            )
+            warnings.warn(
+                "`fuse_linear_cross_entropy` is enabled, which can improve memory efficiency "
+                "at the potential cost of reduced precision. "
+                "If you observe issues like loss divergence, consider disabling this setting.",
+                stacklevel=2,
+            )
fla/ops/__init__.py (1)

29-31: Gracefully handle missing Triton/compiled ops at import time.

Package import will fail if parallel backend isn’t available. Make the export optional to allow CPU-only environments to import higher-level modules.

Apply this diff:

-from .stickbreaking_attn.naive import sb_attn_naive
-from .stickbreaking_attn.parallel import sb_attn
+try:
+    from .stickbreaking_attn.parallel import sb_attn
+except Exception:
+    sb_attn = None
+from .stickbreaking_attn.naive import sb_attn_naive
tests/ops/test_stickbreaking_attn.py (3)

52-55: Also verify attend_current=True path.

Add a second forward/backward block with attend_current=True to cover diagonal-included semantics.

Apply this diff:

-    # Triton fused
-    tri_o, tri_rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=False)
+    # Triton fused
+    tri_o, tri_rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=False)
     (tri_o * do).sum().backward(retain_graph=True)
     (tri_rem * dr).sum().backward()
     tri_dq, q.grad = q.grad.clone(), None
     tri_dk, k.grad = k.grad.clone(), None
     tri_dv, v.grad = v.grad.clone(), None
+
+    # Triton fused (attend_current=True)
+    q.grad = k.grad = v.grad = None
+    tri_o_diag, tri_rem_diag = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=True)
+    (tri_o_diag * do).sum().backward(retain_graph=True)
+    (tri_rem_diag * dr).sum().backward()
+    tri_dq_diag, q.grad = q.grad.clone(), None
+    tri_dk_diag, k.grad = k.grad.clone(), None
+    tri_dv_diag, v.grad = v.grad.clone(), None

60-65: Add checks for attend_current=True parity.

Compare naïve vs fused also for attend_current=True to lock in both behaviors.

Apply this diff:

-    assert_close(" o", ref_o, tri_o, 0.008)
-    assert_close("rem", ref_rem, tri_rem, 0.02)
-    assert_close("dq", ref_dq, tri_dq, 0.02)
-    assert_close("dk", ref_dk, tri_dk, 0.02)
-    assert_close("dv", ref_dv, tri_dv, 0.02)
+    assert_close(" o", ref_o, tri_o, 0.008)
+    assert_close("rem", ref_rem, tri_rem, 0.02)
+    assert_close("dq", ref_dq, tri_dq, 0.02)
+    assert_close("dk", ref_dk, tri_dk, 0.02)
+    assert_close("dv", ref_dv, tri_dv, 0.02)
+
+    # Reference (naive, attend_current=True)
+    q.grad = k.grad = v.grad = None
+    ref_o_diag, ref_rem_diag = sb_attn_naive(q, k, v, inv_temp, attend_current=True)
+    (ref_o_diag * do).sum().backward(retain_graph=True)
+    (ref_rem_diag * dr).sum().backward()
+    ref_dq_diag, q.grad = q.grad.clone(), None
+    ref_dk_diag, k.grad = k.grad.clone(), None
+    ref_dv_diag, v.grad = v.grad.clone(), None
+
+    assert_close(" o(diag)", ref_o_diag, tri_o_diag, 0.008)
+    assert_close("rem(diag)", ref_rem_diag, tri_rem_diag, 0.02)
+    assert_close("dq(diag)", ref_dq_diag, tri_dq_diag, 0.02)
+    assert_close("dk(diag)", ref_dk_diag, tri_dk_diag, 0.02)
+    assert_close("dv(diag)", ref_dv_diag, tri_dv_diag, 0.02)

12-21: Consider adding a varlen test using cu_seqlens.

A small param set with padded sequences and cu_seqlens would catch IS_VARLEN paths.

Would you like me to add a minimal varlen test that builds cu_seqlens from an attention_mask and compares fused vs naive on concatenated sequences?

fla/ops/stickbreaking_attn/naive.py (4)

9-13: Remove or use _tril_mask; currently dead code.

Apply this diff (remove it if unused):

-def _tril_mask(T: int, strict: bool = True, device=None) -> torch.Tensor:
-    i = torch.arange(T, device=device).view(1, 1, T, 1)
-    j = torch.arange(T, device=device).view(1, 1, 1, T)
-    return (j < i) if strict else (j <= i)
+

34-36: Silence Ruff: unused unpacked variables.

Prefix unused with underscores.

Apply this diff:

-    B, T, H, D = q.shape
+    _B, T, _H, _D = q.shape

40-45: Fix misleading comments and simplify mask construction.

Comments invert semantics; also avoid creating an all-ones tensor first.

Apply this diff:

-    if attend_current:
-        mask = torch.ones(T, T, device=q.device).triu(1).bool()  # exclude diagonal
-    else:
-        mask = torch.ones(T, T, device=q.device).triu(0).bool()  # include diagonal
+    # attend_current=True: allow diagonal (strictly causal, incl diag)
+    # attend_current=False: exclude diagonal (strictly causal, strictly lower)
+    mask = torch.triu(torch.ones(T, T, device=q.device, dtype=torch.bool), diagonal=1 if attend_current else 0)
     mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, T, T]

37-55: Keep math in float32; convert at outputs only.

Avoids bf16 underflow/precision loss in reference path; improves parity with fused kernel.

Apply this diff:

-    logits = torch.einsum('bthd,bshd->bhts', q, k) * inv_temp
-    logits = logits.float()
+    logits = torch.einsum('bthd,bshd->bhts', q, k).to(torch.float32) * inv_temp
@@
-    log_z = torch.nn.functional.logsigmoid(logits).masked_fill(mask, -1e5).to(orig_dtype)
-    log_beta = torch.nn.functional.logsigmoid(-logits).masked_fill(mask, 0).to(orig_dtype)
+    log_z = torch.nn.functional.logsigmoid(logits).masked_fill(mask, -1e5)
+    log_beta = torch.nn.functional.logsigmoid(-logits).masked_fill(mask, 0.0)
@@
-    cum_weight = torch.ones(T, T, device=q.device).tril(-1)
+    cum_weight = torch.tril(torch.ones(T, T, device=q.device, dtype=log_beta.dtype), diagonal=-1)
@@
-    re_cum_log_beta = torch.einsum("bhij,jk->bhik", log_beta, cum_weight.to(log_beta))
+    re_cum_log_beta = torch.einsum("bhij,jk->bhik", log_beta, cum_weight)
@@
-    return o.to(orig_dtype), rem.to(orig_dtype)
+    return o.to(orig_dtype), rem.to(orig_dtype)
fla/ops/stickbreaking_attn/parallel.py (2)

573-612: Add lightweight input validation and ensure contiguity before launching kernels.

Prevents subtle stride bugs and early catches shape mismatches.

Apply this diff:

-    batch_size, token_size, num_heads, dim_size = q.size()
+    assert q.shape == k.shape == v.shape, "q, k, v must have the same shape [B, T, H, D]"
+    assert q.dim() == 4, "q, k, v must be rank-4 tensors [B, T, H, D]"
+    # Ensure expected memory layout for pointer arithmetic
+    if not q.is_contiguous(): q = q.contiguous()
+    if not k.is_contiguous(): k = k.contiguous()
+    if not v.is_contiguous(): v = v.contiguous()
+    batch_size, token_size, num_heads, dim_size = q.size()

399-401: Drop unused is_compiling arg (or use it).

Currently unused; remove to silence linters.

Apply this diff:

-@triton.jit
-def softplus(x, is_compiling: tl.constexpr = False):
+@triton.jit
+def softplus(x):
     return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x)

Note: also remove corresponding parameter at call sites if present.

fla/layers/stickbreaking_attn.py (4)

41-45: Remove ineffective kernel-availability check.

Import would already fail if sb_attn is missing; sb_attn is never None here.

Apply this diff:

-        if sb_attn is None:
-            raise ImportError(
-                "StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable."
-            )
+        # Triton kernels are imported at module import-time; failures surface earlier.

89-93: Add stacklevel to warning for accurate caller attribution.

Apply this diff:

-        if use_cache:
-            warnings.warn(
-                "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.")
+        if use_cache:
+            warnings.warn(
+                "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.",
+                stacklevel=2
+            )

96-101: q/k normalization is fine. Consider folding norms for perf (optional).

If RMSNorm is cheap enough, this is OK. If perf matters, consider fusing q_proj+k_proj with norm upstream.


103-107: Respect attention_mask by deriving cu_seqlens when not provided.

Currently attention_mask is validated but unused. Provide a fallback to varlen automatically.

Apply this diff:

-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens', None)
+        if cu_seqlens is None and attention_mask is not None:
+            # Build cu_seqlens from 0/1 padding mask: shape [B,T] with 1=valid
+            lengths = attention_mask.to(torch.int32).sum(dim=1)
+            cu_seqlens = torch.zeros((lengths.numel() + 1,), dtype=torch.int32, device=lengths.device)
+            cu_seqlens[1:] = torch.cumsum(lengths, dim=0)
         o, _rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=attend_current, cu_seqlens=cu_seqlens)

If this is intentional (mask not supported), consider throwing if attention_mask is provided without cu_seqlens.

fla/models/stickbreaking_attn/__init__.py (1)

11-13: Guard HF Auto registration for older Transformers (<4.33.0)

exist_ok was added in Transformers v4.33.0 (PR #25779, Aug 29 2023). Add a try/except TypeError fallback to call register() without exist_ok for older versions.

File: fla/models/stickbreaking_attn/init.py (lines 11-13)

-AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig, exist_ok=True)
-AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel, exist_ok=True)
-AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM, exist_ok=True)
+try:
+    AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig, exist_ok=True)
+    AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel, exist_ok=True)
+    AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM, exist_ok=True)
+except TypeError:
+    # Fallback for older transformers without `exist_ok`
+    AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig)
+    AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel)
+    AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (7)

7-11: Import Unpack from typing (and add ClassVar) instead of transformers.processing_utils.

Unpack is a typing construct; importing it from transformers.processing_utils is incorrect and can break type checking. Also add ClassVar here for class attribute annotations below.

-from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Tuple, Union, Unpack
-
-if TYPE_CHECKING:
-    from transformers.processing_utils import Unpack

Also applies to: 24-26


108-113: Annotate mutable class attributes with ClassVar to satisfy linters and intent.

This addresses RUF012 and clarifies that these are class-level constants.

 class StickBreakingAttentionPreTrainedModel(PreTrainedModel):
 
     config_class = StickBreakingAttentionConfig
-    base_model_prefix = 'model'
-    supports_gradient_checkpointing = True
-    _no_split_modules = ['StickBreakingAttentionBlock']
-    _supports_cache_class = True
+    base_model_prefix: ClassVar[str] = 'model'
+    supports_gradient_checkpointing: ClassVar[bool] = True
+    _no_split_modules: ClassVar[List[str]] = ['StickBreakingAttentionBlock']
+    _supports_cache_class: ClassVar[bool] = True
@@
 class StickBreakingAttentionForCausalLM(StickBreakingAttentionPreTrainedModel, FLAGenerationMixin):
 
-    _tied_weights_keys = ["lm_head.weight"]
+    _tied_weights_keys: ClassVar[List[str]] = ["lm_head.weight"]

Also applies to: 249-251


44-45: Respect config.elementwise_affine in RMSNorm instantiation.

The config exposes elementwise_affine but it’s not passed, changing intended behavior for users toggling that flag.

-        self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
+        self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(
+            config.hidden_size, elementwise_affine=config.elementwise_affine, eps=config.norm_eps
+        )
@@
-        self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
+        self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(
+            config.hidden_size, elementwise_affine=config.elementwise_affine, eps=config.norm_eps
+        )

Also applies to: 57-58


184-189: Add stacklevel to warnings.warn for accurate caller context.

Improves debuggability and addresses B028.

-            warnings.warn(
-                "`sba` does not support output attention weights now, so `output_attentions` is set to `False`."
-            )
+            warnings.warn(
+                "`sba` does not support output attention weights now, so `output_attentions` is set to `False`.",
+                stacklevel=2,
+            )

182-182: Correct return type annotation for base model forward.

The base model returns BaseModelOutputWithPast, not CausalLMOutputWithPast.

-    ) -> Union[Tuple, CausalLMOutputWithPast]:
+    ) -> Union[Tuple, BaseModelOutputWithPast]:

335-336: Use tuple unpacking instead of tuple concatenation.

Concise and matches Ruff suggestion (RUF005).

-        if not return_dict:
-            output = (logits,) + outputs[1:]
-            return (loss,) + output if loss is not None else output
+        if not return_dict:
+            output = (logits, *outputs[1:])
+            return (loss, *output) if loss is not None else output

66-75: Tighten block forward return type annotation.

The block can return 1–3 elements depending on flags. Reflect this for clarity.

-    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+    ) -> Tuple[
+        torch.FloatTensor,
+        Optional[torch.FloatTensor],
+        Optional[Tuple[torch.FloatTensor]],
+    ]:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 195b74d and db7411a.

📒 Files selected for processing (14)
  • fla/__init__.py (3 hunks)
  • fla/layers/__init__.py (2 hunks)
  • fla/layers/stickbreaking_attn.py (1 hunks)
  • fla/models/__init__.py (2 hunks)
  • fla/models/stickbreaking_attn/__init__.py (1 hunks)
  • fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1 hunks)
  • fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (1 hunks)
  • fla/ops/__init__.py (2 hunks)
  • fla/ops/stickbreaking_attn/__init__.py (1 hunks)
  • fla/ops/stickbreaking_attn/naive.py (1 hunks)
  • fla/ops/stickbreaking_attn/parallel.py (1 hunks)
  • tests/models/test_modeling_stickbreaking_attn.py (1 hunks)
  • tests/models/test_modeling_utils.py (1 hunks)
  • tests/ops/test_stickbreaking_attn.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (11)
fla/ops/__init__.py (2)
fla/ops/stickbreaking_attn/naive.py (1)
  • sb_attn_naive (15-57)
fla/ops/stickbreaking_attn/parallel.py (1)
  • sb_attn (705-713)
fla/models/stickbreaking_attn/__init__.py (2)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
  • StickBreakingAttentionConfig (9-85)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)
  • StickBreakingAttentionForCausalLM (247-344)
  • StickBreakingAttentionModel (144-244)
fla/layers/stickbreaking_attn.py (3)
fla/modules/layernorm.py (1)
  • RMSNorm (1064-1111)
fla/ops/stickbreaking_attn/parallel.py (2)
  • sb_attn (705-713)
  • forward (682-696)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)
  • forward (66-103)
  • forward (171-244)
  • forward (279-344)
tests/ops/test_stickbreaking_attn.py (3)
fla/ops/stickbreaking_attn/parallel.py (2)
  • sb_attn (705-713)
  • backward (699-702)
fla/ops/stickbreaking_attn/naive.py (1)
  • sb_attn_naive (15-57)
fla/utils.py (1)
  • assert_close (82-93)
tests/models/test_modeling_stickbreaking_attn.py (2)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
  • StickBreakingAttentionConfig (9-85)
tests/models/test_modeling_base.py (2)
  • run_test_generation (67-126)
  • run_test_model_forward_backward (27-61)
fla/ops/stickbreaking_attn/__init__.py (2)
fla/ops/stickbreaking_attn/naive.py (1)
  • sb_attn_naive (15-57)
fla/ops/stickbreaking_attn/parallel.py (1)
  • sb_attn (705-713)
fla/__init__.py (2)
fla/layers/stickbreaking_attn.py (1)
  • StickBreakingAttention (25-112)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)
  • StickBreakingAttentionForCausalLM (247-344)
  • StickBreakingAttentionModel (144-244)
fla/layers/__init__.py (1)
fla/layers/stickbreaking_attn.py (1)
  • StickBreakingAttention (25-112)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (8)
fla/layers/stickbreaking_attn.py (2)
  • StickBreakingAttention (25-112)
  • forward (72-112)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
  • StickBreakingAttentionConfig (9-85)
fla/models/utils.py (1)
  • FLAGenerationMixin (385-462)
fla/modules/fused_cross_entropy.py (1)
  • FusedCrossEntropyLoss (344-419)
fla/modules/fused_linear_cross_entropy.py (1)
  • FusedLinearCrossEntropyLoss (493-567)
fla/modules/mlp.py (1)
  • GatedMLP (26-69)
fla/modules/layernorm.py (1)
  • RMSNorm (1064-1111)
fla/models/modeling_layers.py (1)
  • GradientCheckpointingLayer (11-71)
fla/ops/stickbreaking_attn/parallel.py (2)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (116-121)
fla/layers/stickbreaking_attn.py (1)
  • forward (72-112)
fla/models/__init__.py (2)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
  • StickBreakingAttentionConfig (9-85)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)
  • StickBreakingAttentionForCausalLM (247-344)
  • StickBreakingAttentionModel (144-244)
🪛 Ruff (0.13.1)
fla/layers/stickbreaking_attn.py

42-44: Avoid specifying long messages outside the exception class

(TRY003)


77-77: Unused method argument: output_attentions

(ARG002)


90-90: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

fla/ops/stickbreaking_attn/naive.py

34-34: Unpacked variable B is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


34-34: Unpacked variable H is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


34-34: Unpacked variable D is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py

12-12: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


69-71: Avoid specifying long messages outside the exception class

(TRY003)


73-73: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py

111-111: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


184-184: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


194-194: Avoid specifying long messages outside the exception class

(TRY003)


196-196: Avoid specifying long messages outside the exception class

(TRY003)


249-249: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


335-335: Consider (logits, *outputs[1:]) instead of concatenation

Replace with (logits, *outputs[1:])

(RUF005)


336-336: Consider (loss, *output) instead of concatenation

Replace with (loss, *output)

(RUF005)

fla/ops/stickbreaking_attn/parallel.py

34-34: Unused function argument: batch_size

(ARG001)


136-136: Unused function argument: NO_N_MASK

(ARG001)


140-140: Unused function argument: no_grad

(ARG001)


142-142: Unused function argument: return_attention

(ARG001)


399-399: Unused function argument: is_compiling

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (15)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)

68-71: Good defensive check for mutually exclusive flags.

The exclusivity guard between fuse_cross_entropy and fuse_linear_cross_entropy is correct.

fla/ops/__init__.py (1)

55-55: Exports look consistent.

'sb_attn', 'sb_attn_naive' align with new modules.

tests/models/test_modeling_utils.py (1)

27-27: Correctly excluded from generation.

Adding StickBreakingAttentionConfig to GENERATION_UNSUPPORTED matches the current lack of KV cache support.

fla/layers/__init__.py (2)

33-33: Expose StickBreakingAttention in layers API.

Import is correct and aligns with public surface.


64-65: all updated appropriately.

Exporting StickBreakingAttention is consistent.

fla/ops/stickbreaking_attn/__init__.py (1)

1-9: Clean re-exports for ops.

Simple and correct; matches package structure.

fla/models/__init__.py (2)

34-38: Models API exposure looks good.

Imports for StickBreakingAttention family match modeling package.


70-71: all includes new StickBreakingAttention symbols.

Consistent with imports.

fla/__init__.py (3)

29-31: Layers import list updated correctly.

Adding StickBreakingAttention to the high-level API is consistent with other attention types.


79-81: Models import list updated correctly.

Matches new modeling classes.


112-113: all now exports StickBreakingAttention family.

Looks good.

tests/models/test_modeling_stickbreaking_attn.py (2)

14-24: Parametrization covers key cases.

BF16 variants and D=128 (skipped on non-Hopper) provide reasonable coverage.


40-57: Generation test will be skipped by design.

Since StickBreakingAttentionConfig is in GENERATION_UNSUPPORTED, this test will skip. Confirm that’s intended until KV cache is implemented.

fla/models/stickbreaking_attn/__init__.py (1)

16-16: LGTM: public exports are consistent with registrations.

fla/ops/stickbreaking_attn/parallel.py (1)

49-66: Varlen path: confirm correctness with multi-sample cu_seqlens.

Batch is forced to 0 in IS_VARLEN. Verify that prepare_chunk_indices and CU_ptr indexing assume [B,T,...] layout and not flattened [N,...]. Add a dedicated varlen test.

I can add a small test that builds cu_seqlens for a batch with unequal lengths and validates parity against a padded dense run. Do you want me to push that?

Comment on lines +52 to +55
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Validate head divisibility and GQA/MQA support.

Avoid silent mis-sharding. Either implement KV grouping or gate it for now.

Apply this diff:

-        self.num_kv_groups = self.num_heads // self.num_kv_heads
-        self.head_dim = self.hidden_size // self.num_heads
+        if self.hidden_size % self.num_heads != 0:
+            raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
+        if self.num_heads % self.num_kv_heads != 0:
+            raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
+        self.num_kv_groups = self.num_heads // self.num_kv_heads
+        self.head_dim = self.hidden_size // self.num_heads

If GQA/MQA is not yet supported by the kernel, also add:

+        if self.num_kv_heads != self.num_heads:
+            raise NotImplementedError("GQA/MQA (num_kv_heads != num_heads) is not supported yet for StickBreakingAttention.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
if self.hidden_size % self.num_heads != 0:
raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
if self.num_heads % self.num_kv_heads != 0:
raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = self.hidden_size // self.num_heads
if self.num_kv_heads != self.num_heads:
raise NotImplementedError("GQA/MQA (num_kv_heads != num_heads) is not supported yet for StickBreakingAttention.")
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias

Comment on lines +313 to +333
logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])

loss = None
if labels is not None:
if getattr(self, 'criterion', None) is None:
if self.config.fuse_linear_cross_entropy:
criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
elif self.config.fuse_cross_entropy:
criterion = FusedCrossEntropyLoss(inplace_backward=True)
else:
criterion = nn.CrossEntropyLoss()
else:
criterion = self.criterion
labels = labels.to(hidden_states.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
if self.config.fuse_linear_cross_entropy:
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
else:
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
loss = l2_warp(loss, logits) if self.config.use_l2warp else loss

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix loss/labels alignment when logits_to_keep > 0 (shape and indexing bug).

When labels is provided and logits_to_keep > 0, logits is sliced to the last K tokens but labels remains length T after the shift. Reshaping logits to labels.numel() also silently misaligns data. This will produce incorrect training signals or runtime errors under common usages where only the last K logits are kept.

Align the labels with the kept logits and reshape based on logits’ actual size.

         loss = None
         if labels is not None:
             if getattr(self, 'criterion', None) is None:
                 if self.config.fuse_linear_cross_entropy:
                     criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
                 elif self.config.fuse_cross_entropy:
                     criterion = FusedCrossEntropyLoss(inplace_backward=True)
                 else:
                     criterion = nn.CrossEntropyLoss()
             else:
                 criterion = self.criterion
-            labels = labels.to(hidden_states.device)
-            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
+            labels = labels.to(hidden_states.device)
+            # Next-token prediction: shift labels left and pad last position with ignore_index
+            shifted_labels = torch.cat(
+                (labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)),
+                dim=-1,
+            )
             if self.config.fuse_linear_cross_entropy:
-                loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
+                # Fused linear CE expects [B, T, H] and [B, T]
+                loss = criterion(hidden_states, shifted_labels, self.lm_head.weight, self.lm_head.bias)
             else:
-                loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
+                # Keep labels only where logits are computed
+                if logits_to_keep:
+                    shifted_labels = shifted_labels[:, -logits_to_keep:]
+                # logits: [B, K, V], shifted_labels: [B, K]
+                loss = criterion(logits.reshape(-1, logits.size(-1)), shifted_labels.reshape(-1))
                 loss = l2_warp(loss, logits) if self.config.use_l2warp else loss
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
loss = None
if labels is not None:
if getattr(self, 'criterion', None) is None:
if self.config.fuse_linear_cross_entropy:
criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
elif self.config.fuse_cross_entropy:
criterion = FusedCrossEntropyLoss(inplace_backward=True)
else:
criterion = nn.CrossEntropyLoss()
else:
criterion = self.criterion
labels = labels.to(hidden_states.device)
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
if self.config.fuse_linear_cross_entropy:
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
else:
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
loss = l2_warp(loss, logits) if self.config.use_l2warp else loss
logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
loss = None
if labels is not None:
if getattr(self, 'criterion', None) is None:
if self.config.fuse_linear_cross_entropy:
criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
elif self.config.fuse_cross_entropy:
criterion = FusedCrossEntropyLoss(inplace_backward=True)
else:
criterion = nn.CrossEntropyLoss()
else:
criterion = self.criterion
labels = labels.to(hidden_states.device)
# Next-token prediction: shift labels left and pad last position with ignore_index
shifted_labels = torch.cat(
(labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)),
dim=-1,
)
if self.config.fuse_linear_cross_entropy:
# Fused linear CE expects [B, T, H] and [B, T]
loss = criterion(hidden_states, shifted_labels, self.lm_head.weight, self.lm_head.bias)
else:
# Keep labels only where logits are computed
if logits_to_keep:
shifted_labels = shifted_labels[:, -logits_to_keep:]
# logits: [B, K, V], shifted_labels: [B, K]
loss = criterion(logits.reshape(-1, logits.size(-1)), shifted_labels.reshape(-1))
loss = l2_warp(loss, logits) if self.config.use_l2warp else loss
🤖 Prompt for AI Agents
In fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py around lines 313
to 333, the labels are not trimmed to match logits when logits_to_keep > 0
causing misalignment and wrong loss shapes; after shifting labels (dropping
first token and appending ignore_index) slice the shifted labels to keep only
the last K positions matching logits (e.g. labels = labels[...,
-logits.size(1):]) and use logits.view(-1, logits.size(-1)) or
logits.numel()/logits.size(-1) to reshape consistently; ensure labels are on the
same device before slicing and apply the same l2_warp logic only after correct
alignment so loss uses logits.shape and labels.numel() that correspond to each
other.

@yzhangcs yzhangcs linked an issue Sep 23, 2025 that may be closed by this pull request
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)

47-56: Remove unused parameters from StickBreakingAttention instantiation.

The parameters window_size, max_position_embeddings, and layer_idx are passed to StickBreakingAttention but are not used by that layer (they are stored but never referenced in the forward pass). This creates unnecessary API surface and potential confusion.

Based on past review comments, these parameters should be removed from the StickBreakingAttention layer itself, and consequently from this instantiation as well.


150-150: Fix incorrect return type annotation on __init__.

The __init__ method should have a return type of None, not -> StickBreakingAttentionModel.

Apply this diff:

     def __init__(
         self,
         config: StickBreakingAttentionConfig
-    ) -> StickBreakingAttentionModel:
+    ):

314-333: CRITICAL: Fix loss/labels shape mismatch when logits_to_keep > 0.

When logits_to_keep > 0, logits are sliced to [B, K, V] (Line 314) but labels remain at full sequence length [B, T] after shifting (Line 328). Line 332 attempts logits.view(labels.numel(), -1) which creates a shape mismatch since B*K ≠ B*T, causing either runtime errors or silently incorrect loss computation.

The shifted labels must be sliced to match the logits when logits_to_keep > 0:

         loss = None
         if labels is not None:
             if getattr(self, 'criterion', None) is None:
                 if self.config.fuse_linear_cross_entropy:
                     criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
                 elif self.config.fuse_cross_entropy:
                     criterion = FusedCrossEntropyLoss(inplace_backward=True)
                 else:
                     criterion = nn.CrossEntropyLoss()
             else:
                 criterion = self.criterion
             labels = labels.to(hidden_states.device)
-            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
+            # Shift labels for next-token prediction
+            shifted_labels = torch.cat(
+                (labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)),
+                dim=1
+            )
             if self.config.fuse_linear_cross_entropy:
-                loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
+                loss = criterion(hidden_states, shifted_labels, self.lm_head.weight, self.lm_head.bias)
             else:
-                loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
+                # Slice labels to match logits when only keeping last K positions
+                if logits_to_keep > 0:
+                    shifted_labels = shifted_labels[:, -logits_to_keep:]
+                loss = criterion(logits.reshape(-1, logits.size(-1)), shifted_labels.reshape(-1))
                 loss = l2_warp(loss, logits) if self.config.use_l2warp else loss
🧹 Nitpick comments (4)
fla/layers/stickbreaking_attn.py (2)

75-75: Remove unused output_attentions parameter.

The output_attentions parameter is accepted but never used in the forward method. Since the implementation always returns None for attentions (Line 108), this parameter should be removed from the signature to avoid confusion.

Apply this diff:

     def forward(
         self,
         hidden_states: torch.Tensor,
         attention_mask: Optional[torch.LongTensor] = None,
         past_key_values: Optional[Cache] = None,
-        output_attentions: bool = False,
         use_cache: bool = False,
         attend_current: bool = False,
         **kwargs,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

87-90: Add stacklevel=2 to warnings.warn call.

The warning should include stacklevel=2 so the warning message points to the caller's location rather than this module.

Apply this diff:

         if use_cache:
             warnings.warn(
-                "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.")
+                "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.",
+                stacklevel=2)
             use_cache = False
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)

184-188: Add stacklevel=2 to warnings.warn call.

The warning should include stacklevel=2 so the warning message points to the caller's location rather than this module.

Apply this diff:

         if output_attentions:
             warnings.warn(
-                "`sba` does not support output attention weights now, so `output_attentions` is set to `False`."
+                "`sba` does not support output attention weights now, so `output_attentions` is set to `False`.",
+                stacklevel=2
             )
             output_attentions = False

112-112: Annotate mutable class attributes with ClassVar.

The _no_split_modules (Line 112) and _tied_weights_keys (Line 250) are class-level configuration attributes that should be annotated with typing.ClassVar to clarify they are not instance attributes.

Add the import and annotations:

+from typing import ClassVar
+
 class StickBreakingAttentionPreTrainedModel(PreTrainedModel):
 
     config_class = StickBreakingAttentionConfig
     base_model_prefix = 'model'
     supports_gradient_checkpointing = True
-    _no_split_modules = ['StickBreakingAttentionBlock']
+    _no_split_modules: ClassVar[list[str]] = ['StickBreakingAttentionBlock']
     _supports_cache_class = True
 class StickBreakingAttentionForCausalLM(StickBreakingAttentionPreTrainedModel, FLAGenerationMixin):
 
-    _tied_weights_keys = ["lm_head.weight"]
+    _tied_weights_keys: ClassVar[list[str]] = ["lm_head.weight"]

Also applies to: 250-250

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 56ffe33 and cb99d40.

📒 Files selected for processing (3)
  • fla/__init__.py (3 hunks)
  • fla/layers/stickbreaking_attn.py (1 hunks)
  • fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
fla/__init__.py (2)
fla/layers/stickbreaking_attn.py (1)
  • StickBreakingAttention (25-108)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)
  • StickBreakingAttentionForCausalLM (248-345)
  • StickBreakingAttentionModel (145-245)
fla/layers/stickbreaking_attn.py (3)
fla/modules/layernorm.py (1)
  • RMSNorm (1063-1110)
fla/ops/stickbreaking_attn/parallel.py (2)
  • sb_attn (705-713)
  • forward (682-696)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)
  • forward (67-104)
  • forward (172-245)
  • forward (280-345)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (8)
fla/layers/stickbreaking_attn.py (2)
  • StickBreakingAttention (25-108)
  • forward (70-108)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
  • StickBreakingAttentionConfig (9-85)
fla/models/utils.py (1)
  • FLAGenerationMixin (384-461)
fla/modules/fused_cross_entropy.py (1)
  • FusedCrossEntropyLoss (343-418)
fla/modules/fused_linear_cross_entropy.py (1)
  • FusedLinearCrossEntropyLoss (491-565)
fla/modules/mlp.py (1)
  • GatedMLP (25-68)
fla/modules/layernorm.py (1)
  • RMSNorm (1063-1110)
fla/models/modeling_layers.py (1)
  • GradientCheckpointingLayer (10-70)
🪛 Ruff (0.14.3)
fla/layers/stickbreaking_attn.py

41-43: Avoid specifying long messages outside the exception class

(TRY003)


75-75: Unused method argument: output_attentions

(ARG002)


88-88: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py

112-112: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


185-185: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


195-195: Avoid specifying long messages outside the exception class

(TRY003)


197-197: Avoid specifying long messages outside the exception class

(TRY003)


250-250: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


336-336: Consider (logits, *outputs[1:]) instead of concatenation

Replace with (logits, *outputs[1:])

(RUF005)


337-337: Consider (loss, *output) instead of concatenation

Replace with (loss, *output)

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (1)
fla/__init__.py (1)

29-29: LGTM! Clean API surface additions.

The new StickBreakingAttention exports follow the established pattern and are correctly integrated into both the imports and __all__ list.

Also applies to: 78-79, 111-111

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
fla/layers/stickbreaking_attn.py (1)

44-53: Validate head divisibility and unsupported kv-head ratios

StickBreakingAttention currently accepts any hidden_size, num_heads, and num_kv_heads, but the implementation relies on exact divisibility and on the kernels supporting only the num_kv_heads == num_heads case. If the user supplies a configuration where hidden_size % num_heads != 0, einops.rearrange will crash with a cryptic runtime error. Likewise, when num_heads % num_kv_heads != 0 or simply num_kv_heads != num_heads, the naive and fused kernels receive tensors with incompatible head dimensions and immediately fail. We should fail fast with a clear error (or explicitly gate unsupported GQA/MQA) instead of letting these misconfigurations escape.

Suggestion:

         if num_kv_heads is None:
             self.num_kv_heads = self.num_heads
         else:
             self.num_kv_heads = num_kv_heads
-        self.num_kv_groups = self.num_heads // self.num_kv_heads
-        self.head_dim = self.hidden_size // self.num_heads
-        self.kv_dim = self.num_kv_heads * self.head_dim
+
+        if self.hidden_size % self.num_heads != 0:
+            raise ValueError(
+                f'hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).'
+            )
+        if self.num_heads % self.num_kv_heads != 0:
+            raise ValueError(
+                f'num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).'
+            )
+        if self.num_kv_heads != self.num_heads:
+            raise NotImplementedError(
+                'GQA/MQA (num_kv_heads != num_heads) is not yet supported for StickBreakingAttention.'
+            )
+
+        self.num_kv_groups = self.num_heads // self.num_kv_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.kv_dim = self.num_kv_heads * self.head_dim
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (1)

309-327: Fix logits_to_keep slicing and label alignment

When logits_to_keep > 0, we slice the hidden states before the head but keep the labels at full length. The subsequent logits.view(labels.numel(), -1) either crashes (size mismatch) or, in the fused path, we still compute loss on the full sequence instead of the last K positions. This breaks training for anyone using partial logits.

We need to slice both the hidden states and labels to the same window before calling the criterion:

-        logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:])
+        if logits_to_keep:
+            hidden_states_to_score = hidden_states[:, -logits_to_keep:, :]
+        else:
+            hidden_states_to_score = hidden_states
+
+        logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states_to_score)
@@
-            labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
-            if self.config.fuse_linear_cross_entropy:
-                loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
+            labels = torch.cat(
+                (labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)),
+                dim=-1,
+            )
+            if logits_to_keep:
+                labels = labels[:, -logits_to_keep:]
+            if self.config.fuse_linear_cross_entropy:
+                loss = criterion(hidden_states_to_score, labels, self.lm_head.weight, self.lm_head.bias)
             else:
-                loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
+                loss = criterion(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cb99d40 and 9b2554c.

📒 Files selected for processing (10)
  • fla/layers/stickbreaking_attn.py (1 hunks)
  • fla/models/__init__.py (2 hunks)
  • fla/models/stickbreaking_attn/__init__.py (1 hunks)
  • fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1 hunks)
  • fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (1 hunks)
  • fla/ops/stickbreaking_attn/__init__.py (1 hunks)
  • fla/ops/stickbreaking_attn/naive.py (1 hunks)
  • fla/ops/stickbreaking_attn/parallel.py (1 hunks)
  • tests/models/test_modeling_stickbreaking_attn.py (1 hunks)
  • tests/ops/test_stickbreaking_attn.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • fla/ops/stickbreaking_attn/parallel.py
  • fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py
  • fla/models/init.py
  • tests/models/test_modeling_stickbreaking_attn.py
  • fla/ops/stickbreaking_attn/init.py

Comment on lines +79 to +88
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)

if use_cache:
warnings.warn(
"StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not silently ignore attention_mask

We assert that an attention_mask is 2-D, but we never use it afterward. When padding is present in the batch, the kernel still allocates stick mass to those padded positions, so the outputs (and gradients) depend on garbage tokens. That violates the huggingface contract for attention layers and breaks training/inference with padded batches.

Until the kernel can consume the mask (e.g. by translating it to cu_seqlens or masking logits), we should not silently continue when zeros are present. Please either wire up proper masking or explicitly fail to avoid corrupt results. For example:

         if attention_mask is not None:
             assert len(attention_mask.shape) == 2, (
                 "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
                 "for padding purposes (0 indicating padding). "
                 "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
             )
+            valid = attention_mask.to(dtype=torch.bool)
+            if not torch.all(valid):
+                raise NotImplementedError(
+                    "Padding masks are not yet supported. Please supply packed sequences via `cu_seqlens` "
+                    "or implement masking before calling StickBreakingAttention."
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)
if use_cache:
warnings.warn(
"StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.")
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)
valid = attention_mask.to(dtype=torch.bool)
if not torch.all(valid):
raise NotImplementedError(
"Padding masks are not yet supported. Please supply packed sequences via `cu_seqlens` "
"or implement masking before calling StickBreakingAttention."
)
if use_cache:
warnings.warn(
"StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.")
🤖 Prompt for AI Agents
In fla/layers/stickbreaking_attn.py around lines 79-88, the code asserts
attention_mask is 2-D but never uses it so padded positions still receive stick
mass; either wire the mask into the kernel (e.g., convert [batch, seq_len] mask
into cu_seqlens or apply it to attention logits before the stick-breaking
kernel) or explicitly fail when any padding is present. Implement one of two
fixes: (1) translate the 0/1 attention_mask to the kernel-friendly format and
ensure masked positions are excluded from stick mass allocation, or (2) check
for any zeros in attention_mask and raise a clear ValueError (or
NotImplementedError) stating that StickBreakingAttention does not support padded
masks yet. Ensure the failure case includes guidance to use a mask-free batch or
a compatible attention implementation.

Comment on lines +186 to +205
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if use_cache and not isinstance(past_key_values, Cache):
past_key_values = Cache.from_legacy_cache(past_key_values)

if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)

hidden_states = inputs_embeds

all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
next_cache = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Disable the KV cache until it is implemented

Higher-level generation will call this model with use_cache=True and only feed the newly generated token plus past_key_values. However, StickBreakingAttention immediately flips use_cache to False and just returns the incoming cache object unchanged. On the next decoding step we therefore attend over an empty history, producing incorrect logits. Either implement cache population or explicitly force use_cache=False here and drop the cache from the outputs so downstream code falls back to full-prefix recomputation.

-        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
+        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
+        if use_cache:
+            warnings.warn(
+                "StickBreakingAttention does not support KV cache yet; forcing use_cache=False.",
+                stacklevel=2,
+            )
+            use_cache = False
+            past_key_values = None
🤖 Prompt for AI Agents
In fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py around lines 186
to 205, the KV cache handling is incorrect for generation: the model accepts
use_cache=True but immediately disables it in the attention layer and returns
the incoming cache unchanged, which causes future decoding steps to see an empty
history. Disable the KV cache end-to-end until it is implemented by forcing
use_cache = False here (ignore any incoming past_key_values), ensure
past_key_values is cleared/ignored (set to None or an empty Cache) before any
attention ops, and do not return a next_cache in the outputs (leave next_cache
as None) so downstream generation falls back to full-prefix recomputation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
fla/layers/stickbreaking_attn.py (2)

49-53: Validate head divisibility and gate unsupported GQA

We’re still doing integer floor division without guarding hidden_size % num_heads or num_heads % num_kv_heads. That yields silent shape corruption (and num_kv_groups == 0 if heads < kv_heads) before the first forward pass. We need explicit ValueErrors for those cases, and if the kernels can’t handle grouped-KV yet, raise NotImplementedError when num_kv_heads != num_heads.


84-105: Don’t ignore padding masks

We only assert the mask is 2-D, then discard it. Any padded batch (zeros present) will feed garbage tokens into the kernel, so outputs and gradients depend on padding noise. Either translate the 0/1 mask into cu_seqlens (or an explicit mask inside the kernel) or raise NotImplementedError until padding is supported. Silent failure here breaks HF interoperability.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9b2554c and b15eaf6.

📒 Files selected for processing (6)
  • fla/layers/stickbreaking_attn.py (1 hunks)
  • fla/ops/__init__.py (3 hunks)
  • fla/ops/stickbreaking_attn/__init__.py (1 hunks)
  • fla/ops/stickbreaking_attn/naive.py (1 hunks)
  • fla/ops/stickbreaking_attn/parallel.py (1 hunks)
  • tests/ops/test_stickbreaking_attn.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/ops/test_stickbreaking_attn.py
  • fla/ops/init.py
🧰 Additional context used
🧬 Code graph analysis (4)
fla/ops/stickbreaking_attn/parallel.py (2)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (114-119)
fla/layers/stickbreaking_attn.py (1)
  • forward (68-110)
fla/layers/stickbreaking_attn.py (3)
fla/modules/layernorm.py (1)
  • RMSNorm (1063-1110)
fla/ops/stickbreaking_attn/parallel.py (2)
  • parallel_stickbreaking_attn (703-713)
  • forward (680-694)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)
  • forward (62-99)
  • forward (167-240)
  • forward (275-340)
fla/ops/stickbreaking_attn/__init__.py (2)
fla/ops/stickbreaking_attn/naive.py (1)
  • naive_stickbreaking_attn (7-51)
fla/ops/stickbreaking_attn/parallel.py (1)
  • parallel_stickbreaking_attn (703-713)
fla/ops/stickbreaking_attn/naive.py (1)
fla/modules/activations.py (1)
  • logsigmoid (200-201)
🪛 Ruff (0.14.3)
fla/ops/stickbreaking_attn/parallel.py

32-32: Unused function argument: batch_size

(ARG001)


134-134: Unused function argument: NO_N_MASK

(ARG001)


138-138: Unused function argument: no_grad

(ARG001)


140-140: Unused function argument: return_attention

(ARG001)


397-397: Unused function argument: is_compiling

(ARG001)

fla/layers/stickbreaking_attn.py

39-41: Avoid specifying long messages outside the exception class

(TRY003)


73-73: Unused method argument: output_attentions

(ARG002)


86-86: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

fla/ops/stickbreaking_attn/__init__.py

4-7: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

♻️ Duplicate comments (4)
fla/ops/stickbreaking_attn/parallel.py (2)

52-52: Remove the no-op call to tl.num_programs(2).

The return value is not used, so this call has no effect and should be removed.

Apply this diff:

     batch_id = 0 if IS_VARLEN else tl.program_id(0)
     head_pid = tl.program_id(1)
     prog_id = tl.program_id(2)
-    tl.num_programs(2)
     if IS_VARLEN:

632-633: CRITICAL: Backward buffer allocation causes massive memory usage.

The dk and dv tensors are allocated with shape (M_count, batch_size, token_size, num_heads, dim_size), where M_count = triton.cdiv(token_size, BLOCK_M). For the advertised benchmark config (batch=8, seq=4096, heads=32, dim=64, BLOCK_M=64), this allocates:

  • M_count = 64 blocks
  • Memory per tensor ≈ 64 × 8 × 4096 × 32 × 64 × 2 bytes ≈ 64 GB in fp16

This is roughly 64× larger than the actual KV tensors and will OOM before the backward kernel runs, blocking any training at the advertised scale.

Solution: Accumulate gradients directly into [batch_size, token_size, num_heads, dim_size] buffers using atomic adds, block-level reductions, or a two-pass strategy with a shared staging buffer per program. Memory must scale linearly with model size.

Do you want me to generate a fix using atomic operations?

fla/layers/stickbreaking_attn.py (2)

43-52: Validate head/count divisibility and gate unsupported GQA

We still allow configurations where hidden_size is not divisible by num_heads, or num_heads is not divisible by num_kv_heads. Worse, the kernel does not implement grouped KV, so mismatched head counts silently mis-shard tensors. Please add the explicit guards and reject unsupported GQA/MQA setups.

         self.hidden_size = hidden_size
         self.num_heads = num_heads
         if num_kv_heads is None:
             self.num_kv_heads = self.num_heads
         else:
             self.num_kv_heads = num_kv_heads
-        self.num_kv_groups = self.num_heads // self.num_kv_heads
-        self.head_dim = self.hidden_size // self.num_heads
+        if self.hidden_size % self.num_heads != 0:
+            raise ValueError(
+                f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})."
+            )
+        if self.num_heads % self.num_kv_heads != 0:
+            raise ValueError(
+                f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})."
+            )
+        if self.num_kv_heads != self.num_heads:
+            raise NotImplementedError(
+                "StickBreakingAttention does not support grouped KV (num_kv_heads != num_heads) yet."
+            )
+        self.num_kv_groups = self.num_heads // self.num_kv_heads
+        self.head_dim = self.hidden_size // self.num_heads

77-104: Do not ignore attention_mask

Shape-checking and then dropping the mask leaves padded tokens in play, so outputs/gradients depend on garbage data. Either wire the mask into cu_seqlens/logits or clearly reject masked batches to avoid corrupt results.

         if attention_mask is not None:
             assert len(attention_mask.shape) == 2, (
                 "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
                 "for padding purposes (0 indicating padding). "
                 "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
             )
+            if not torch.all(attention_mask.to(torch.bool)):
+                raise NotImplementedError(
+                    "Padding masks are not yet supported. Pack sequences via `cu_seqlens` before calling "
+                    "StickBreakingAttention."
+                )
🧹 Nitpick comments (6)
sba_code/stickbreaking_attention/sb_varlen/softplus.py (2)

1-1: Remove unused import.

torch is imported but never used in this file.

Apply this diff:

-import torch
 import triton
 from triton import language as tl

35-52: Consider consolidating duplicate softplus implementations.

This file implements softplus with inline assembly, while fla/ops/stickbreaking_attn/parallel.py (lines 390-392) has a simpler implementation: tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x). Having two implementations makes maintenance harder and can lead to inconsistencies.

Can you clarify whether the inline-assembly path provides measurable performance benefits over the simpler implementation? If not, consider using the simpler version consistently across the codebase.

fla/ops/stickbreaking_attn/parallel.py (1)

390-392: Remove the unused is_compiling parameter.

The is_compiling parameter is declared but never used in the function body.

Apply this diff:

 @triton.jit
-def softplus(x, is_compiling: tl.constexpr = False):
+def softplus(x):
     return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x)
sba_code/tests/test_varlen.py (2)

4-4: Remove unused import.

torch.nn.functional as F is imported but never used in the test file.

Apply this diff:

 import torch
 import pytest
 import math
-from torch.nn import functional as F
 from stickbreaking_attention.sb_varlen import sb_attn_varlen

16-16: Consider adding strict=True to zip for safer iteration.

For Python 3.10+, using strict=True ensures all iterables have the same length and raises an error if they don't, preventing silent bugs.

Apply this diff:

-    for q_chunk, k_chunk, v_chunk in zip(q.split(splits, 1), k.split(splits, 1), v.split(splits, 1)):
+    for q_chunk, k_chunk, v_chunk in zip(q.split(splits, 1), k.split(splits, 1), v.split(splits, 1), strict=True):
sba_code/stickbreaking_attention/sb_attn/__init__.py (1)

57-60: Remove or document the unused zero_start parameter.

The zero_start parameter is accepted but never used in the function. Either remove it or document why it's present (e.g., for API compatibility).

If it's for compatibility with sb_attn_varlen, consider adding a docstring explaining this.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b15eaf6 and 93c98b3.

📒 Files selected for processing (23)
  • fla/layers/stickbreaking_attn.py (1 hunks)
  • fla/ops/stickbreaking_attn/naive.py (1 hunks)
  • fla/ops/stickbreaking_attn/parallel.py (1 hunks)
  • sba_code/.gitignore (1 hunks)
  • sba_code/LICENSE (1 hunks)
  • sba_code/README.md (1 hunks)
  • sba_code/benchmarks/attn.py (1 hunks)
  • sba_code/benchmarks/varlen.py (1 hunks)
  • sba_code/load_model_with_dolomite_demo.py (1 hunks)
  • sba_code/setup.py (1 hunks)
  • sba_code/stickbreaking_attention/__init__.py (1 hunks)
  • sba_code/stickbreaking_attention/sb_attn/__init__.py (1 hunks)
  • sba_code/stickbreaking_attention/sb_attn/sb_bwd.py (1 hunks)
  • sba_code/stickbreaking_attention/sb_attn/sb_fwd.py (1 hunks)
  • sba_code/stickbreaking_attention/sb_ref.py (1 hunks)
  • sba_code/stickbreaking_attention/sb_varlen/__init__.py (1 hunks)
  • sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py (1 hunks)
  • sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (1 hunks)
  • sba_code/stickbreaking_attention/sb_varlen/softplus.py (1 hunks)
  • sba_code/stickbreaking_attention/utils.py (1 hunks)
  • sba_code/tests/test_attn.py (1 hunks)
  • sba_code/tests/test_varlen.py (1 hunks)
  • tests/ops/test_stickbreaking_attn.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • sba_code/LICENSE
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/ops/test_stickbreaking_attn.py
  • fla/ops/stickbreaking_attn/naive.py
🧰 Additional context used
🧬 Code graph analysis (15)
sba_code/stickbreaking_attention/sb_attn/__init__.py (4)
sba_code/stickbreaking_attention/sb_attn/sb_bwd.py (1)
  • _bwd (163-205)
sba_code/stickbreaking_attention/sb_attn/sb_fwd.py (1)
  • _fwd (139-174)
sba_code/stickbreaking_attention/sb_varlen/__init__.py (3)
  • StickBreakingAttention (24-68)
  • forward (27-46)
  • backward (49-68)
fla/ops/stickbreaking_attn/parallel.py (2)
  • forward (681-693)
  • backward (696-699)
sba_code/stickbreaking_attention/sb_ref.py (1)
fla/modules/activations.py (1)
  • logsigmoid (200-201)
fla/ops/stickbreaking_attn/parallel.py (2)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (114-119)
fla/layers/stickbreaking_attn.py (1)
  • forward (68-108)
sba_code/benchmarks/varlen.py (5)
sba_code/stickbreaking_attention/sb_varlen/__init__.py (1)
  • sb_attn_varlen (71-78)
fla/modules/rotary.py (1)
  • rotate_half (16-22)
sba_code/stickbreaking_attention/sb_ref.py (1)
  • stickbreaking (8-25)
sba_code/tests/test_varlen.py (1)
  • ref_fwd (10-27)
sba_code/benchmarks/attn.py (2)
  • tri_fwdbwd (13-20)
  • flash_fwdbwd (22-30)
sba_code/stickbreaking_attention/sb_varlen/softplus.py (1)
fla/ops/stickbreaking_attn/parallel.py (1)
  • softplus (391-392)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (3)
fla/ops/stickbreaking_attn/parallel.py (4)
  • softplus (391-392)
  • load_kv (331-343)
  • compute_block (347-387)
  • backward (696-699)
sba_code/stickbreaking_attention/sb_varlen/softplus.py (1)
  • softplus (36-52)
sba_code/stickbreaking_attention/utils.py (1)
  • custom_op (20-39)
sba_code/stickbreaking_attention/sb_varlen/__init__.py (2)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (1)
  • varlen_fwd (413-451)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py (1)
  • varlen_bwd (484-539)
sba_code/tests/test_attn.py (3)
sba_code/stickbreaking_attention/sb_attn/__init__.py (1)
  • sb_attn (57-60)
sba_code/stickbreaking_attention/sb_ref.py (1)
  • stickbreaking (8-25)
sba_code/tests/test_varlen.py (3)
  • test_varlen (75-110)
  • assert_close (43-58)
  • ref_fwd (10-27)
sba_code/benchmarks/attn.py (3)
sba_code/stickbreaking_attention/sb_attn/__init__.py (1)
  • sb_attn (57-60)
fla/modules/rotary.py (1)
  • rotate_half (16-22)
sba_code/benchmarks/varlen.py (3)
  • tri_fwdbwd (42-53)
  • flash_fwdbwd (55-72)
  • fun_ (119-121)
sba_code/stickbreaking_attention/__init__.py (2)
sba_code/stickbreaking_attention/sb_attn/__init__.py (1)
  • sb_attn (57-60)
sba_code/stickbreaking_attention/sb_varlen/__init__.py (1)
  • sb_attn_varlen (71-78)
sba_code/tests/test_varlen.py (2)
sba_code/stickbreaking_attention/sb_varlen/__init__.py (2)
  • sb_attn_varlen (71-78)
  • backward (49-68)
sba_code/stickbreaking_attention/sb_ref.py (1)
  • stickbreaking (8-25)
sba_code/stickbreaking_attention/sb_attn/sb_fwd.py (3)
sba_code/stickbreaking_attention/utils.py (1)
  • custom_op (20-39)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (2)
  • _forward_one_row (82-215)
  • _forward (233-410)
sba_code/stickbreaking_attention/sb_varlen/softplus.py (1)
  • softplus (36-52)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py (2)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (2)
  • compute_block (27-78)
  • load_kv (11-23)
sba_code/stickbreaking_attention/utils.py (1)
  • custom_op (20-39)
sba_code/stickbreaking_attention/sb_attn/sb_bwd.py (3)
sba_code/stickbreaking_attention/utils.py (1)
  • custom_op (20-39)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py (2)
  • _backward_one_row (298-481)
  • _backward (92-294)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (2)
  • compute_block (27-78)
  • load_kv (11-23)
fla/layers/stickbreaking_attn.py (3)
fla/modules/layernorm.py (1)
  • RMSNorm (1063-1110)
fla/ops/stickbreaking_attn/parallel.py (2)
  • parallel_stickbreaking_attn (702-711)
  • forward (681-693)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)
  • forward (62-99)
  • forward (167-240)
  • forward (275-340)
🪛 Flake8 (7.3.0)
sba_code/load_model_with_dolomite_demo.py

[error] 5-5: redefinition of unused 'hf_models' from line 2

(F811)

sba_code/benchmarks/varlen.py

[error] 2-2: 'pytest' imported but unused

(F401)


[error] 8-8: 'transformers.models.llama.modeling_llama.apply_rotary_pos_emb' imported but unused

(F401)


[error] 89-89: f-string is missing placeholders

(F541)


[error] 111-111: do not assign a lambda expression, use a def

(E731)


[error] 113-113: do not assign a lambda expression, use a def

(E731)


[error] 117-117: do not assign a lambda expression, use a def

(E731)

sba_code/stickbreaking_attention/sb_varlen/softplus.py

[error] 1-1: 'torch' imported but unused

(F401)

sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py

[error] 224-224: unexpected indentation (comment)

(E116)


[error] 225-225: unexpected indentation (comment)

(E116)


[error] 226-226: unexpected indentation (comment)

(E116)


[error] 227-227: unexpected indentation (comment)

(E116)

sba_code/stickbreaking_attention/sb_varlen/__init__.py

[error] 7-7: 'torch.nn.functional as F' imported but unused

(F401)

sba_code/benchmarks/attn.py

[error] 2-2: 'pytest' imported but unused

(F401)


[error] 4-4: 'torch.nn.functional as F' imported but unused

(F401)


[error] 9-9: 'transformers.models.llama.modeling_llama.apply_rotary_pos_emb' imported but unused

(F401)


[error] 57-57: f-string is missing placeholders

(F541)


[error] 76-76: do not assign a lambda expression, use a def

(E731)


[error] 79-79: do not assign a lambda expression, use a def

(E731)


[error] 82-82: do not assign a lambda expression, use a def

(E731)

sba_code/stickbreaking_attention/__init__.py

[error] 1-1: '.sb_attn.sb_attn' imported but unused

(F401)


[error] 2-2: '.sb_varlen.sb_attn_varlen' imported but unused

(F401)

sba_code/tests/test_varlen.py

[error] 4-4: 'torch.nn.functional as F' imported but unused

(F401)

sba_code/stickbreaking_attention/sb_attn/sb_fwd.py

[error] 7-7: '..sb_varlen.softplus.softplus' imported but unused

(F401)

sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py

[error] 14-14: missing whitespace around parameter equals

(E252)


[error] 14-14: missing whitespace around parameter equals

(E252)


[error] 43-43: expected an indented block (comment)

(E115)


[error] 44-44: indentation is not a multiple of 4

(E111)


[error] 44-44: over-indented

(E117)


[error] 45-45: indentation is not a multiple of 4

(E111)


[error] 46-46: indentation is not a multiple of 4

(E111)


[error] 47-47: indentation is not a multiple of 4

(E111)


[error] 48-48: indentation is not a multiple of 4

(E111)


[error] 49-49: indentation is not a multiple of 4

(E111)


[error] 50-50: indentation is not a multiple of 4

(E111)


[error] 51-51: indentation is not a multiple of 4

(E111)


[error] 58-58: continuation line under-indented for visual indent

(E128)


[error] 58-58: missing whitespace around parameter equals

(E252)


[error] 58-58: missing whitespace around parameter equals

(E252)


[error] 502-502: local variable 'N_count' is assigned to but never used

(F841)

sba_code/stickbreaking_attention/sb_attn/sb_bwd.py

[error] 7-7: '..sb_varlen.sb_varlen_fwd.compute_block' imported but unused

(F401)


[error] 7-7: '..sb_varlen.sb_varlen_fwd.load_kv' imported but unused

(F401)

🪛 GitHub Actions: lint
sba_code/benchmarks/varlen.py

[error] 110-110: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 112-112: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 116-116: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)

sba_code/setup.py

[error] 1-1: End-of-file fixer modified file (end-of-file-fixer).

sba_code/tests/test_attn.py

[error] 1-1: Trailing whitespace found and removed by pre-commit (trailing-whitespace).

sba_code/benchmarks/attn.py

[error] 75-75: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 78-78: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 81-81: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)

sba_code/tests/test_varlen.py

[error] 1-1: Trailing whitespace found and removed by pre-commit (trailing-whitespace).

sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py

[error] 501-501: F841 Local variable N_count is assigned to but never used. (pre-commit: ruff-check)


[error] 1-1: Trailing whitespace found and removed by pre-commit (trailing-whitespace).


[error] 1-1: End-of-file fixer modified file (end-of-file-fixer).

sba_code/README.md

[error] 1-1: Trailing whitespace found and removed by pre-commit (trailing-whitespace).

🪛 markdownlint-cli2 (0.18.1)
sba_code/README.md

4-4: Link text should be descriptive

(MD059, descriptive-link-text)

🪛 Ruff (0.14.3)
sba_code/stickbreaking_attention/sb_attn/__init__.py

57-57: Unused function argument: zero_start

(ARG001)

fla/ops/stickbreaking_attn/parallel.py

31-31: Unused function argument: batch_size

(ARG001)


132-132: Unused function argument: NO_N_MASK

(ARG001)


136-136: Unused function argument: no_grad

(ARG001)


138-138: Unused function argument: return_attention

(ARG001)


391-391: Unused function argument: is_compiling

(ARG001)

sba_code/benchmarks/varlen.py

24-24: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


26-26: Unpacked variable rem is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


37-37: Unused function argument: do

(ARG001)


42-42: Unused function argument: do

(ARG001)


47-47: Unpacked variable rem is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


55-55: Unused function argument: do

(ARG001)


89-89: f-string without any placeholders

Remove extraneous f prefix

(F541)


111-111: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


113-113: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


117-117: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


121-121: Unpacked variable dq is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


121-121: Unpacked variable dk is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


121-121: Unpacked variable dv is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

sba_code/stickbreaking_attention/utils.py

21-21: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


22-22: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py

113-113: Unused function argument: NO_N_MASK

(ARG001)


117-117: Unused function argument: no_grad

(ARG001)


262-262: Unused function argument: batch_size

(ARG001)


263-263: Unused function argument: token_size

(ARG001)

sba_code/benchmarks/attn.py

13-13: Unused function argument: do

(ARG001)


17-17: Unpacked variable rem is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


22-22: Unused function argument: do

(ARG001)


32-32: Unused function argument: do

(ARG001)


57-57: f-string without any placeholders

Remove extraneous f prefix

(F541)


76-76: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


79-79: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


82-82: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


87-87: Unpacked variable dq is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


87-87: Unpacked variable dk is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


87-87: Unpacked variable dv is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

sba_code/tests/test_varlen.py

16-16: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

sba_code/stickbreaking_attention/sb_attn/sb_fwd.py

53-53: Unused function argument: batch_size

(ARG001)


56-56: Unused function argument: num_heads

(ARG001)

sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py

14-14: Do not perform function call tl.constexpr in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)


57-57: Unused function argument: Lock_ptr

(ARG001)


57-57: Unused function argument: Count_ptr

(ARG001)


58-58: Unused function argument: EVICTION_POLICY

(ARG001)


133-133: Unused function argument: batch_size

(ARG001)


134-134: Unused function argument: token_size

(ARG001)


138-138: Unused function argument: BLOCK_CSL

(ARG001)


141-141: Unused function argument: NO_N_MASK

(ARG001)


502-502: Local variable N_count is assigned to but never used

Remove assignment to unused variable N_count

(F841)

sba_code/stickbreaking_attention/sb_attn/sb_bwd.py

69-69: Unused function argument: batch_size

(ARG001)


72-72: Unused function argument: num_heads

(ARG001)


76-76: Unused function argument: NO_N_MASK

(ARG001)

fla/layers/stickbreaking_attn.py

39-41: Avoid specifying long messages outside the exception class

(TRY003)


73-73: Unused method argument: output_attentions

(ARG002)


85-85: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (3)
sba_code/.gitignore (1)

1-156: LGTM! Comprehensive and standard .gitignore structure.

This .gitignore file is well-formed and follows Python project best practices. It covers Python artifacts, packaging/distribution outputs, virtual environments, IDE configurations, test/coverage reports, type checker caches, and various build system artifacts. The patterns are syntactically correct and appropriately scoped.

As a defensive practice, having a local .gitignore in the sba_code/ sub-package is good, even though it may partially overlap with a root-level .gitignore in the main repository.

sba_code/stickbreaking_attention/__init__.py (1)

1-2: LGTM!

The imports are re-exported to form the public API of the package. The static analysis warnings about unused imports are false positives in this context.

sba_code/stickbreaking_attention/sb_ref.py (1)

7-25: LGTM!

The reference implementation is clear and follows the stick-breaking attention algorithm correctly. This provides a good baseline for testing the optimized implementations.

Comment on lines 74 to 80
position_ids = torch.arange(q.size(1), device=device, dtype=torch.int32)[None, :]
if provider == "triton":
fun = lambda: tri_fwdbwd(do, q, k, v)
elif provider == "flash":
rope = LlamaRotaryEmbedding(dim=head_dim).to(device)
fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v)
elif provider == "triton_flash":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix rotary position IDs dtype before runtime crash.

LlamaRotaryEmbedding indexes its cached cos/sin tables with position_ids. Passing a CUDA tensor in torch.int32 raises Index tensor must have dtype long, so this benchmark will crash as soon as you hit the flash branch. Please build position_ids with dtype=torch.long (and keep it on CUDA). A minimal change is to drop the explicit dtype or replace it with dtype=torch.long.

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 76-76: do not assign a lambda expression, use a def

(E731)


[error] 79-79: do not assign a lambda expression, use a def

(E731)

🪛 GitHub Actions: lint

[error] 75-75: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 78-78: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)

🪛 Ruff (0.14.3)

76-76: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


79-79: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)

🤖 Prompt for AI Agents
In sba_code/benchmarks/attn.py around lines 74 to 80, position_ids is created
with dtype=torch.int32 which causes LlamaRotaryEmbedding to throw "Index tensor
must have dtype long" when used on CUDA; change the creation of position_ids to
use dtype=torch.long (or omit dtype so it defaults to long) while keeping it on
the same device so the flash branch can index the cached cos/sin tables without
crashing.

Comment on lines 75 to 83
if provider == "triton":
fun = lambda: tri_fwdbwd(do, q, k, v)
elif provider == "flash":
rope = LlamaRotaryEmbedding(dim=head_dim).to(device)
fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v)
elif provider == "triton_flash":
rope = LlamaRotaryEmbedding(dim=head_dim).to(device)
fun = lambda: triton_flash_fwdbwd(rope, position_ids, do, q, k, v)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Refactor the provider dispatch into defs to unblock lint.

CI is currently red on this file (E731) because fun is assigned three different lambda expressions. Please rewrite these branches as small local functions (def run_triton(): ...; fun = run_triton, etc.) so the linter passes again.

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 76-76: do not assign a lambda expression, use a def

(E731)


[error] 79-79: do not assign a lambda expression, use a def

(E731)


[error] 82-82: do not assign a lambda expression, use a def

(E731)

🪛 GitHub Actions: lint

[error] 75-75: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 78-78: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 81-81: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)

🪛 Ruff (0.14.3)

76-76: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


79-79: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


82-82: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)

🤖 Prompt for AI Agents
In sba_code/benchmarks/attn.py around lines 75 to 83, the three branches
currently assign different lambda expressions to fun (triggering E731); replace
each lambda with a small local def: define run_triton() that calls
tri_fwdbwd(do, q, k, v), define run_flash() that constructs or closes over rope
= LlamaRotaryEmbedding(dim=head_dim).to(device) (or create rope before the def)
and calls flash_fwdbwd(rope, position_ids, do, q, k, v), and define
run_triton_flash() similarly for triton_flash_fwdbwd; then set fun = run_triton
/ run_flash / run_triton_flash in each branch so the linter no longer sees
lambda assignments.

Comment on lines 107 to 118
do = torch.randn((num_heads, total_length, head_dim), device=device, dtype=dtype)
position_ids = torch.arange(q.size(1), device=device, dtype=torch.int32)[None, :]

if provider== "reference":
fun = lambda: ref_fwdbwd(do, q, k, v, lengths)
elif provider == "triton":
fun = lambda: tri_fwdbwd(do, q, k, v, lengths)
elif provider == "flash":
config = LlamaConfig(max_position_embeddings=length)
rope = LlamaRotaryEmbedding(config).to(device)
fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v, lengths)
if bwd:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Use correct rotary positions (total tokens, dtype torch.long).

Here we build position_ids with q.size(1) and dtype=torch.int32. In this benchmark q.shape == (total_length, num_heads, head_dim), so q.size(1) is the head count, not the token index. As a result every token shares the same tiny set of angles, and the rotary embedding is completely wrong. Worse, LlamaRotaryEmbedding indexes its cache with position_ids, and a CUDA torch.int32 tensor will raise Index tensor must have dtype long. Please construct per-token positions (e.g. torch.arange(total_length, device=device, dtype=torch.long)[None, :], or restart per sequence via lengths) and keep them in torch.long so rotary is well-defined.

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 111-111: do not assign a lambda expression, use a def

(E731)


[error] 113-113: do not assign a lambda expression, use a def

(E731)


[error] 117-117: do not assign a lambda expression, use a def

(E731)

🪛 GitHub Actions: lint

[error] 110-110: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 112-112: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 116-116: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)

🪛 Ruff (0.14.3)

111-111: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


113-113: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


117-117: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)

🤖 Prompt for AI Agents
In sba_code/benchmarks/varlen.py around lines 107 to 118, the position_ids are
built from q.size(1) with dtype torch.int32 which is wrong because q.shape is
(total_length, num_heads, head_dim) so q.size(1) is num_heads not token
positions, and int32 will break LlamaRotaryEmbedding indexing; replace this with
per-token positions of length total_length and dtype torch.long (e.g.
torch.arange(total_length, device=device, dtype=torch.long)[None, :]) or, if
sequences restart, construct per-sequence positions using lengths, ensuring the
resulting tensor indexes the rotary cache correctly and matches the token
dimension of q.

Comment on lines 111 to 118
fun = lambda: ref_fwdbwd(do, q, k, v, lengths)
elif provider == "triton":
fun = lambda: tri_fwdbwd(do, q, k, v, lengths)
elif provider == "flash":
config = LlamaConfig(max_position_embeddings=length)
rope = LlamaRotaryEmbedding(config).to(device)
fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v, lengths)
if bwd:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Replace the provider lambdas with named callables.

Same lint failure as in attn.py: assigning lambdas to fun violates E731 and is currently breaking CI. Please convert each branch into a small def (or partial) and assign that function instead.

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 111-111: do not assign a lambda expression, use a def

(E731)


[error] 113-113: do not assign a lambda expression, use a def

(E731)


[error] 117-117: do not assign a lambda expression, use a def

(E731)

🪛 GitHub Actions: lint

[error] 112-112: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)


[error] 116-116: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)

🪛 Ruff (0.14.3)

111-111: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


113-113: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)


117-117: Do not assign a lambda expression, use a def

Rewrite fun as a def

(E731)

🤖 Prompt for AI Agents
In sba_code/benchmarks/varlen.py around lines 111 to 118, the code assigns
anonymous lambdas to the variable `fun` for each provider branch which triggers
lint error E731; replace each lambda with a named callable (either a small def
function or functools.partial) and assign that named function to `fun`. For
example, create distinct functions like `_ref_fun()`, `_tri_fun()`, and
`_flash_fun()` that capture the same closed-over variables (or use partial to
bind args) and set `fun = _ref_fun` / `fun = _tri_fun` / `fun = _flash_fun`
accordingly so no lambda is assigned. Ensure the flash branch still constructs
`config` and `rope` before defining the callable and leave the subsequent `if
bwd:` logic unchanged.

Comment on lines 4 to 5
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Close file handle properly.

The file is opened but never closed, causing a resource leak.

Apply this diff:

 def read(fname):
-    return open(os.path.join(os.path.dirname(__file__), fname)).read()
+    with open(os.path.join(os.path.dirname(__file__), fname)) as f:
+        return f.read()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def read(fname):
with open(os.path.join(os.path.dirname(__file__), fname)) as f:
return f.read()
🤖 Prompt for AI Agents
In sba_code/setup.py around lines 4 to 5, the read() function opens a file but
never closes it; change it to use a context manager (with open(... ) as f:) and
return f.read() so the file handle is closed automatically; also specify a mode
('r') and an explicit encoding (e.g., 'utf-8') for portability.

description = "Triton implementation of Stick-breaking attention",
license = "Apache License",
keywords = "triton pytorch llm stickbreaking attention",
url = "https://github.com/shawntan/scattermoe",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Update the project URL to the correct repository.

The URL points to shawntan/scattermoe instead of the actual stick-breaking attention repository.

Please update line 15 with the correct project URL.

🤖 Prompt for AI Agents
In sba_code/setup.py around line 15, the project URL currently points to
"https://github.com/shawntan/scattermoe"; update that value to the correct
repository URL for the stick-breaking attention project (e.g. replace with
"https://github.com/shawntan/stick-breaking-attention") so the package metadata
references the proper upstream repo.

Comment on lines 173 to 175
dk = torch.zeros_like(k, dtype=torch.bfloat16)
dv = torch.zeros_like(v, dtype=torch.bfloat16)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Stop coercing K/V gradients to bfloat16

Gradients must come back in the same dtype as the forward tensors. For fp16 or fp32 Q/K/V this cast to bfloat16 causes autograd to error out (type mismatch) or silently downcast precision. Reuse the original dtype instead of forcing bfloat16.

-    dk = torch.zeros_like(k, dtype=torch.bfloat16)
-    dv = torch.zeros_like(v, dtype=torch.bfloat16)
+    dk = torch.zeros_like(k)
+    dv = torch.zeros_like(v)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dk = torch.zeros_like(k, dtype=torch.bfloat16)
dv = torch.zeros_like(v, dtype=torch.bfloat16)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
🤖 Prompt for AI Agents
In sba_code/stickbreaking_attention/sb_attn/sb_bwd.py around lines 173 to 175,
the backward code currently creates dk/dv with dtype=torch.bfloat16 which forces
gradients to bfloat16 and breaks/autodownsamples fp16/fp32 tensors; change the
allocation to reuse the original forward dtypes (e.g., torch.zeros_like(k) /
torch.zeros_like(v) or torch.zeros_like(k, dtype=k.dtype) and
torch.zeros_like(v, dtype=v.dtype)) so gradients match the forward tensors'
dtype and device.

Comment on lines 315 to 319
O_head_seq_ptr = O_ptr + stride_oh * head_id + stride_om * seq_start_offset
R_head_seq_ptr = R_ptr + stride_rh * head_id + stride_rm * seq_start_offset
A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset
W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset
_forward_one_row(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix incorrect stride when computing W_head_seq_ptr.

Inside both head-block branches, W_head_seq_ptr is advanced with stride_am, i.e. the stride for the accumulator tensor A_ptr. As soon as return_attention=True, this writes attention weights using the wrong stride (and can walk into the A_ptr buffer), so the returned attention matrix is corrupted and may stomp unrelated memory. Please use the dedicated stride_wm for the W tensor when applying the sequence offset.

-            W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset
+            W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_wm * seq_start_offset
...
-            W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset
+            W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_wm * seq_start_offset

Also applies to: 367-371

🤖 Prompt for AI Agents
In sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py around lines
315-319 (also apply same fix at 367-371): W_head_seq_ptr is computed using
stride_am (accumulator stride) instead of the W tensor's sequence stride
stride_wm, which causes attention writes to use the wrong stride and can corrupt
A_ptr or other memory; fix by replacing the use of stride_am with stride_wm when
adding seq_start_offset to W_ptr so W_head_seq_ptr uses stride_wh * head_id +
stride_wm * seq_start_offset (mirror the pattern used for O/R/A but with the
correct W sequence stride).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1bf8ba2 and 3d037ff.

📒 Files selected for processing (2)
  • fla/ops/stickbreaking_attn/parallel.py (1 hunks)
  • fla/ops/stickbreaking_attn/softplus.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/stickbreaking_attn/parallel.py (2)
fla/ops/stickbreaking_attn/softplus.py (1)
  • softplus (38-49)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (114-119)
🪛 Ruff (0.14.3)
fla/ops/stickbreaking_attn/parallel.py

94-94: Unused function argument: NO_N_MASK

(ARG001)


98-98: Unused function argument: no_grad

(ARG001)


100-100: Unused function argument: return_attention

(ARG001)


344-344: Unused function argument: B

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (1)
fla/ops/stickbreaking_attn/softplus.py (1)

8-49: Inline softplus port looks great — matches the referenced kernel’s log₂-domain math and keeps the JIT-time constants nicely cached, so I don’t see any issues here.

Comment on lines +623 to +665
dk = torch.zeros((M_count, B, T, H, D), dtype=k.dtype, device=k.device)
dv = torch.zeros((M_count, B, T, H, D), dtype=v.dtype, device=v.device)

BLOCK_D = triton.next_power_of_2(D)

NO_M_MASK = (T % BT) == 0
NO_N_MASK = (T % BS) == 0
if cu_seqlens is not None:
NO_M_MASK = False
NO_N_MASK = False

parallel_stickbreaking_attn_bwd_kernel[grid](
do,
dr,
neg_log_acc,
q,
k,
v,
dq,
dk,
dv,
CU_ptr=cu_seqlens if cu_seqlens is not None else q,
CI_ptr=CI if CI is not None else q,
scale=scale,
B=B,
T=T,
head_size=D,
H=H,
BT=BT,
BS=BS,
BLOCK_D=BLOCK_D,
NO_D_MASK=D == BLOCK_D,
NO_M_MASK=NO_M_MASK,
NO_N_MASK=NO_N_MASK,
ALLOW_TF32=ALLOW_TF32,
acc_dtype=tl.float32,
IS_VARLEN=cu_seqlens is not None,
)

dk_final = dk.sum(0)
dv_final = dv.sum(0)

return dq.to(q.dtype), dk_final, dv_final
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Rewrite DK/DV accumulation to avoid 64× memory blow-up — allocating dk/dv as (M_count, B, T, H, D) temporarily multiplies the KV footprint by ~T / BT. For a single advertised run (batch = 8, seq = 4096, heads = 32, dim = 64, BT = 64) this comes out to ~8.6 GB per tensor in fp16, so you’ll OOM before the backward kernel even launches. Please accumulate directly into the [B, T, H, D] buffers (e.g. atomic adds or a per-block staging slab that you immediately fold into the final grad) so memory scales linearly with the model size.

@sustcsonglin sustcsonglin force-pushed the main branch 2 times, most recently from 1700f8d to f4082b3 Compare November 11, 2025 16:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] Add stick-breaking attention

2 participants