-
Notifications
You must be signed in to change notification settings - Fork 301
[Stick-Breaking Attention] Add Model #599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds Stick-Breaking Attention support: Triton and naive kernel implementations, a new attention layer, HF Transformers config and model wrappers (encoder and causal LM), ops exports, and tests for kernels and modeling behavior. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Layer as StickBreakingAttention
participant Kernels as parallel_stickbreaking_attn / naive
participant ModelOut
Caller->>Layer: forward(hidden_states, attention_mask?, cu_seqlens, ...)
activate Layer
Layer->>Layer: project Q, K, V
Layer->>Layer: optional q/k RMSNorm
alt use_cache requested
Layer-->>Caller: emit warning (KV caching unsupported)
end
Layer->>Kernels: call parallel_stickbreaking_attn(q,k,v,scale,cu_seqlens)
activate Kernels
Kernels->>Kernels: compute logits (log-space), apply triangular mask
Kernels->>Kernels: compute log_z/log_beta → attention weights
Kernels-->>Layer: return (o, rem)
deactivate Kernels
Layer->>Layer: reshape & output projection
Layer-->>ModelOut: return (output, None, past_key_values)
deactivate Layer
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Nathancgy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the novel Stick-Breaking Attention model, providing an efficient and scalable attention mechanism. The implementation adheres to the project's style guidelines, supports variable-length sequences, and has been benchmarked for competitive throughput. It includes both a naive reference implementation and an optimized Triton kernel, alongside comprehensive tests to ensure correctness and stability across various scenarios. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the Stick-Breaking Attention model, including the attention layer, model definition, configuration, and optimized Triton kernels. The implementation is comprehensive and includes both a naive version for reference and a parallelized version for performance, along with corresponding tests.
My review focuses on improving code clarity and correctness. I've identified a critical issue with a type hint in a constructor, and several medium-severity issues related to unused parameters, dead code, and confusing code constructs. These changes will make the new model's code cleaner and more maintainable.
| def __init__( | ||
| self, | ||
| config: StickBreakingAttentionConfig | ||
| ) -> StickBreakingAttentionModel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fla/layers/stickbreaking_attn.py
Outdated
| window_size: Optional[int] = None, | ||
| rope_theta: Optional[float] = None, # sba doesn't use RoPE | ||
| max_position_embeddings: Optional[int] = None, | ||
| layer_idx: int | None = None, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| if sb_attn is None: | ||
| raise ImportError( | ||
| "StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable." | ||
| ) | ||
|
|
||
| self.hidden_size = hidden_size | ||
| self.num_heads = num_heads | ||
| if num_kv_heads is None: | ||
| self.num_kv_heads = self.num_heads | ||
| else: | ||
| self.num_kv_heads = num_kv_heads | ||
| self.num_kv_groups = self.num_heads // self.num_kv_heads | ||
| self.head_dim = self.hidden_size // self.num_heads | ||
| self.kv_dim = self.num_kv_heads * self.head_dim | ||
| self.qkv_bias = qkv_bias | ||
| self.qk_norm = qk_norm | ||
|
|
||
| self.window_size = window_size | ||
| self.rope_theta = rope_theta | ||
| self.max_position_embeddings = max_position_embeddings | ||
| self.layer_idx = layer_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameters window_size, rope_theta, max_position_embeddings, and layer_idx are accepted in the constructor (lines 34-37) and stored as attributes (lines 58-61), but they are never used within the StickBreakingAttention class. This can be misleading and adds unnecessary clutter to the API. It's best to remove these unused parameters and their assignments to keep the code clean and maintainable.
| window_size: Optional[int] = None, | ||
| max_position_embeddings: int = 2048, | ||
| hidden_ratio: Optional[int] = 4, | ||
| intermediate_size: Optional[int] = None, | ||
| hidden_act: str = "swish", | ||
| initializer_range: float = 0.02, | ||
| elementwise_affine: Optional[bool] = True, | ||
| norm_eps: float = 1e-6, | ||
| use_cache: bool = True, | ||
| pad_token_id: Optional[int] = None, | ||
| bos_token_id: int = 1, | ||
| eos_token_id: int = 2, | ||
| tie_word_embeddings: bool = False, | ||
| fuse_norm: bool = True, | ||
| fuse_swiglu: bool = True, | ||
| fuse_cross_entropy: bool = True, | ||
| fuse_linear_cross_entropy: bool = False, | ||
| use_l2warp: bool = False, | ||
| vocab_size: int = 32000, | ||
| **kwargs, | ||
| ): | ||
| self.hidden_size = hidden_size | ||
| self.num_hidden_layers = num_hidden_layers | ||
| self.num_heads = num_heads | ||
| self.num_kv_heads = num_kv_heads | ||
| self.qkv_bias = qkv_bias | ||
| self.qk_norm = qk_norm | ||
| self.window_size = window_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| window_size=config.window_size, | ||
| rope_theta=None, | ||
| max_position_embeddings=config.max_position_embeddings, | ||
| layer_idx=layer_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fla/ops/stickbreaking_attn/naive.py
Outdated
| def _tril_mask(T: int, strict: bool = True, device=None) -> torch.Tensor: | ||
| i = torch.arange(T, device=device).view(1, 1, T, 1) | ||
| j = torch.arange(T, device=device).view(1, 1, 1, T) | ||
| return (j < i) if strict else (j <= i) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| batch_id = 0 if IS_VARLEN else tl.program_id(0) | ||
| head_pid = tl.program_id(1) | ||
| prog_id = tl.program_id(2) | ||
| tl.num_programs(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (25)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (3)
11-12: Annotate class attributes with ClassVar to satisfy linters.Use typing.ClassVar for
model_typeandkeys_to_ignore_at_inference(RUF012).Apply this diff:
-class StickBreakingAttentionConfig(PretrainedConfig): +class StickBreakingAttentionConfig(PretrainedConfig): @@ - model_type = 'stickbreaking_attn' - keys_to_ignore_at_inference = ['past_key_values'] + model_type: ClassVar[str] = 'stickbreaking_attn' + keys_to_ignore_at_inference: ClassVar[List[str]] = ['past_key_values']
4-4: Add missing imports for ClassVar/List.Needed for the ClassVar annotations above.
Apply this diff:
-from typing import Optional +from typing import Optional, ClassVar, List
73-77: Add stacklevel to warnings and fix grammar.Include
stacklevel=2(B028) and fix “can improves” → “can improve”.Apply this diff:
- warnings.warn( - "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency " - "at the potential cost of reduced precision. " - "If you observe issues like loss divergence, consider disabling this setting." - ) + warnings.warn( + "`fuse_linear_cross_entropy` is enabled, which can improve memory efficiency " + "at the potential cost of reduced precision. " + "If you observe issues like loss divergence, consider disabling this setting.", + stacklevel=2, + )fla/ops/__init__.py (1)
29-31: Gracefully handle missing Triton/compiled ops at import time.Package import will fail if
parallelbackend isn’t available. Make the export optional to allow CPU-only environments to import higher-level modules.Apply this diff:
-from .stickbreaking_attn.naive import sb_attn_naive -from .stickbreaking_attn.parallel import sb_attn +try: + from .stickbreaking_attn.parallel import sb_attn +except Exception: + sb_attn = None +from .stickbreaking_attn.naive import sb_attn_naivetests/ops/test_stickbreaking_attn.py (3)
52-55: Also verify attend_current=True path.Add a second forward/backward block with attend_current=True to cover diagonal-included semantics.
Apply this diff:
- # Triton fused - tri_o, tri_rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=False) + # Triton fused + tri_o, tri_rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=False) (tri_o * do).sum().backward(retain_graph=True) (tri_rem * dr).sum().backward() tri_dq, q.grad = q.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None tri_dv, v.grad = v.grad.clone(), None + + # Triton fused (attend_current=True) + q.grad = k.grad = v.grad = None + tri_o_diag, tri_rem_diag = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=True) + (tri_o_diag * do).sum().backward(retain_graph=True) + (tri_rem_diag * dr).sum().backward() + tri_dq_diag, q.grad = q.grad.clone(), None + tri_dk_diag, k.grad = k.grad.clone(), None + tri_dv_diag, v.grad = v.grad.clone(), None
60-65: Add checks for attend_current=True parity.Compare naïve vs fused also for attend_current=True to lock in both behaviors.
Apply this diff:
- assert_close(" o", ref_o, tri_o, 0.008) - assert_close("rem", ref_rem, tri_rem, 0.02) - assert_close("dq", ref_dq, tri_dq, 0.02) - assert_close("dk", ref_dk, tri_dk, 0.02) - assert_close("dv", ref_dv, tri_dv, 0.02) + assert_close(" o", ref_o, tri_o, 0.008) + assert_close("rem", ref_rem, tri_rem, 0.02) + assert_close("dq", ref_dq, tri_dq, 0.02) + assert_close("dk", ref_dk, tri_dk, 0.02) + assert_close("dv", ref_dv, tri_dv, 0.02) + + # Reference (naive, attend_current=True) + q.grad = k.grad = v.grad = None + ref_o_diag, ref_rem_diag = sb_attn_naive(q, k, v, inv_temp, attend_current=True) + (ref_o_diag * do).sum().backward(retain_graph=True) + (ref_rem_diag * dr).sum().backward() + ref_dq_diag, q.grad = q.grad.clone(), None + ref_dk_diag, k.grad = k.grad.clone(), None + ref_dv_diag, v.grad = v.grad.clone(), None + + assert_close(" o(diag)", ref_o_diag, tri_o_diag, 0.008) + assert_close("rem(diag)", ref_rem_diag, tri_rem_diag, 0.02) + assert_close("dq(diag)", ref_dq_diag, tri_dq_diag, 0.02) + assert_close("dk(diag)", ref_dk_diag, tri_dk_diag, 0.02) + assert_close("dv(diag)", ref_dv_diag, tri_dv_diag, 0.02)
12-21: Consider adding a varlen test using cu_seqlens.A small param set with padded sequences and cu_seqlens would catch IS_VARLEN paths.
Would you like me to add a minimal varlen test that builds cu_seqlens from an attention_mask and compares fused vs naive on concatenated sequences?
fla/ops/stickbreaking_attn/naive.py (4)
9-13: Remove or use _tril_mask; currently dead code.Apply this diff (remove it if unused):
-def _tril_mask(T: int, strict: bool = True, device=None) -> torch.Tensor: - i = torch.arange(T, device=device).view(1, 1, T, 1) - j = torch.arange(T, device=device).view(1, 1, 1, T) - return (j < i) if strict else (j <= i) +
34-36: Silence Ruff: unused unpacked variables.Prefix unused with underscores.
Apply this diff:
- B, T, H, D = q.shape + _B, T, _H, _D = q.shape
40-45: Fix misleading comments and simplify mask construction.Comments invert semantics; also avoid creating an all-ones tensor first.
Apply this diff:
- if attend_current: - mask = torch.ones(T, T, device=q.device).triu(1).bool() # exclude diagonal - else: - mask = torch.ones(T, T, device=q.device).triu(0).bool() # include diagonal + # attend_current=True: allow diagonal (strictly causal, incl diag) + # attend_current=False: exclude diagonal (strictly causal, strictly lower) + mask = torch.triu(torch.ones(T, T, device=q.device, dtype=torch.bool), diagonal=1 if attend_current else 0) mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, T, T]
37-55: Keep math in float32; convert at outputs only.Avoids bf16 underflow/precision loss in reference path; improves parity with fused kernel.
Apply this diff:
- logits = torch.einsum('bthd,bshd->bhts', q, k) * inv_temp - logits = logits.float() + logits = torch.einsum('bthd,bshd->bhts', q, k).to(torch.float32) * inv_temp @@ - log_z = torch.nn.functional.logsigmoid(logits).masked_fill(mask, -1e5).to(orig_dtype) - log_beta = torch.nn.functional.logsigmoid(-logits).masked_fill(mask, 0).to(orig_dtype) + log_z = torch.nn.functional.logsigmoid(logits).masked_fill(mask, -1e5) + log_beta = torch.nn.functional.logsigmoid(-logits).masked_fill(mask, 0.0) @@ - cum_weight = torch.ones(T, T, device=q.device).tril(-1) + cum_weight = torch.tril(torch.ones(T, T, device=q.device, dtype=log_beta.dtype), diagonal=-1) @@ - re_cum_log_beta = torch.einsum("bhij,jk->bhik", log_beta, cum_weight.to(log_beta)) + re_cum_log_beta = torch.einsum("bhij,jk->bhik", log_beta, cum_weight) @@ - return o.to(orig_dtype), rem.to(orig_dtype) + return o.to(orig_dtype), rem.to(orig_dtype)fla/ops/stickbreaking_attn/parallel.py (2)
573-612: Add lightweight input validation and ensure contiguity before launching kernels.Prevents subtle stride bugs and early catches shape mismatches.
Apply this diff:
- batch_size, token_size, num_heads, dim_size = q.size() + assert q.shape == k.shape == v.shape, "q, k, v must have the same shape [B, T, H, D]" + assert q.dim() == 4, "q, k, v must be rank-4 tensors [B, T, H, D]" + # Ensure expected memory layout for pointer arithmetic + if not q.is_contiguous(): q = q.contiguous() + if not k.is_contiguous(): k = k.contiguous() + if not v.is_contiguous(): v = v.contiguous() + batch_size, token_size, num_heads, dim_size = q.size()
399-401: Drop unused is_compiling arg (or use it).Currently unused; remove to silence linters.
Apply this diff:
-@triton.jit -def softplus(x, is_compiling: tl.constexpr = False): +@triton.jit +def softplus(x): return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x)Note: also remove corresponding parameter at call sites if present.
fla/layers/stickbreaking_attn.py (4)
41-45: Remove ineffective kernel-availability check.Import would already fail if sb_attn is missing; sb_attn is never None here.
Apply this diff:
- if sb_attn is None: - raise ImportError( - "StickBreakingAttention kernels are not available. Ensure Triton is installed and ops are importable." - ) + # Triton kernels are imported at module import-time; failures surface earlier.
89-93: Add stacklevel to warning for accurate caller attribution.Apply this diff:
- if use_cache: - warnings.warn( - "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.") + if use_cache: + warnings.warn( + "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.", + stacklevel=2 + )
96-101: q/k normalization is fine. Consider folding norms for perf (optional).If RMSNorm is cheap enough, this is OK. If perf matters, consider fusing q_proj+k_proj with norm upstream.
103-107: Respect attention_mask by deriving cu_seqlens when not provided.Currently attention_mask is validated but unused. Provide a fallback to varlen automatically.
Apply this diff:
- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens', None) + if cu_seqlens is None and attention_mask is not None: + # Build cu_seqlens from 0/1 padding mask: shape [B,T] with 1=valid + lengths = attention_mask.to(torch.int32).sum(dim=1) + cu_seqlens = torch.zeros((lengths.numel() + 1,), dtype=torch.int32, device=lengths.device) + cu_seqlens[1:] = torch.cumsum(lengths, dim=0) o, _rem = sb_attn(q, k, v, inv_temp=inv_temp, attend_current=attend_current, cu_seqlens=cu_seqlens)If this is intentional (mask not supported), consider throwing if attention_mask is provided without cu_seqlens.
fla/models/stickbreaking_attn/__init__.py (1)
11-13: Guard HF Auto registration for older Transformers (<4.33.0)exist_ok was added in Transformers v4.33.0 (PR #25779, Aug 29 2023). Add a try/except TypeError fallback to call register() without exist_ok for older versions.
File: fla/models/stickbreaking_attn/init.py (lines 11-13)
-AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig, exist_ok=True) -AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel, exist_ok=True) -AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM, exist_ok=True) +try: + AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig, exist_ok=True) + AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel, exist_ok=True) + AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM, exist_ok=True) +except TypeError: + # Fallback for older transformers without `exist_ok` + AutoConfig.register(StickBreakingAttentionConfig.model_type, StickBreakingAttentionConfig) + AutoModel.register(StickBreakingAttentionConfig, StickBreakingAttentionModel) + AutoModelForCausalLM.register(StickBreakingAttentionConfig, StickBreakingAttentionForCausalLM)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (7)
7-11: Import Unpack from typing (and add ClassVar) instead of transformers.processing_utils.
Unpackis a typing construct; importing it fromtransformers.processing_utilsis incorrect and can break type checking. Also addClassVarhere for class attribute annotations below.-from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Tuple, Union, Unpack - -if TYPE_CHECKING: - from transformers.processing_utils import UnpackAlso applies to: 24-26
108-113: Annotate mutable class attributes with ClassVar to satisfy linters and intent.This addresses RUF012 and clarifies that these are class-level constants.
class StickBreakingAttentionPreTrainedModel(PreTrainedModel): config_class = StickBreakingAttentionConfig - base_model_prefix = 'model' - supports_gradient_checkpointing = True - _no_split_modules = ['StickBreakingAttentionBlock'] - _supports_cache_class = True + base_model_prefix: ClassVar[str] = 'model' + supports_gradient_checkpointing: ClassVar[bool] = True + _no_split_modules: ClassVar[List[str]] = ['StickBreakingAttentionBlock'] + _supports_cache_class: ClassVar[bool] = True @@ class StickBreakingAttentionForCausalLM(StickBreakingAttentionPreTrainedModel, FLAGenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys: ClassVar[List[str]] = ["lm_head.weight"]Also applies to: 249-251
44-45: Respect config.elementwise_affine in RMSNorm instantiation.The config exposes
elementwise_affinebut it’s not passed, changing intended behavior for users toggling that flag.- self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)( + config.hidden_size, elementwise_affine=config.elementwise_affine, eps=config.norm_eps + ) @@ - self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)( + config.hidden_size, elementwise_affine=config.elementwise_affine, eps=config.norm_eps + )Also applies to: 57-58
184-189: Add stacklevel to warnings.warn for accurate caller context.Improves debuggability and addresses B028.
- warnings.warn( - "`sba` does not support output attention weights now, so `output_attentions` is set to `False`." - ) + warnings.warn( + "`sba` does not support output attention weights now, so `output_attentions` is set to `False`.", + stacklevel=2, + )
182-182: Correct return type annotation for base model forward.The base model returns
BaseModelOutputWithPast, notCausalLMOutputWithPast.- ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> Union[Tuple, BaseModelOutputWithPast]:
335-336: Use tuple unpacking instead of tuple concatenation.Concise and matches Ruff suggestion (RUF005).
- if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + if not return_dict: + output = (logits, *outputs[1:]) + return (loss, *output) if loss is not None else output
66-75: Tighten block forward return type annotation.The block can return 1–3 elements depending on flags. Reflect this for clarity.
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[ + torch.FloatTensor, + Optional[torch.FloatTensor], + Optional[Tuple[torch.FloatTensor]], + ]:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
fla/__init__.py(3 hunks)fla/layers/__init__.py(2 hunks)fla/layers/stickbreaking_attn.py(1 hunks)fla/models/__init__.py(2 hunks)fla/models/stickbreaking_attn/__init__.py(1 hunks)fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py(1 hunks)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py(1 hunks)fla/ops/__init__.py(2 hunks)fla/ops/stickbreaking_attn/__init__.py(1 hunks)fla/ops/stickbreaking_attn/naive.py(1 hunks)fla/ops/stickbreaking_attn/parallel.py(1 hunks)tests/models/test_modeling_stickbreaking_attn.py(1 hunks)tests/models/test_modeling_utils.py(1 hunks)tests/ops/test_stickbreaking_attn.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (11)
fla/ops/__init__.py (2)
fla/ops/stickbreaking_attn/naive.py (1)
sb_attn_naive(15-57)fla/ops/stickbreaking_attn/parallel.py (1)
sb_attn(705-713)
fla/models/stickbreaking_attn/__init__.py (2)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
StickBreakingAttentionConfig(9-85)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)
StickBreakingAttentionForCausalLM(247-344)StickBreakingAttentionModel(144-244)
fla/layers/stickbreaking_attn.py (3)
fla/modules/layernorm.py (1)
RMSNorm(1064-1111)fla/ops/stickbreaking_attn/parallel.py (2)
sb_attn(705-713)forward(682-696)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)
forward(66-103)forward(171-244)forward(279-344)
tests/ops/test_stickbreaking_attn.py (3)
fla/ops/stickbreaking_attn/parallel.py (2)
sb_attn(705-713)backward(699-702)fla/ops/stickbreaking_attn/naive.py (1)
sb_attn_naive(15-57)fla/utils.py (1)
assert_close(82-93)
tests/models/test_modeling_stickbreaking_attn.py (2)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
StickBreakingAttentionConfig(9-85)tests/models/test_modeling_base.py (2)
run_test_generation(67-126)run_test_model_forward_backward(27-61)
fla/ops/stickbreaking_attn/__init__.py (2)
fla/ops/stickbreaking_attn/naive.py (1)
sb_attn_naive(15-57)fla/ops/stickbreaking_attn/parallel.py (1)
sb_attn(705-713)
fla/__init__.py (2)
fla/layers/stickbreaking_attn.py (1)
StickBreakingAttention(25-112)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)
StickBreakingAttentionForCausalLM(247-344)StickBreakingAttentionModel(144-244)
fla/layers/__init__.py (1)
fla/layers/stickbreaking_attn.py (1)
StickBreakingAttention(25-112)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (8)
fla/layers/stickbreaking_attn.py (2)
StickBreakingAttention(25-112)forward(72-112)fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
StickBreakingAttentionConfig(9-85)fla/models/utils.py (1)
FLAGenerationMixin(385-462)fla/modules/fused_cross_entropy.py (1)
FusedCrossEntropyLoss(344-419)fla/modules/fused_linear_cross_entropy.py (1)
FusedLinearCrossEntropyLoss(493-567)fla/modules/mlp.py (1)
GatedMLP(26-69)fla/modules/layernorm.py (1)
RMSNorm(1064-1111)fla/models/modeling_layers.py (1)
GradientCheckpointingLayer(11-71)
fla/ops/stickbreaking_attn/parallel.py (2)
fla/ops/utils/index.py (1)
prepare_chunk_indices(116-121)fla/layers/stickbreaking_attn.py (1)
forward(72-112)
fla/models/__init__.py (2)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
StickBreakingAttentionConfig(9-85)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)
StickBreakingAttentionForCausalLM(247-344)StickBreakingAttentionModel(144-244)
🪛 Ruff (0.13.1)
fla/layers/stickbreaking_attn.py
42-44: Avoid specifying long messages outside the exception class
(TRY003)
77-77: Unused method argument: output_attentions
(ARG002)
90-90: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
fla/ops/stickbreaking_attn/naive.py
34-34: Unpacked variable B is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
34-34: Unpacked variable H is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
34-34: Unpacked variable D is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py
12-12: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
69-71: Avoid specifying long messages outside the exception class
(TRY003)
73-73: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py
111-111: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
184-184: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
196-196: Avoid specifying long messages outside the exception class
(TRY003)
249-249: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
335-335: Consider (logits, *outputs[1:]) instead of concatenation
Replace with (logits, *outputs[1:])
(RUF005)
336-336: Consider (loss, *output) instead of concatenation
Replace with (loss, *output)
(RUF005)
fla/ops/stickbreaking_attn/parallel.py
34-34: Unused function argument: batch_size
(ARG001)
136-136: Unused function argument: NO_N_MASK
(ARG001)
140-140: Unused function argument: no_grad
(ARG001)
142-142: Unused function argument: return_attention
(ARG001)
399-399: Unused function argument: is_compiling
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (15)
fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
68-71: Good defensive check for mutually exclusive flags.The exclusivity guard between
fuse_cross_entropyandfuse_linear_cross_entropyis correct.fla/ops/__init__.py (1)
55-55: Exports look consistent.
'sb_attn', 'sb_attn_naive'align with new modules.tests/models/test_modeling_utils.py (1)
27-27: Correctly excluded from generation.Adding
StickBreakingAttentionConfigtoGENERATION_UNSUPPORTEDmatches the current lack of KV cache support.fla/layers/__init__.py (2)
33-33: Expose StickBreakingAttention in layers API.Import is correct and aligns with public surface.
64-65: all updated appropriately.Exporting
StickBreakingAttentionis consistent.fla/ops/stickbreaking_attn/__init__.py (1)
1-9: Clean re-exports for ops.Simple and correct; matches package structure.
fla/models/__init__.py (2)
34-38: Models API exposure looks good.Imports for StickBreakingAttention family match modeling package.
70-71: all includes new StickBreakingAttention symbols.Consistent with imports.
fla/__init__.py (3)
29-31: Layers import list updated correctly.Adding
StickBreakingAttentionto the high-level API is consistent with other attention types.
79-81: Models import list updated correctly.Matches new modeling classes.
112-113: all now exports StickBreakingAttention family.Looks good.
tests/models/test_modeling_stickbreaking_attn.py (2)
14-24: Parametrization covers key cases.BF16 variants and D=128 (skipped on non-Hopper) provide reasonable coverage.
40-57: Generation test will be skipped by design.Since
StickBreakingAttentionConfigis inGENERATION_UNSUPPORTED, this test will skip. Confirm that’s intended until KV cache is implemented.fla/models/stickbreaking_attn/__init__.py (1)
16-16: LGTM: public exports are consistent with registrations.fla/ops/stickbreaking_attn/parallel.py (1)
49-66: Varlen path: confirm correctness with multi-sample cu_seqlens.Batch is forced to 0 in IS_VARLEN. Verify that prepare_chunk_indices and CU_ptr indexing assume [B,T,...] layout and not flattened [N,...]. Add a dedicated varlen test.
I can add a small test that builds cu_seqlens for a batch with unequal lengths and validates parity against a padded dense run. Do you want me to push that?
| self.num_kv_groups = self.num_heads // self.num_kv_heads | ||
| self.head_dim = self.hidden_size // self.num_heads | ||
| self.kv_dim = self.num_kv_heads * self.head_dim | ||
| self.qkv_bias = qkv_bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate head divisibility and GQA/MQA support.
Avoid silent mis-sharding. Either implement KV grouping or gate it for now.
Apply this diff:
- self.num_kv_groups = self.num_heads // self.num_kv_heads
- self.head_dim = self.hidden_size // self.num_heads
+ if self.hidden_size % self.num_heads != 0:
+ raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).")
+ if self.num_heads % self.num_kv_heads != 0:
+ raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).")
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
+ self.head_dim = self.hidden_size // self.num_headsIf GQA/MQA is not yet supported by the kernel, also add:
+ if self.num_kv_heads != self.num_heads:
+ raise NotImplementedError("GQA/MQA (num_kv_heads != num_heads) is not supported yet for StickBreakingAttention.")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.num_kv_groups = self.num_heads // self.num_kv_heads | |
| self.head_dim = self.hidden_size // self.num_heads | |
| self.kv_dim = self.num_kv_heads * self.head_dim | |
| self.qkv_bias = qkv_bias | |
| if self.hidden_size % self.num_heads != 0: | |
| raise ValueError(f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).") | |
| if self.num_heads % self.num_kv_heads != 0: | |
| raise ValueError(f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).") | |
| self.num_kv_groups = self.num_heads // self.num_kv_heads | |
| self.head_dim = self.hidden_size // self.num_heads | |
| if self.num_kv_heads != self.num_heads: | |
| raise NotImplementedError("GQA/MQA (num_kv_heads != num_heads) is not supported yet for StickBreakingAttention.") | |
| self.kv_dim = self.num_kv_heads * self.head_dim | |
| self.qkv_bias = qkv_bias |
| logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) | ||
|
|
||
| loss = None | ||
| if labels is not None: | ||
| if getattr(self, 'criterion', None) is None: | ||
| if self.config.fuse_linear_cross_entropy: | ||
| criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) | ||
| elif self.config.fuse_cross_entropy: | ||
| criterion = FusedCrossEntropyLoss(inplace_backward=True) | ||
| else: | ||
| criterion = nn.CrossEntropyLoss() | ||
| else: | ||
| criterion = self.criterion | ||
| labels = labels.to(hidden_states.device) | ||
| labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) | ||
| if self.config.fuse_linear_cross_entropy: | ||
| loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) | ||
| else: | ||
| loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) | ||
| loss = l2_warp(loss, logits) if self.config.use_l2warp else loss | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix loss/labels alignment when logits_to_keep > 0 (shape and indexing bug).
When labels is provided and logits_to_keep > 0, logits is sliced to the last K tokens but labels remains length T after the shift. Reshaping logits to labels.numel() also silently misaligns data. This will produce incorrect training signals or runtime errors under common usages where only the last K logits are kept.
Align the labels with the kept logits and reshape based on logits’ actual size.
loss = None
if labels is not None:
if getattr(self, 'criterion', None) is None:
if self.config.fuse_linear_cross_entropy:
criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp)
elif self.config.fuse_cross_entropy:
criterion = FusedCrossEntropyLoss(inplace_backward=True)
else:
criterion = nn.CrossEntropyLoss()
else:
criterion = self.criterion
- labels = labels.to(hidden_states.device)
- labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
+ labels = labels.to(hidden_states.device)
+ # Next-token prediction: shift labels left and pad last position with ignore_index
+ shifted_labels = torch.cat(
+ (labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)),
+ dim=-1,
+ )
if self.config.fuse_linear_cross_entropy:
- loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
+ # Fused linear CE expects [B, T, H] and [B, T]
+ loss = criterion(hidden_states, shifted_labels, self.lm_head.weight, self.lm_head.bias)
else:
- loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
+ # Keep labels only where logits are computed
+ if logits_to_keep:
+ shifted_labels = shifted_labels[:, -logits_to_keep:]
+ # logits: [B, K, V], shifted_labels: [B, K]
+ loss = criterion(logits.reshape(-1, logits.size(-1)), shifted_labels.reshape(-1))
loss = l2_warp(loss, logits) if self.config.use_l2warp else loss📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) | |
| loss = None | |
| if labels is not None: | |
| if getattr(self, 'criterion', None) is None: | |
| if self.config.fuse_linear_cross_entropy: | |
| criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) | |
| elif self.config.fuse_cross_entropy: | |
| criterion = FusedCrossEntropyLoss(inplace_backward=True) | |
| else: | |
| criterion = nn.CrossEntropyLoss() | |
| else: | |
| criterion = self.criterion | |
| labels = labels.to(hidden_states.device) | |
| labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) | |
| if self.config.fuse_linear_cross_entropy: | |
| loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) | |
| else: | |
| loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) | |
| loss = l2_warp(loss, logits) if self.config.use_l2warp else loss | |
| logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) | |
| loss = None | |
| if labels is not None: | |
| if getattr(self, 'criterion', None) is None: | |
| if self.config.fuse_linear_cross_entropy: | |
| criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) | |
| elif self.config.fuse_cross_entropy: | |
| criterion = FusedCrossEntropyLoss(inplace_backward=True) | |
| else: | |
| criterion = nn.CrossEntropyLoss() | |
| else: | |
| criterion = self.criterion | |
| labels = labels.to(hidden_states.device) | |
| # Next-token prediction: shift labels left and pad last position with ignore_index | |
| shifted_labels = torch.cat( | |
| (labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), | |
| dim=-1, | |
| ) | |
| if self.config.fuse_linear_cross_entropy: | |
| # Fused linear CE expects [B, T, H] and [B, T] | |
| loss = criterion(hidden_states, shifted_labels, self.lm_head.weight, self.lm_head.bias) | |
| else: | |
| # Keep labels only where logits are computed | |
| if logits_to_keep: | |
| shifted_labels = shifted_labels[:, -logits_to_keep:] | |
| # logits: [B, K, V], shifted_labels: [B, K] | |
| loss = criterion(logits.reshape(-1, logits.size(-1)), shifted_labels.reshape(-1)) | |
| loss = l2_warp(loss, logits) if self.config.use_l2warp else loss |
🤖 Prompt for AI Agents
In fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py around lines 313
to 333, the labels are not trimmed to match logits when logits_to_keep > 0
causing misalignment and wrong loss shapes; after shifting labels (dropping
first token and appending ignore_index) slice the shifted labels to keep only
the last K positions matching logits (e.g. labels = labels[...,
-logits.size(1):]) and use logits.view(-1, logits.size(-1)) or
logits.numel()/logits.size(-1) to reshape consistently; ensure labels are on the
same device before slicing and apply the same l2_warp logic only after correct
alignment so loss uses logits.shape and labels.numel() that correspond to each
other.
a4e1a1c to
90a0fea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)
47-56: Remove unused parameters from StickBreakingAttention instantiation.The parameters
window_size,max_position_embeddings, andlayer_idxare passed toStickBreakingAttentionbut are not used by that layer (they are stored but never referenced in the forward pass). This creates unnecessary API surface and potential confusion.Based on past review comments, these parameters should be removed from the
StickBreakingAttentionlayer itself, and consequently from this instantiation as well.
150-150: Fix incorrect return type annotation on__init__.The
__init__method should have a return type ofNone, not-> StickBreakingAttentionModel.Apply this diff:
def __init__( self, config: StickBreakingAttentionConfig - ) -> StickBreakingAttentionModel: + ):
314-333: CRITICAL: Fix loss/labels shape mismatch whenlogits_to_keep > 0.When
logits_to_keep > 0, logits are sliced to[B, K, V](Line 314) but labels remain at full sequence length[B, T]after shifting (Line 328). Line 332 attemptslogits.view(labels.numel(), -1)which creates a shape mismatch sinceB*K ≠ B*T, causing either runtime errors or silently incorrect loss computation.The shifted labels must be sliced to match the logits when
logits_to_keep > 0:loss = None if labels is not None: if getattr(self, 'criterion', None) is None: if self.config.fuse_linear_cross_entropy: criterion = FusedLinearCrossEntropyLoss(use_l2warp=self.config.use_l2warp) elif self.config.fuse_cross_entropy: criterion = FusedCrossEntropyLoss(inplace_backward=True) else: criterion = nn.CrossEntropyLoss() else: criterion = self.criterion labels = labels.to(hidden_states.device) - labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + # Shift labels for next-token prediction + shifted_labels = torch.cat( + (labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), + dim=1 + ) if self.config.fuse_linear_cross_entropy: - loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + loss = criterion(hidden_states, shifted_labels, self.lm_head.weight, self.lm_head.bias) else: - loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + # Slice labels to match logits when only keeping last K positions + if logits_to_keep > 0: + shifted_labels = shifted_labels[:, -logits_to_keep:] + loss = criterion(logits.reshape(-1, logits.size(-1)), shifted_labels.reshape(-1)) loss = l2_warp(loss, logits) if self.config.use_l2warp else loss
🧹 Nitpick comments (4)
fla/layers/stickbreaking_attn.py (2)
75-75: Remove unusedoutput_attentionsparameter.The
output_attentionsparameter is accepted but never used in the forward method. Since the implementation always returnsNonefor attentions (Line 108), this parameter should be removed from the signature to avoid confusion.Apply this diff:
def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, - output_attentions: bool = False, use_cache: bool = False, attend_current: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
87-90: Addstacklevel=2to warnings.warn call.The warning should include
stacklevel=2so the warning message points to the caller's location rather than this module.Apply this diff:
if use_cache: warnings.warn( - "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.") + "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.", + stacklevel=2) use_cache = Falsefla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)
184-188: Addstacklevel=2to warnings.warn call.The warning should include
stacklevel=2so the warning message points to the caller's location rather than this module.Apply this diff:
if output_attentions: warnings.warn( - "`sba` does not support output attention weights now, so `output_attentions` is set to `False`." + "`sba` does not support output attention weights now, so `output_attentions` is set to `False`.", + stacklevel=2 ) output_attentions = False
112-112: Annotate mutable class attributes withClassVar.The
_no_split_modules(Line 112) and_tied_weights_keys(Line 250) are class-level configuration attributes that should be annotated withtyping.ClassVarto clarify they are not instance attributes.Add the import and annotations:
+from typing import ClassVar + class StickBreakingAttentionPreTrainedModel(PreTrainedModel): config_class = StickBreakingAttentionConfig base_model_prefix = 'model' supports_gradient_checkpointing = True - _no_split_modules = ['StickBreakingAttentionBlock'] + _no_split_modules: ClassVar[list[str]] = ['StickBreakingAttentionBlock'] _supports_cache_class = Trueclass StickBreakingAttentionForCausalLM(StickBreakingAttentionPreTrainedModel, FLAGenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys: ClassVar[list[str]] = ["lm_head.weight"]Also applies to: 250-250
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/__init__.py(3 hunks)fla/layers/stickbreaking_attn.py(1 hunks)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
fla/__init__.py (2)
fla/layers/stickbreaking_attn.py (1)
StickBreakingAttention(25-108)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (2)
StickBreakingAttentionForCausalLM(248-345)StickBreakingAttentionModel(145-245)
fla/layers/stickbreaking_attn.py (3)
fla/modules/layernorm.py (1)
RMSNorm(1063-1110)fla/ops/stickbreaking_attn/parallel.py (2)
sb_attn(705-713)forward(682-696)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)
forward(67-104)forward(172-245)forward(280-345)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (8)
fla/layers/stickbreaking_attn.py (2)
StickBreakingAttention(25-108)forward(70-108)fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py (1)
StickBreakingAttentionConfig(9-85)fla/models/utils.py (1)
FLAGenerationMixin(384-461)fla/modules/fused_cross_entropy.py (1)
FusedCrossEntropyLoss(343-418)fla/modules/fused_linear_cross_entropy.py (1)
FusedLinearCrossEntropyLoss(491-565)fla/modules/mlp.py (1)
GatedMLP(25-68)fla/modules/layernorm.py (1)
RMSNorm(1063-1110)fla/models/modeling_layers.py (1)
GradientCheckpointingLayer(10-70)
🪛 Ruff (0.14.3)
fla/layers/stickbreaking_attn.py
41-43: Avoid specifying long messages outside the exception class
(TRY003)
75-75: Unused method argument: output_attentions
(ARG002)
88-88: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py
112-112: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
185-185: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
195-195: Avoid specifying long messages outside the exception class
(TRY003)
197-197: Avoid specifying long messages outside the exception class
(TRY003)
250-250: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
336-336: Consider (logits, *outputs[1:]) instead of concatenation
Replace with (logits, *outputs[1:])
(RUF005)
337-337: Consider (loss, *output) instead of concatenation
Replace with (loss, *output)
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (1)
fla/__init__.py (1)
29-29: LGTM! Clean API surface additions.The new StickBreakingAttention exports follow the established pattern and are correctly integrated into both the imports and
__all__list.Also applies to: 78-79, 111-111
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
fla/layers/stickbreaking_attn.py (1)
44-53: Validate head divisibility and unsupported kv-head ratios
StickBreakingAttentioncurrently accepts anyhidden_size,num_heads, andnum_kv_heads, but the implementation relies on exact divisibility and on the kernels supporting only the num_kv_heads == num_heads case. If the user supplies a configuration wherehidden_size % num_heads != 0,einops.rearrangewill crash with a cryptic runtime error. Likewise, whennum_heads % num_kv_heads != 0or simplynum_kv_heads != num_heads, the naive and fused kernels receive tensors with incompatible head dimensions and immediately fail. We should fail fast with a clear error (or explicitly gate unsupported GQA/MQA) instead of letting these misconfigurations escape.Suggestion:
if num_kv_heads is None: self.num_kv_heads = self.num_heads else: self.num_kv_heads = num_kv_heads - self.num_kv_groups = self.num_heads // self.num_kv_heads - self.head_dim = self.hidden_size // self.num_heads - self.kv_dim = self.num_kv_heads * self.head_dim + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f'hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads}).' + ) + if self.num_heads % self.num_kv_heads != 0: + raise ValueError( + f'num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads}).' + ) + if self.num_kv_heads != self.num_heads: + raise NotImplementedError( + 'GQA/MQA (num_kv_heads != num_heads) is not yet supported for StickBreakingAttention.' + ) + + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dimfla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (1)
309-327: Fixlogits_to_keepslicing and label alignmentWhen
logits_to_keep > 0, we slice the hidden states before the head but keep the labels at full length. The subsequentlogits.view(labels.numel(), -1)either crashes (size mismatch) or, in the fused path, we still compute loss on the full sequence instead of the last K positions. This breaks training for anyone using partial logits.We need to slice both the hidden states and labels to the same window before calling the criterion:
- logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + if logits_to_keep: + hidden_states_to_score = hidden_states[:, -logits_to_keep:, :] + else: + hidden_states_to_score = hidden_states + + logits = None if self.config.fuse_linear_cross_entropy else self.lm_head(hidden_states_to_score) @@ - labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) - if self.config.fuse_linear_cross_entropy: - loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + labels = torch.cat( + (labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), + dim=-1, + ) + if logits_to_keep: + labels = labels[:, -logits_to_keep:] + if self.config.fuse_linear_cross_entropy: + loss = criterion(hidden_states_to_score, labels, self.lm_head.weight, self.lm_head.bias) else: - loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + loss = criterion(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
fla/layers/stickbreaking_attn.py(1 hunks)fla/models/__init__.py(2 hunks)fla/models/stickbreaking_attn/__init__.py(1 hunks)fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py(1 hunks)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py(1 hunks)fla/ops/stickbreaking_attn/__init__.py(1 hunks)fla/ops/stickbreaking_attn/naive.py(1 hunks)fla/ops/stickbreaking_attn/parallel.py(1 hunks)tests/models/test_modeling_stickbreaking_attn.py(1 hunks)tests/ops/test_stickbreaking_attn.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- fla/ops/stickbreaking_attn/parallel.py
- fla/models/stickbreaking_attn/configuration_stickbreaking_attn.py
- fla/models/init.py
- tests/models/test_modeling_stickbreaking_attn.py
- fla/ops/stickbreaking_attn/init.py
| if attention_mask is not None: | ||
| assert len(attention_mask.shape) == 2, ( | ||
| "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " | ||
| "for padding purposes (0 indicating padding). " | ||
| "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." | ||
| ) | ||
|
|
||
| if use_cache: | ||
| warnings.warn( | ||
| "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not silently ignore attention_mask
We assert that an attention_mask is 2-D, but we never use it afterward. When padding is present in the batch, the kernel still allocates stick mass to those padded positions, so the outputs (and gradients) depend on garbage tokens. That violates the huggingface contract for attention layers and breaks training/inference with padded batches.
Until the kernel can consume the mask (e.g. by translating it to cu_seqlens or masking logits), we should not silently continue when zeros are present. Please either wire up proper masking or explicitly fail to avoid corrupt results. For example:
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)
+ valid = attention_mask.to(dtype=torch.bool)
+ if not torch.all(valid):
+ raise NotImplementedError(
+ "Padding masks are not yet supported. Please supply packed sequences via `cu_seqlens` "
+ "or implement masking before calling StickBreakingAttention."
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if attention_mask is not None: | |
| assert len(attention_mask.shape) == 2, ( | |
| "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " | |
| "for padding purposes (0 indicating padding). " | |
| "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." | |
| ) | |
| if use_cache: | |
| warnings.warn( | |
| "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.") | |
| if attention_mask is not None: | |
| assert len(attention_mask.shape) == 2, ( | |
| "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " | |
| "for padding purposes (0 indicating padding). " | |
| "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." | |
| ) | |
| valid = attention_mask.to(dtype=torch.bool) | |
| if not torch.all(valid): | |
| raise NotImplementedError( | |
| "Padding masks are not yet supported. Please supply packed sequences via `cu_seqlens` " | |
| "or implement masking before calling StickBreakingAttention." | |
| ) | |
| if use_cache: | |
| warnings.warn( | |
| "StickBreakingAttention does not support KV cache yet; falling back to use_cache=False.") |
🤖 Prompt for AI Agents
In fla/layers/stickbreaking_attn.py around lines 79-88, the code asserts
attention_mask is 2-D but never uses it so padded positions still receive stick
mass; either wire the mask into the kernel (e.g., convert [batch, seq_len] mask
into cu_seqlens or apply it to attention logits before the stick-breaking
kernel) or explicitly fail when any padding is present. Implement one of two
fixes: (1) translate the 0/1 attention_mask to the kernel-friendly format and
ensure masked positions are excluded from stick mass allocation, or (2) check
for any zeros in attention_mask and raise a clear ValueError (or
NotImplementedError) stating that StickBreakingAttention does not support padded
masks yet. Ensure the failure case includes guidance to use a mask-free batch or
a compatible attention implementation.
| use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| if input_ids is not None and inputs_embeds is not None: | ||
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | ||
| elif input_ids is None and inputs_embeds is None: | ||
| raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
|
||
| if use_cache and not isinstance(past_key_values, Cache): | ||
| past_key_values = Cache.from_legacy_cache(past_key_values) | ||
|
|
||
| if inputs_embeds is None: | ||
| inputs_embeds = self.embeddings(input_ids) | ||
|
|
||
| hidden_states = inputs_embeds | ||
|
|
||
| all_hidden_states = () if output_hidden_states else None | ||
| all_attns = () if output_attentions else None | ||
| next_cache = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disable the KV cache until it is implemented
Higher-level generation will call this model with use_cache=True and only feed the newly generated token plus past_key_values. However, StickBreakingAttention immediately flips use_cache to False and just returns the incoming cache object unchanged. On the next decoding step we therefore attend over an empty history, producing incorrect logits. Either implement cache population or explicitly force use_cache=False here and drop the cache from the outputs so downstream code falls back to full-prefix recomputation.
- use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
+ if use_cache:
+ warnings.warn(
+ "StickBreakingAttention does not support KV cache yet; forcing use_cache=False.",
+ stacklevel=2,
+ )
+ use_cache = False
+ past_key_values = None🤖 Prompt for AI Agents
In fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py around lines 186
to 205, the KV cache handling is incorrect for generation: the model accepts
use_cache=True but immediately disables it in the attention layer and returns
the incoming cache unchanged, which causes future decoding steps to see an empty
history. Disable the KV cache end-to-end until it is implemented by forcing
use_cache = False here (ignore any incoming past_key_values), ensure
past_key_values is cleared/ignored (set to None or an empty Cache) before any
attention ops, and do not return a next_cache in the outputs (leave next_cache
as None) so downstream generation falls back to full-prefix recomputation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
fla/layers/stickbreaking_attn.py (2)
49-53: Validate head divisibility and gate unsupported GQAWe’re still doing integer floor division without guarding
hidden_size % num_headsornum_heads % num_kv_heads. That yields silent shape corruption (andnum_kv_groups == 0if heads < kv_heads) before the first forward pass. We need explicitValueErrors for those cases, and if the kernels can’t handle grouped-KV yet, raiseNotImplementedErrorwhennum_kv_heads != num_heads.
84-105: Don’t ignore padding masksWe only assert the mask is 2-D, then discard it. Any padded batch (zeros present) will feed garbage tokens into the kernel, so outputs and gradients depend on padding noise. Either translate the 0/1 mask into
cu_seqlens(or an explicit mask inside the kernel) or raiseNotImplementedErroruntil padding is supported. Silent failure here breaks HF interoperability.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
fla/layers/stickbreaking_attn.py(1 hunks)fla/ops/__init__.py(3 hunks)fla/ops/stickbreaking_attn/__init__.py(1 hunks)fla/ops/stickbreaking_attn/naive.py(1 hunks)fla/ops/stickbreaking_attn/parallel.py(1 hunks)tests/ops/test_stickbreaking_attn.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/ops/test_stickbreaking_attn.py
- fla/ops/init.py
🧰 Additional context used
🧬 Code graph analysis (4)
fla/ops/stickbreaking_attn/parallel.py (2)
fla/ops/utils/index.py (1)
prepare_chunk_indices(114-119)fla/layers/stickbreaking_attn.py (1)
forward(68-110)
fla/layers/stickbreaking_attn.py (3)
fla/modules/layernorm.py (1)
RMSNorm(1063-1110)fla/ops/stickbreaking_attn/parallel.py (2)
parallel_stickbreaking_attn(703-713)forward(680-694)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)
forward(62-99)forward(167-240)forward(275-340)
fla/ops/stickbreaking_attn/__init__.py (2)
fla/ops/stickbreaking_attn/naive.py (1)
naive_stickbreaking_attn(7-51)fla/ops/stickbreaking_attn/parallel.py (1)
parallel_stickbreaking_attn(703-713)
fla/ops/stickbreaking_attn/naive.py (1)
fla/modules/activations.py (1)
logsigmoid(200-201)
🪛 Ruff (0.14.3)
fla/ops/stickbreaking_attn/parallel.py
32-32: Unused function argument: batch_size
(ARG001)
134-134: Unused function argument: NO_N_MASK
(ARG001)
138-138: Unused function argument: no_grad
(ARG001)
140-140: Unused function argument: return_attention
(ARG001)
397-397: Unused function argument: is_compiling
(ARG001)
fla/layers/stickbreaking_attn.py
39-41: Avoid specifying long messages outside the exception class
(TRY003)
73-73: Unused method argument: output_attentions
(ARG002)
86-86: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
fla/ops/stickbreaking_attn/__init__.py
4-7: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
♻️ Duplicate comments (4)
fla/ops/stickbreaking_attn/parallel.py (2)
52-52: Remove the no-op call totl.num_programs(2).The return value is not used, so this call has no effect and should be removed.
Apply this diff:
batch_id = 0 if IS_VARLEN else tl.program_id(0) head_pid = tl.program_id(1) prog_id = tl.program_id(2) - tl.num_programs(2) if IS_VARLEN:
632-633: CRITICAL: Backward buffer allocation causes massive memory usage.The
dkanddvtensors are allocated with shape(M_count, batch_size, token_size, num_heads, dim_size), whereM_count = triton.cdiv(token_size, BLOCK_M). For the advertised benchmark config (batch=8, seq=4096, heads=32, dim=64, BLOCK_M=64), this allocates:
- M_count = 64 blocks
- Memory per tensor ≈ 64 × 8 × 4096 × 32 × 64 × 2 bytes ≈ 64 GB in fp16
This is roughly 64× larger than the actual KV tensors and will OOM before the backward kernel runs, blocking any training at the advertised scale.
Solution: Accumulate gradients directly into
[batch_size, token_size, num_heads, dim_size]buffers using atomic adds, block-level reductions, or a two-pass strategy with a shared staging buffer per program. Memory must scale linearly with model size.Do you want me to generate a fix using atomic operations?
fla/layers/stickbreaking_attn.py (2)
43-52: Validate head/count divisibility and gate unsupported GQAWe still allow configurations where
hidden_sizeis not divisible bynum_heads, ornum_headsis not divisible bynum_kv_heads. Worse, the kernel does not implement grouped KV, so mismatched head counts silently mis-shard tensors. Please add the explicit guards and reject unsupported GQA/MQA setups.self.hidden_size = hidden_size self.num_heads = num_heads if num_kv_heads is None: self.num_kv_heads = self.num_heads else: self.num_kv_heads = num_kv_heads - self.num_kv_groups = self.num_heads // self.num_kv_heads - self.head_dim = self.hidden_size // self.num_heads + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})." + ) + if self.num_heads % self.num_kv_heads != 0: + raise ValueError( + f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})." + ) + if self.num_kv_heads != self.num_heads: + raise NotImplementedError( + "StickBreakingAttention does not support grouped KV (num_kv_heads != num_heads) yet." + ) + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.head_dim = self.hidden_size // self.num_heads
77-104: Do not ignoreattention_maskShape-checking and then dropping the mask leaves padded tokens in play, so outputs/gradients depend on garbage data. Either wire the mask into cu_seqlens/logits or clearly reject masked batches to avoid corrupt results.
if attention_mask is not None: assert len(attention_mask.shape) == 2, ( "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " "for padding purposes (0 indicating padding). " "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." ) + if not torch.all(attention_mask.to(torch.bool)): + raise NotImplementedError( + "Padding masks are not yet supported. Pack sequences via `cu_seqlens` before calling " + "StickBreakingAttention." + )
🧹 Nitpick comments (6)
sba_code/stickbreaking_attention/sb_varlen/softplus.py (2)
1-1: Remove unused import.
torchis imported but never used in this file.Apply this diff:
-import torch import triton from triton import language as tl
35-52: Consider consolidating duplicate softplus implementations.This file implements softplus with inline assembly, while
fla/ops/stickbreaking_attn/parallel.py(lines 390-392) has a simpler implementation:tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x). Having two implementations makes maintenance harder and can lead to inconsistencies.Can you clarify whether the inline-assembly path provides measurable performance benefits over the simpler implementation? If not, consider using the simpler version consistently across the codebase.
fla/ops/stickbreaking_attn/parallel.py (1)
390-392: Remove the unusedis_compilingparameter.The
is_compilingparameter is declared but never used in the function body.Apply this diff:
@triton.jit -def softplus(x, is_compiling: tl.constexpr = False): +def softplus(x): return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x)sba_code/tests/test_varlen.py (2)
4-4: Remove unused import.
torch.nn.functional as Fis imported but never used in the test file.Apply this diff:
import torch import pytest import math -from torch.nn import functional as F from stickbreaking_attention.sb_varlen import sb_attn_varlen
16-16: Consider addingstrict=Trueto zip for safer iteration.For Python 3.10+, using
strict=Trueensures all iterables have the same length and raises an error if they don't, preventing silent bugs.Apply this diff:
- for q_chunk, k_chunk, v_chunk in zip(q.split(splits, 1), k.split(splits, 1), v.split(splits, 1)): + for q_chunk, k_chunk, v_chunk in zip(q.split(splits, 1), k.split(splits, 1), v.split(splits, 1), strict=True):sba_code/stickbreaking_attention/sb_attn/__init__.py (1)
57-60: Remove or document the unusedzero_startparameter.The
zero_startparameter is accepted but never used in the function. Either remove it or document why it's present (e.g., for API compatibility).If it's for compatibility with
sb_attn_varlen, consider adding a docstring explaining this.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (23)
fla/layers/stickbreaking_attn.py(1 hunks)fla/ops/stickbreaking_attn/naive.py(1 hunks)fla/ops/stickbreaking_attn/parallel.py(1 hunks)sba_code/.gitignore(1 hunks)sba_code/LICENSE(1 hunks)sba_code/README.md(1 hunks)sba_code/benchmarks/attn.py(1 hunks)sba_code/benchmarks/varlen.py(1 hunks)sba_code/load_model_with_dolomite_demo.py(1 hunks)sba_code/setup.py(1 hunks)sba_code/stickbreaking_attention/__init__.py(1 hunks)sba_code/stickbreaking_attention/sb_attn/__init__.py(1 hunks)sba_code/stickbreaking_attention/sb_attn/sb_bwd.py(1 hunks)sba_code/stickbreaking_attention/sb_attn/sb_fwd.py(1 hunks)sba_code/stickbreaking_attention/sb_ref.py(1 hunks)sba_code/stickbreaking_attention/sb_varlen/__init__.py(1 hunks)sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py(1 hunks)sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py(1 hunks)sba_code/stickbreaking_attention/sb_varlen/softplus.py(1 hunks)sba_code/stickbreaking_attention/utils.py(1 hunks)sba_code/tests/test_attn.py(1 hunks)sba_code/tests/test_varlen.py(1 hunks)tests/ops/test_stickbreaking_attn.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- sba_code/LICENSE
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/ops/test_stickbreaking_attn.py
- fla/ops/stickbreaking_attn/naive.py
🧰 Additional context used
🧬 Code graph analysis (15)
sba_code/stickbreaking_attention/sb_attn/__init__.py (4)
sba_code/stickbreaking_attention/sb_attn/sb_bwd.py (1)
_bwd(163-205)sba_code/stickbreaking_attention/sb_attn/sb_fwd.py (1)
_fwd(139-174)sba_code/stickbreaking_attention/sb_varlen/__init__.py (3)
StickBreakingAttention(24-68)forward(27-46)backward(49-68)fla/ops/stickbreaking_attn/parallel.py (2)
forward(681-693)backward(696-699)
sba_code/stickbreaking_attention/sb_ref.py (1)
fla/modules/activations.py (1)
logsigmoid(200-201)
fla/ops/stickbreaking_attn/parallel.py (2)
fla/ops/utils/index.py (1)
prepare_chunk_indices(114-119)fla/layers/stickbreaking_attn.py (1)
forward(68-108)
sba_code/benchmarks/varlen.py (5)
sba_code/stickbreaking_attention/sb_varlen/__init__.py (1)
sb_attn_varlen(71-78)fla/modules/rotary.py (1)
rotate_half(16-22)sba_code/stickbreaking_attention/sb_ref.py (1)
stickbreaking(8-25)sba_code/tests/test_varlen.py (1)
ref_fwd(10-27)sba_code/benchmarks/attn.py (2)
tri_fwdbwd(13-20)flash_fwdbwd(22-30)
sba_code/stickbreaking_attention/sb_varlen/softplus.py (1)
fla/ops/stickbreaking_attn/parallel.py (1)
softplus(391-392)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (3)
fla/ops/stickbreaking_attn/parallel.py (4)
softplus(391-392)load_kv(331-343)compute_block(347-387)backward(696-699)sba_code/stickbreaking_attention/sb_varlen/softplus.py (1)
softplus(36-52)sba_code/stickbreaking_attention/utils.py (1)
custom_op(20-39)
sba_code/stickbreaking_attention/sb_varlen/__init__.py (2)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (1)
varlen_fwd(413-451)sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py (1)
varlen_bwd(484-539)
sba_code/tests/test_attn.py (3)
sba_code/stickbreaking_attention/sb_attn/__init__.py (1)
sb_attn(57-60)sba_code/stickbreaking_attention/sb_ref.py (1)
stickbreaking(8-25)sba_code/tests/test_varlen.py (3)
test_varlen(75-110)assert_close(43-58)ref_fwd(10-27)
sba_code/benchmarks/attn.py (3)
sba_code/stickbreaking_attention/sb_attn/__init__.py (1)
sb_attn(57-60)fla/modules/rotary.py (1)
rotate_half(16-22)sba_code/benchmarks/varlen.py (3)
tri_fwdbwd(42-53)flash_fwdbwd(55-72)fun_(119-121)
sba_code/stickbreaking_attention/__init__.py (2)
sba_code/stickbreaking_attention/sb_attn/__init__.py (1)
sb_attn(57-60)sba_code/stickbreaking_attention/sb_varlen/__init__.py (1)
sb_attn_varlen(71-78)
sba_code/tests/test_varlen.py (2)
sba_code/stickbreaking_attention/sb_varlen/__init__.py (2)
sb_attn_varlen(71-78)backward(49-68)sba_code/stickbreaking_attention/sb_ref.py (1)
stickbreaking(8-25)
sba_code/stickbreaking_attention/sb_attn/sb_fwd.py (3)
sba_code/stickbreaking_attention/utils.py (1)
custom_op(20-39)sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (2)
_forward_one_row(82-215)_forward(233-410)sba_code/stickbreaking_attention/sb_varlen/softplus.py (1)
softplus(36-52)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py (2)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (2)
compute_block(27-78)load_kv(11-23)sba_code/stickbreaking_attention/utils.py (1)
custom_op(20-39)
sba_code/stickbreaking_attention/sb_attn/sb_bwd.py (3)
sba_code/stickbreaking_attention/utils.py (1)
custom_op(20-39)sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py (2)
_backward_one_row(298-481)_backward(92-294)sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py (2)
compute_block(27-78)load_kv(11-23)
fla/layers/stickbreaking_attn.py (3)
fla/modules/layernorm.py (1)
RMSNorm(1063-1110)fla/ops/stickbreaking_attn/parallel.py (2)
parallel_stickbreaking_attn(702-711)forward(681-693)fla/models/stickbreaking_attn/modeling_stickbreaking_attn.py (3)
forward(62-99)forward(167-240)forward(275-340)
🪛 Flake8 (7.3.0)
sba_code/load_model_with_dolomite_demo.py
[error] 5-5: redefinition of unused 'hf_models' from line 2
(F811)
sba_code/benchmarks/varlen.py
[error] 2-2: 'pytest' imported but unused
(F401)
[error] 8-8: 'transformers.models.llama.modeling_llama.apply_rotary_pos_emb' imported but unused
(F401)
[error] 89-89: f-string is missing placeholders
(F541)
[error] 111-111: do not assign a lambda expression, use a def
(E731)
[error] 113-113: do not assign a lambda expression, use a def
(E731)
[error] 117-117: do not assign a lambda expression, use a def
(E731)
sba_code/stickbreaking_attention/sb_varlen/softplus.py
[error] 1-1: 'torch' imported but unused
(F401)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py
[error] 224-224: unexpected indentation (comment)
(E116)
[error] 225-225: unexpected indentation (comment)
(E116)
[error] 226-226: unexpected indentation (comment)
(E116)
[error] 227-227: unexpected indentation (comment)
(E116)
sba_code/stickbreaking_attention/sb_varlen/__init__.py
[error] 7-7: 'torch.nn.functional as F' imported but unused
(F401)
sba_code/benchmarks/attn.py
[error] 2-2: 'pytest' imported but unused
(F401)
[error] 4-4: 'torch.nn.functional as F' imported but unused
(F401)
[error] 9-9: 'transformers.models.llama.modeling_llama.apply_rotary_pos_emb' imported but unused
(F401)
[error] 57-57: f-string is missing placeholders
(F541)
[error] 76-76: do not assign a lambda expression, use a def
(E731)
[error] 79-79: do not assign a lambda expression, use a def
(E731)
[error] 82-82: do not assign a lambda expression, use a def
(E731)
sba_code/stickbreaking_attention/__init__.py
[error] 1-1: '.sb_attn.sb_attn' imported but unused
(F401)
[error] 2-2: '.sb_varlen.sb_attn_varlen' imported but unused
(F401)
sba_code/tests/test_varlen.py
[error] 4-4: 'torch.nn.functional as F' imported but unused
(F401)
sba_code/stickbreaking_attention/sb_attn/sb_fwd.py
[error] 7-7: '..sb_varlen.softplus.softplus' imported but unused
(F401)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py
[error] 14-14: missing whitespace around parameter equals
(E252)
[error] 14-14: missing whitespace around parameter equals
(E252)
[error] 43-43: expected an indented block (comment)
(E115)
[error] 44-44: indentation is not a multiple of 4
(E111)
[error] 44-44: over-indented
(E117)
[error] 45-45: indentation is not a multiple of 4
(E111)
[error] 46-46: indentation is not a multiple of 4
(E111)
[error] 47-47: indentation is not a multiple of 4
(E111)
[error] 48-48: indentation is not a multiple of 4
(E111)
[error] 49-49: indentation is not a multiple of 4
(E111)
[error] 50-50: indentation is not a multiple of 4
(E111)
[error] 51-51: indentation is not a multiple of 4
(E111)
[error] 58-58: continuation line under-indented for visual indent
(E128)
[error] 58-58: missing whitespace around parameter equals
(E252)
[error] 58-58: missing whitespace around parameter equals
(E252)
[error] 502-502: local variable 'N_count' is assigned to but never used
(F841)
sba_code/stickbreaking_attention/sb_attn/sb_bwd.py
[error] 7-7: '..sb_varlen.sb_varlen_fwd.compute_block' imported but unused
(F401)
[error] 7-7: '..sb_varlen.sb_varlen_fwd.load_kv' imported but unused
(F401)
🪛 GitHub Actions: lint
sba_code/benchmarks/varlen.py
[error] 110-110: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 112-112: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 116-116: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
sba_code/setup.py
[error] 1-1: End-of-file fixer modified file (end-of-file-fixer).
sba_code/tests/test_attn.py
[error] 1-1: Trailing whitespace found and removed by pre-commit (trailing-whitespace).
sba_code/benchmarks/attn.py
[error] 75-75: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 78-78: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 81-81: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
sba_code/tests/test_varlen.py
[error] 1-1: Trailing whitespace found and removed by pre-commit (trailing-whitespace).
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py
[error] 501-501: F841 Local variable N_count is assigned to but never used. (pre-commit: ruff-check)
[error] 1-1: Trailing whitespace found and removed by pre-commit (trailing-whitespace).
[error] 1-1: End-of-file fixer modified file (end-of-file-fixer).
sba_code/README.md
[error] 1-1: Trailing whitespace found and removed by pre-commit (trailing-whitespace).
🪛 markdownlint-cli2 (0.18.1)
sba_code/README.md
4-4: Link text should be descriptive
(MD059, descriptive-link-text)
🪛 Ruff (0.14.3)
sba_code/stickbreaking_attention/sb_attn/__init__.py
57-57: Unused function argument: zero_start
(ARG001)
fla/ops/stickbreaking_attn/parallel.py
31-31: Unused function argument: batch_size
(ARG001)
132-132: Unused function argument: NO_N_MASK
(ARG001)
136-136: Unused function argument: no_grad
(ARG001)
138-138: Unused function argument: return_attention
(ARG001)
391-391: Unused function argument: is_compiling
(ARG001)
sba_code/benchmarks/varlen.py
24-24: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
26-26: Unpacked variable rem is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
37-37: Unused function argument: do
(ARG001)
42-42: Unused function argument: do
(ARG001)
47-47: Unpacked variable rem is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
55-55: Unused function argument: do
(ARG001)
89-89: f-string without any placeholders
Remove extraneous f prefix
(F541)
111-111: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
113-113: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
117-117: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
121-121: Unpacked variable dq is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
121-121: Unpacked variable dk is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
121-121: Unpacked variable dv is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
sba_code/stickbreaking_attention/utils.py
21-21: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
22-22: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py
113-113: Unused function argument: NO_N_MASK
(ARG001)
117-117: Unused function argument: no_grad
(ARG001)
262-262: Unused function argument: batch_size
(ARG001)
263-263: Unused function argument: token_size
(ARG001)
sba_code/benchmarks/attn.py
13-13: Unused function argument: do
(ARG001)
17-17: Unpacked variable rem is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
22-22: Unused function argument: do
(ARG001)
32-32: Unused function argument: do
(ARG001)
57-57: f-string without any placeholders
Remove extraneous f prefix
(F541)
76-76: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
79-79: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
82-82: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
87-87: Unpacked variable dq is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
87-87: Unpacked variable dk is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
87-87: Unpacked variable dv is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
sba_code/tests/test_varlen.py
16-16: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
sba_code/stickbreaking_attention/sb_attn/sb_fwd.py
53-53: Unused function argument: batch_size
(ARG001)
56-56: Unused function argument: num_heads
(ARG001)
sba_code/stickbreaking_attention/sb_varlen/sb_varlen_bwd.py
14-14: Do not perform function call tl.constexpr in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
(B008)
57-57: Unused function argument: Lock_ptr
(ARG001)
57-57: Unused function argument: Count_ptr
(ARG001)
58-58: Unused function argument: EVICTION_POLICY
(ARG001)
133-133: Unused function argument: batch_size
(ARG001)
134-134: Unused function argument: token_size
(ARG001)
138-138: Unused function argument: BLOCK_CSL
(ARG001)
141-141: Unused function argument: NO_N_MASK
(ARG001)
502-502: Local variable N_count is assigned to but never used
Remove assignment to unused variable N_count
(F841)
sba_code/stickbreaking_attention/sb_attn/sb_bwd.py
69-69: Unused function argument: batch_size
(ARG001)
72-72: Unused function argument: num_heads
(ARG001)
76-76: Unused function argument: NO_N_MASK
(ARG001)
fla/layers/stickbreaking_attn.py
39-41: Avoid specifying long messages outside the exception class
(TRY003)
73-73: Unused method argument: output_attentions
(ARG002)
85-85: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (3)
sba_code/.gitignore (1)
1-156: LGTM! Comprehensive and standard.gitignorestructure.This
.gitignorefile is well-formed and follows Python project best practices. It covers Python artifacts, packaging/distribution outputs, virtual environments, IDE configurations, test/coverage reports, type checker caches, and various build system artifacts. The patterns are syntactically correct and appropriately scoped.As a defensive practice, having a local
.gitignorein thesba_code/sub-package is good, even though it may partially overlap with a root-level.gitignorein the main repository.sba_code/stickbreaking_attention/__init__.py (1)
1-2: LGTM!The imports are re-exported to form the public API of the package. The static analysis warnings about unused imports are false positives in this context.
sba_code/stickbreaking_attention/sb_ref.py (1)
7-25: LGTM!The reference implementation is clear and follows the stick-breaking attention algorithm correctly. This provides a good baseline for testing the optimized implementations.
sba_code/benchmarks/attn.py
Outdated
| position_ids = torch.arange(q.size(1), device=device, dtype=torch.int32)[None, :] | ||
| if provider == "triton": | ||
| fun = lambda: tri_fwdbwd(do, q, k, v) | ||
| elif provider == "flash": | ||
| rope = LlamaRotaryEmbedding(dim=head_dim).to(device) | ||
| fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v) | ||
| elif provider == "triton_flash": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix rotary position IDs dtype before runtime crash.
LlamaRotaryEmbedding indexes its cached cos/sin tables with position_ids. Passing a CUDA tensor in torch.int32 raises Index tensor must have dtype long, so this benchmark will crash as soon as you hit the flash branch. Please build position_ids with dtype=torch.long (and keep it on CUDA). A minimal change is to drop the explicit dtype or replace it with dtype=torch.long.
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 76-76: do not assign a lambda expression, use a def
(E731)
[error] 79-79: do not assign a lambda expression, use a def
(E731)
🪛 GitHub Actions: lint
[error] 75-75: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 78-78: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
🪛 Ruff (0.14.3)
76-76: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
79-79: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
🤖 Prompt for AI Agents
In sba_code/benchmarks/attn.py around lines 74 to 80, position_ids is created
with dtype=torch.int32 which causes LlamaRotaryEmbedding to throw "Index tensor
must have dtype long" when used on CUDA; change the creation of position_ids to
use dtype=torch.long (or omit dtype so it defaults to long) while keeping it on
the same device so the flash branch can index the cached cos/sin tables without
crashing.
sba_code/benchmarks/attn.py
Outdated
| if provider == "triton": | ||
| fun = lambda: tri_fwdbwd(do, q, k, v) | ||
| elif provider == "flash": | ||
| rope = LlamaRotaryEmbedding(dim=head_dim).to(device) | ||
| fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v) | ||
| elif provider == "triton_flash": | ||
| rope = LlamaRotaryEmbedding(dim=head_dim).to(device) | ||
| fun = lambda: triton_flash_fwdbwd(rope, position_ids, do, q, k, v) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor the provider dispatch into defs to unblock lint.
CI is currently red on this file (E731) because fun is assigned three different lambda expressions. Please rewrite these branches as small local functions (def run_triton(): ...; fun = run_triton, etc.) so the linter passes again.
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 76-76: do not assign a lambda expression, use a def
(E731)
[error] 79-79: do not assign a lambda expression, use a def
(E731)
[error] 82-82: do not assign a lambda expression, use a def
(E731)
🪛 GitHub Actions: lint
[error] 75-75: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 78-78: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 81-81: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
🪛 Ruff (0.14.3)
76-76: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
79-79: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
82-82: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
🤖 Prompt for AI Agents
In sba_code/benchmarks/attn.py around lines 75 to 83, the three branches
currently assign different lambda expressions to fun (triggering E731); replace
each lambda with a small local def: define run_triton() that calls
tri_fwdbwd(do, q, k, v), define run_flash() that constructs or closes over rope
= LlamaRotaryEmbedding(dim=head_dim).to(device) (or create rope before the def)
and calls flash_fwdbwd(rope, position_ids, do, q, k, v), and define
run_triton_flash() similarly for triton_flash_fwdbwd; then set fun = run_triton
/ run_flash / run_triton_flash in each branch so the linter no longer sees
lambda assignments.
sba_code/benchmarks/varlen.py
Outdated
| do = torch.randn((num_heads, total_length, head_dim), device=device, dtype=dtype) | ||
| position_ids = torch.arange(q.size(1), device=device, dtype=torch.int32)[None, :] | ||
|
|
||
| if provider== "reference": | ||
| fun = lambda: ref_fwdbwd(do, q, k, v, lengths) | ||
| elif provider == "triton": | ||
| fun = lambda: tri_fwdbwd(do, q, k, v, lengths) | ||
| elif provider == "flash": | ||
| config = LlamaConfig(max_position_embeddings=length) | ||
| rope = LlamaRotaryEmbedding(config).to(device) | ||
| fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v, lengths) | ||
| if bwd: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use correct rotary positions (total tokens, dtype torch.long).
Here we build position_ids with q.size(1) and dtype=torch.int32. In this benchmark q.shape == (total_length, num_heads, head_dim), so q.size(1) is the head count, not the token index. As a result every token shares the same tiny set of angles, and the rotary embedding is completely wrong. Worse, LlamaRotaryEmbedding indexes its cache with position_ids, and a CUDA torch.int32 tensor will raise Index tensor must have dtype long. Please construct per-token positions (e.g. torch.arange(total_length, device=device, dtype=torch.long)[None, :], or restart per sequence via lengths) and keep them in torch.long so rotary is well-defined.
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 111-111: do not assign a lambda expression, use a def
(E731)
[error] 113-113: do not assign a lambda expression, use a def
(E731)
[error] 117-117: do not assign a lambda expression, use a def
(E731)
🪛 GitHub Actions: lint
[error] 110-110: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 112-112: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 116-116: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
🪛 Ruff (0.14.3)
111-111: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
113-113: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
117-117: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
🤖 Prompt for AI Agents
In sba_code/benchmarks/varlen.py around lines 107 to 118, the position_ids are
built from q.size(1) with dtype torch.int32 which is wrong because q.shape is
(total_length, num_heads, head_dim) so q.size(1) is num_heads not token
positions, and int32 will break LlamaRotaryEmbedding indexing; replace this with
per-token positions of length total_length and dtype torch.long (e.g.
torch.arange(total_length, device=device, dtype=torch.long)[None, :]) or, if
sequences restart, construct per-sequence positions using lengths, ensuring the
resulting tensor indexes the rotary cache correctly and matches the token
dimension of q.
sba_code/benchmarks/varlen.py
Outdated
| fun = lambda: ref_fwdbwd(do, q, k, v, lengths) | ||
| elif provider == "triton": | ||
| fun = lambda: tri_fwdbwd(do, q, k, v, lengths) | ||
| elif provider == "flash": | ||
| config = LlamaConfig(max_position_embeddings=length) | ||
| rope = LlamaRotaryEmbedding(config).to(device) | ||
| fun = lambda: flash_fwdbwd(rope, position_ids, do, q, k, v, lengths) | ||
| if bwd: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace the provider lambdas with named callables.
Same lint failure as in attn.py: assigning lambdas to fun violates E731 and is currently breaking CI. Please convert each branch into a small def (or partial) and assign that function instead.
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 111-111: do not assign a lambda expression, use a def
(E731)
[error] 113-113: do not assign a lambda expression, use a def
(E731)
[error] 117-117: do not assign a lambda expression, use a def
(E731)
🪛 GitHub Actions: lint
[error] 112-112: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
[error] 116-116: E731 Do not assign a lambda expression, use a def. (pre-commit: ruff-check)
🪛 Ruff (0.14.3)
111-111: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
113-113: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
117-117: Do not assign a lambda expression, use a def
Rewrite fun as a def
(E731)
🤖 Prompt for AI Agents
In sba_code/benchmarks/varlen.py around lines 111 to 118, the code assigns
anonymous lambdas to the variable `fun` for each provider branch which triggers
lint error E731; replace each lambda with a named callable (either a small def
function or functools.partial) and assign that named function to `fun`. For
example, create distinct functions like `_ref_fun()`, `_tri_fun()`, and
`_flash_fun()` that capture the same closed-over variables (or use partial to
bind args) and set `fun = _ref_fun` / `fun = _tri_fun` / `fun = _flash_fun`
accordingly so no lambda is assigned. Ensure the flash branch still constructs
`config` and `rope` before defining the callable and leave the subsequent `if
bwd:` logic unchanged.
sba_code/setup.py
Outdated
| def read(fname): | ||
| return open(os.path.join(os.path.dirname(__file__), fname)).read() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Close file handle properly.
The file is opened but never closed, causing a resource leak.
Apply this diff:
def read(fname):
- return open(os.path.join(os.path.dirname(__file__), fname)).read()
+ with open(os.path.join(os.path.dirname(__file__), fname)) as f:
+ return f.read()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def read(fname): | |
| return open(os.path.join(os.path.dirname(__file__), fname)).read() | |
| def read(fname): | |
| with open(os.path.join(os.path.dirname(__file__), fname)) as f: | |
| return f.read() |
🤖 Prompt for AI Agents
In sba_code/setup.py around lines 4 to 5, the read() function opens a file but
never closes it; change it to use a context manager (with open(... ) as f:) and
return f.read() so the file handle is closed automatically; also specify a mode
('r') and an explicit encoding (e.g., 'utf-8') for portability.
sba_code/setup.py
Outdated
| description = "Triton implementation of Stick-breaking attention", | ||
| license = "Apache License", | ||
| keywords = "triton pytorch llm stickbreaking attention", | ||
| url = "https://github.com/shawntan/scattermoe", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the project URL to the correct repository.
The URL points to shawntan/scattermoe instead of the actual stick-breaking attention repository.
Please update line 15 with the correct project URL.
🤖 Prompt for AI Agents
In sba_code/setup.py around line 15, the project URL currently points to
"https://github.com/shawntan/scattermoe"; update that value to the correct
repository URL for the stick-breaking attention project (e.g. replace with
"https://github.com/shawntan/stick-breaking-attention") so the package metadata
references the proper upstream repo.
| dk = torch.zeros_like(k, dtype=torch.bfloat16) | ||
| dv = torch.zeros_like(v, dtype=torch.bfloat16) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stop coercing K/V gradients to bfloat16
Gradients must come back in the same dtype as the forward tensors. For fp16 or fp32 Q/K/V this cast to bfloat16 causes autograd to error out (type mismatch) or silently downcast precision. Reuse the original dtype instead of forcing bfloat16.
- dk = torch.zeros_like(k, dtype=torch.bfloat16)
- dv = torch.zeros_like(v, dtype=torch.bfloat16)
+ dk = torch.zeros_like(k)
+ dv = torch.zeros_like(v)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| dk = torch.zeros_like(k, dtype=torch.bfloat16) | |
| dv = torch.zeros_like(v, dtype=torch.bfloat16) | |
| dk = torch.zeros_like(k) | |
| dv = torch.zeros_like(v) |
🤖 Prompt for AI Agents
In sba_code/stickbreaking_attention/sb_attn/sb_bwd.py around lines 173 to 175,
the backward code currently creates dk/dv with dtype=torch.bfloat16 which forces
gradients to bfloat16 and breaks/autodownsamples fp16/fp32 tensors; change the
allocation to reuse the original forward dtypes (e.g., torch.zeros_like(k) /
torch.zeros_like(v) or torch.zeros_like(k, dtype=k.dtype) and
torch.zeros_like(v, dtype=v.dtype)) so gradients match the forward tensors'
dtype and device.
| O_head_seq_ptr = O_ptr + stride_oh * head_id + stride_om * seq_start_offset | ||
| R_head_seq_ptr = R_ptr + stride_rh * head_id + stride_rm * seq_start_offset | ||
| A_head_seq_ptr = A_ptr + stride_ah * head_id + stride_am * seq_start_offset | ||
| W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset | ||
| _forward_one_row( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect stride when computing W_head_seq_ptr.
Inside both head-block branches, W_head_seq_ptr is advanced with stride_am, i.e. the stride for the accumulator tensor A_ptr. As soon as return_attention=True, this writes attention weights using the wrong stride (and can walk into the A_ptr buffer), so the returned attention matrix is corrupted and may stomp unrelated memory. Please use the dedicated stride_wm for the W tensor when applying the sequence offset.
- W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset
+ W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_wm * seq_start_offset
...
- W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_am * seq_start_offset
+ W_head_seq_ptr = W_ptr + stride_wh * head_id + stride_wm * seq_start_offsetAlso applies to: 367-371
🤖 Prompt for AI Agents
In sba_code/stickbreaking_attention/sb_varlen/sb_varlen_fwd.py around lines
315-319 (also apply same fix at 367-371): W_head_seq_ptr is computed using
stride_am (accumulator stride) instead of the W tensor's sequence stride
stride_wm, which causes attention writes to use the wrong stride and can corrupt
A_ptr or other memory; fix by replacing the use of stride_am with stride_wm when
adding seq_start_offset to W_ptr so W_head_seq_ptr uses stride_wh * head_id +
stride_wm * seq_start_offset (mirror the pattern used for O/R/A but with the
correct W sequence stride).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/ops/stickbreaking_attn/parallel.py(1 hunks)fla/ops/stickbreaking_attn/softplus.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/stickbreaking_attn/parallel.py (2)
fla/ops/stickbreaking_attn/softplus.py (1)
softplus(38-49)fla/ops/utils/index.py (1)
prepare_chunk_indices(114-119)
🪛 Ruff (0.14.3)
fla/ops/stickbreaking_attn/parallel.py
94-94: Unused function argument: NO_N_MASK
(ARG001)
98-98: Unused function argument: no_grad
(ARG001)
100-100: Unused function argument: return_attention
(ARG001)
344-344: Unused function argument: B
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (1)
fla/ops/stickbreaking_attn/softplus.py (1)
8-49: Inline softplus port looks great — matches the referenced kernel’s log₂-domain math and keeps the JIT-time constants nicely cached, so I don’t see any issues here.
| dk = torch.zeros((M_count, B, T, H, D), dtype=k.dtype, device=k.device) | ||
| dv = torch.zeros((M_count, B, T, H, D), dtype=v.dtype, device=v.device) | ||
|
|
||
| BLOCK_D = triton.next_power_of_2(D) | ||
|
|
||
| NO_M_MASK = (T % BT) == 0 | ||
| NO_N_MASK = (T % BS) == 0 | ||
| if cu_seqlens is not None: | ||
| NO_M_MASK = False | ||
| NO_N_MASK = False | ||
|
|
||
| parallel_stickbreaking_attn_bwd_kernel[grid]( | ||
| do, | ||
| dr, | ||
| neg_log_acc, | ||
| q, | ||
| k, | ||
| v, | ||
| dq, | ||
| dk, | ||
| dv, | ||
| CU_ptr=cu_seqlens if cu_seqlens is not None else q, | ||
| CI_ptr=CI if CI is not None else q, | ||
| scale=scale, | ||
| B=B, | ||
| T=T, | ||
| head_size=D, | ||
| H=H, | ||
| BT=BT, | ||
| BS=BS, | ||
| BLOCK_D=BLOCK_D, | ||
| NO_D_MASK=D == BLOCK_D, | ||
| NO_M_MASK=NO_M_MASK, | ||
| NO_N_MASK=NO_N_MASK, | ||
| ALLOW_TF32=ALLOW_TF32, | ||
| acc_dtype=tl.float32, | ||
| IS_VARLEN=cu_seqlens is not None, | ||
| ) | ||
|
|
||
| dk_final = dk.sum(0) | ||
| dv_final = dv.sum(0) | ||
|
|
||
| return dq.to(q.dtype), dk_final, dv_final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewrite DK/DV accumulation to avoid 64× memory blow-up — allocating dk/dv as (M_count, B, T, H, D) temporarily multiplies the KV footprint by ~T / BT. For a single advertised run (batch = 8, seq = 4096, heads = 32, dim = 64, BT = 64) this comes out to ~8.6 GB per tensor in fp16, so you’ll OOM before the backward kernel even launches. Please accumulate directly into the [B, T, H, D] buffers (e.g. atomic adds or a per-block staging slab that you immediately fold into the final grad) so memory scales linearly with the model size.
1700f8d to
f4082b3
Compare
python -m benchmarks.benchmark_training_throughput --name stickbreaking_attn --batch_size 8 --seq_len 4096Summary by CodeRabbit
New Features
Tests