diff --git a/.ci/docker/ci_commit_pins/optimum-executorch.txt b/.ci/docker/ci_commit_pins/optimum-executorch.txt index df87f35a69d..156ff2f3c82 100644 --- a/.ci/docker/ci_commit_pins/optimum-executorch.txt +++ b/.ci/docker/ci_commit_pins/optimum-executorch.txt @@ -1 +1 @@ -d03e90c2cd9048e6d9a75285c0355f033cd016fc +0123293118efb08ac4ffc4fefe9d330201465c93 diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index f0d3a000ec0..9204ecaecda 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -68,7 +68,8 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any] ) triton_kernel_mode = mode - return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else [] + # return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else [] + return [ReplaceEdgeOpWithTritonOpPass()] @classmethod def get_aoti_compile_options( @@ -134,20 +135,20 @@ def get_aoti_compile_options( return options - @classmethod - def get_extra_aoti_compile_context_manager(cls): - """ - Return SDPA MATH backend context manager for CUDA compilation. - - This context manager plays as a fallback solution for any remaining PyTorch SDPA - operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. - - Note: - - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, - this context manager will have no effect on those ops (they are no longer - PyTorch SDPA ops). - - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this - context manager will force them to use the MATH backend, causing them to - be automatically decomposed during compilation. - """ - return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) + # @classmethod + # def get_extra_aoti_compile_context_manager(cls): + # """ + # Return SDPA MATH backend context manager for CUDA compilation. + + # This context manager plays as a fallback solution for any remaining PyTorch SDPA + # operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation. + + # Note: + # - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass, + # this context manager will have no effect on those ops (they are no longer + # PyTorch SDPA ops). + # - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this + # context manager will force them to use the MATH backend, causing them to + # be automatically decomposed during compilation. + # """ + # return torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index 7e8eb1444df..e3c444f093e 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -22,17 +22,22 @@ from torch.library import triton_op, wrap_triton -def _next_power_of_2(n: int) -> int: - """Round up to the next power of 2.""" - if n <= 0: - return 1 - if n & (n - 1) == 0: - return n +def _is_power_of_2(n: int) -> bool: + """Check if n is a power of 2.""" + return n > 0 and (n & (n - 1)) == 0 - power = 1 - while power < n: - power <<= 1 - return power + +def _next_power_of_2(x: int) -> int: + """Get the next power of 2 >= x, clamped to [16, 256].""" + if x <= 16: + return 16 + if x <= 32: + return 32 + if x <= 64: + return 64 + if x <= 128: + return 128 + return 256 def _validate_qkv_shapes( @@ -77,31 +82,21 @@ def _validate_qkv_shapes( return B_q, H_q, L_q, L_kv_k, D_q, D_k -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=4, num_warps=8), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_stages=4, num_warps=8), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=3, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=1, num_warps=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=1, num_warps=2), - ], - key=["L_Q", "L_KV", "HEAD_DIM"], -) +# ============================================================================== +# Non-power-of-2 HEAD_DIM kernel +# ============================================================================== @triton.jit -def _sdpa_fwd_kernel( +def _sdpa_fwd_kernel_non_pow2( q_ptr, k_ptr, v_ptr, - mask_ptr, o_ptr, + mask_ptr, B, H, - L_Q, # Query sequence length - L_KV, # Key/Value sequence length - HEAD_DIM, # Actual head dimension (may not be power of 2) + LQ, + LK, + HEAD_DIM, stride_qb, stride_qh, stride_ql, @@ -114,118 +109,401 @@ def _sdpa_fwd_kernel( stride_vh, stride_vl, stride_vd, - stride_mb, - stride_mh, - stride_ml, - stride_mn, stride_ob, stride_oh, stride_ol, stride_od, - sm_scale, - IS_CAUSAL: tl.constexpr, - HAS_MASK: tl.constexpr, + stride_mb, + stride_mh, + stride_mlq, + stride_mlk, + scale, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - HEAD_DIM_CE: tl.constexpr, # Rounded up for tl.arange + BLOCK_D: tl.constexpr, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, ): """ - Fused SDPA kernel that handles different sequence lengths for Q and K/V. - - Q shape: [B, H, L_Q, D] - K/V shape: [B, H, L_KV, D] - Output shape: [B, H, L_Q, D] + SDPA forward kernel for non-power-of-2 HEAD_DIM. + Uses dynamic masking to handle arbitrary head dimensions. """ - # Program IDs - pid_m = tl.program_id(axis=0) # along query length - pid_hz = tl.program_id(axis=1) # flattened batch*head - off_b = pid_hz // H - off_h = pid_hz % H - # Compute ranges for queries - start_m = pid_m * BLOCK_M - offs_m = start_m + tl.arange(0, BLOCK_M) - offs_d = tl.arange(0, HEAD_DIM_CE) - mask_m = offs_m < L_Q # Mask based on query length - # Base pointers for this (b, h) - q_base = q_ptr + off_b * stride_qb + off_h * stride_qh - k_base = k_ptr + off_b * stride_kb + off_h * stride_kh - v_base = v_ptr + off_b * stride_vb + off_h * stride_vh - o_base = o_ptr + off_b * stride_ob + off_h * stride_oh - # Mask base pointer (if provided) - if HAS_MASK: - mask_base = mask_ptr + off_b * stride_mb + off_h * stride_mh - # Mask for actual head dimension (HEAD_DIM may not be power of 2) - mask_d = offs_d < HEAD_DIM - # Make head-dim addresses compiler-friendly - offs_d_ctg = tl.max_contiguous(tl.multiple_of(offs_d, 16), HEAD_DIM_CE) - # Load Q tile [BLOCK_M, HEAD_DIM] - coalesced along HEAD_DIM - q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d_ctg[None, :] * stride_qd) - q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) - q = q.to(tl.bfloat16) - # Initialize accumulators and softmax stats - acc = tl.zeros((BLOCK_M, HEAD_DIM_CE), dtype=tl.float32) + pid_m = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + + b = pid_bh // H + h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + d_mask = offs_d < HEAD_DIM + q_row_mask = offs_m < LQ + + q_base = q_ptr + b * stride_qb + h * stride_qh + k_base = k_ptr + b * stride_kb + h * stride_kh + v_base = v_ptr + b * stride_vb + h * stride_vh + o_base = o_ptr + b * stride_ob + h * stride_oh + + q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd) + q = tl.load(q_ptrs, mask=q_row_mask[:, None] & d_mask[None, :], other=0.0) + + acc = tl.zeros((BLOCK_M, BLOCK_D), dtype=tl.float32) m_i = tl.full((BLOCK_M,), -float("inf"), dtype=tl.float32) - l_i = tl.zeros((BLOCK_M,), dtype=tl.float32) - # Convert to base-2 scale for exp2 - qk_scale = sm_scale * 1.4426950408889634 - # Loop over keys/values along L_KV dimension (not L_Q!) - for start_n in tl.range(0, L_KV, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_n = offs_n < L_KV # Mask based on key/value length - # Load K tile [BLOCK_N, HEAD_DIM] (contiguous along HEAD_DIM) - k_ptrs = k_base + ( - offs_n[:, None] * stride_kl + offs_d_ctg[None, :] * stride_kd - ) - k = tl.load(k_ptrs, mask=mask_n[:, None] & mask_d[None, :], other=0.0) - k = k.to(tl.bfloat16) - # Compute attention logits [BLOCK_M, BLOCK_N] = Q[BM,D] @ K[BN,D]^T - qk = tl.dot(q, tl.trans(k)).to(tl.float32) - qk = qk * qk_scale - # Apply causal mask if needed - # For causal masking with different lengths: position i can attend to position j if i >= j + l_i = tl.full((BLOCK_M,), 1.0, dtype=tl.float32) + + qk_scale_log2 = scale * 1.4426950408889634 + + if HAS_MASK: + mask_b_base = mask_ptr + b * stride_mb + + for start_n in tl.range(0, LK, BLOCK_N, num_stages=2): + kn = start_n + offs_n + kv_col_mask = kn < LK + + k_ptrs = k_base + (kn[:, None] * stride_kl + offs_d[None, :] * stride_kd) + k = tl.load(k_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) + + qk = tl.dot(q, tl.trans(k)) + qk = qk * qk_scale_log2 + if IS_CAUSAL: - causal_mask = offs_m[:, None] >= offs_n[None, :] - qk = tl.where(causal_mask, qk, -float("inf")) - # Apply attention mask if provided + row_abs = offs_m[:, None] + col_abs = kn[None, :] + causal_mask = col_abs > row_abs + qk = tl.where(causal_mask, -float("inf"), qk) + if HAS_MASK: - # Load mask tile [BLOCK_M, BLOCK_N] - # Mask shape should be [B, H, L_Q, L_KV] - mask_ptrs = mask_base + ( - offs_m[:, None] * stride_ml + offs_n[None, :] * stride_mn + mask_ptrs = ( + mask_b_base + offs_m[:, None] * stride_mlq + kn[None, :] * stride_mlk ) - attn_mask = tl.load( - mask_ptrs, - mask=mask_m[:, None] & mask_n[None, :], - other=0.0, - ) - # Convert boolean mask to additive mask (-inf for False, 0 for True) - qk = tl.where(attn_mask, qk, -float("inf")) - # Apply OOB masks for both rows and cols - qk = tl.where(mask_n[None, :], qk, -float("inf")) - qk = tl.where(mask_m[:, None], qk, -float("inf")) - # Online softmax + tile_valid = q_row_mask[:, None] & kv_col_mask[None, :] + keep = tl.load(mask_ptrs, mask=tile_valid, other=True) + qk = tl.where(keep, qk, -float("inf")) + + qk = tl.where(kv_col_mask[None, :], qk, -float("inf")) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) p = tl.math.exp2(qk - m_ij[:, None]) l_ij = tl.sum(p, 1) alpha = tl.math.exp2(m_i - m_ij) - # Load V tile [BLOCK_N, HEAD_DIM] (contiguous along HEAD_DIM) - v_ptrs = v_base + ( - offs_n[:, None] * stride_vl + offs_d_ctg[None, :] * stride_vd - ) - v = tl.load(v_ptrs, mask=mask_n[:, None] & mask_d[None, :], other=0.0) - v = v.to(tl.bfloat16) - # Update accumulator + acc = acc * alpha[:, None] - p_bf16 = p.to(tl.bfloat16) - acc = tl.dot(p_bf16, v, acc) - # Update softmax stats + + v_ptrs = v_base + (kn[:, None] * stride_vl + offs_d[None, :] * stride_vd) + v = tl.load(v_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) + + acc = tl.dot(p.to(v.dtype), v, acc) + + l_i = l_i * alpha + l_ij + m_i = m_ij + + out = acc / l_i[:, None] + o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d[None, :] * stride_od) + tl.store(o_ptrs, out.to(tl.bfloat16), mask=q_row_mask[:, None] & d_mask[None, :]) + + +# ============================================================================== +# Power-of-2 HEAD_DIM kernels +# ============================================================================== +@triton.jit +def _sdpa_fwd_kernel_body( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + """ + Shared kernel body for SDPA forward pass. + """ + pid_m = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + b = pid_bh // H + h = pid_bh % H + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n_init = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + + q_ptrs = Q_ptr + ( + b * stride_qb + + h * stride_qh + + (offs_m[:, None] * stride_qm) + + (offs_d[None, :] * stride_qd) + ) + q_mask = (offs_m[:, None] < Lq) & (offs_d[None, :] < HEAD_DIM) + q = tl.load(q_ptrs, mask=q_mask, other=0.0).to(tl.bfloat16) + + m_i = tl.full([BLOCK_M], -float("inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + for start_n in tl.range(0, Lk, BLOCK_N): + offs_n = start_n + offs_n_init + + k_ptrs = K_ptr + ( + b * stride_kb + + h * stride_kh + + (offs_n[:, None] * stride_kn) + + (offs_d[None, :] * stride_kd) + ) + k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) + + qk = tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale + + if HAS_MASK: + mask_ptrs = Mask_ptr + ( + b * stride_mb + + (offs_m[:, None] * stride_mq) + + (offs_n[None, :] * stride_mk) + ) + mn_mask = (offs_m[:, None] < Lq) & (offs_n[None, :] < Lk) + mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) + qk = tl.where(mask_block, qk, -float("inf")) + + if IS_CAUSAL: + abs_m = offs_m[:, None] + abs_n = offs_n[None, :] + causal = abs_n > abs_m + qk = tl.where(causal, -float("inf"), qk) + + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p_f32 = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p_f32, axis=1) + alpha = tl.exp(m_i - m_ij) + + v_ptrs = V_ptr + ( + b * stride_vb + + h * stride_vh + + (offs_n[:, None] * stride_vn) + + (offs_d[None, :] * stride_vd) + ) + v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) + v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) + + p_bf16 = p_f32.to(tl.bfloat16) + acc = acc * alpha[:, None] + tl.dot(p_bf16, v) l_i = l_i * alpha + l_ij m_i = m_ij - # Normalize accumulator by softmax denominator - acc = acc / l_i[:, None] - # Store output [BLOCK_M, HEAD_DIM] - shape matches query - o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d_ctg[None, :] * stride_od) - tl.store(o_ptrs, acc.to(tl.bfloat16), mask=mask_m[:, None] & mask_d[None, :]) + + inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0) + acc = acc * inv_l_i[:, None] + + o_ptrs = O_ptr + ( + b * stride_ob + + h * stride_oh + + (offs_m[:, None] * stride_om) + + (offs_d[None, :] * stride_od) + ) + o_mask = (offs_m[:, None] < Lq) & (offs_d[None, :] < HEAD_DIM) + tl.store(o_ptrs, acc.to(tl.bfloat16), mask=o_mask) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), + ], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL"], +) +@triton.jit +def _sdpa_fwd_kernel_m64( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SDPA kernel with BLOCK_M=64 optimizations. + """ + _sdpa_fwd_kernel_body( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM, + ) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), + ], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL"], +) +@triton.jit +def _sdpa_fwd_kernel_m32( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SDPA kernel with BLOCK_M=32 optimizations for small workloads. + """ + _sdpa_fwd_kernel_body( + Q_ptr, + K_ptr, + V_ptr, + O_ptr, + Mask_ptr, + B, + H, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM, + ) @triton_op("triton::sdpa", mutates_args={}) @@ -240,14 +518,13 @@ def sdpa( enable_gqa: bool = False, ) -> torch.Tensor: """ - Triton fused Scaled Dot-Product Attention with support for different sequence lengths. + Triton fused Scaled Dot-Product Attention with optimized dual-kernel approach. Args: - query: Query tensor with szie [B, H, L_q, D] and dtype torch.bfloat16 + query: Query tensor with size [B, H, L_q, D] and dtype torch.bfloat16 key: Key tensor [B, H, L_kv, D] and dtype torch.bfloat16 value: Value tensor [B, H, L_kv, D] and dtype torch.bfloat16 - attn_mask: Optional attention mask [B, H, L_q, L_kv] or - broadcastable shape (2D: [L_q, L_kv] or 3D: [B, L_q, L_kv]) + attn_mask: Optional attention mask [B, H, L_q, L_kv] with dtype torch.bool dropout_p: must be 0.0 (others are not supported) is_causal: whether to apply causal masking scale: attention scale (default: 1/sqrt(D)) @@ -282,83 +559,199 @@ def sdpa( # Validate and get dimensions B, H, L_q, L_kv, D_q, D_kv = _validate_qkv_shapes(query, key, value) D = D_q # Head dimension + + # Enforce causal masking constraint + if is_causal: + if L_q != L_kv: + raise RuntimeError( + f"Causal masking requires L_q == L_kv; got L_q={L_q}, L_kv={L_kv}." + ) + # Allocate output with query shape - out = torch.empty_like(query) + out = torch.empty((B, H, L_q, D), device=query.device, dtype=query.dtype) + # Element-wise strides - sqb, sqh, sql, sqd = query.stride() - skb, skh, skl, skd = key.stride() - svb, svh, svl, svd = value.stride() - sob, soh, sol, sod = out.stride() - - # Grid: tile queries (M) and batch*heads axis - def grid(META): - return ( - triton.cdiv(L_q, META["BLOCK_M"]), # Based on query length - B * H, - ) + stride_qb, stride_qh, stride_qm, stride_qd = query.stride() + stride_kb, stride_kh, stride_kn, stride_kd = key.stride() + stride_vb, stride_vh, stride_vn, stride_vd = value.stride() + stride_ob, stride_oh, stride_om, stride_od = out.stride() # Scale factor for SDPA sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale + # Handle attention mask - has_mask = attn_mask is not None - if has_mask: - # Expand mask to [B, H, L_q, L_kv] if needed - if attn_mask.dim() == 2: - # [L_q, L_kv] -> [B, H, L_q, L_kv] - attn_mask = attn_mask.unsqueeze(0).unsqueeze(0).expand(B, H, -1, -1) - elif attn_mask.dim() == 3: - # [B, L_q, L_kv] -> [B, H, L_q, L_kv] - attn_mask = attn_mask.unsqueeze(1).expand(-1, H, -1, -1) - - # Validate mask shape - if attn_mask.shape != (B, H, L_q, L_kv): - # Try to expand if broadcastable - attn_mask = attn_mask.expand(B, H, L_q, L_kv) - - smb, smh, sml, smn = attn_mask.stride() + HAS_MASK = attn_mask is not None + Mask_ptr = 0 + stride_mb = stride_mq = stride_mk = 0 + if HAS_MASK: + if attn_mask.dtype != torch.bool: + raise RuntimeError("attn_mask must have dtype torch.bool") + if not attn_mask.is_cuda: + raise RuntimeError("attn_mask must be a CUDA tensor") + if ( + attn_mask.shape[0] != B + or attn_mask.shape[2] != L_q + or attn_mask.shape[3] != L_kv + ): + raise RuntimeError( + f"attn_mask shape mismatch: expected [B={B}, H, L_q={L_q}, L_kv={L_kv}], " + f"got {attn_mask.shape}" + ) + Mask_ptr = attn_mask + stride_mb = attn_mask.stride(0) + stride_mq = attn_mask.stride(2) + stride_mk = attn_mask.stride(3) + + # Grid configuration + def grid(meta): + return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H) + + # Select kernel based on whether HEAD_DIM is power of 2 + if _is_power_of_2(D): + # Use power-of-2 optimized kernels with autotune + # Dynamic kernel selection based on workload + total_ctas_m64 = ((L_q + 63) // 64) * (B * H) + threshold = 4 * 84 # Heuristic threshold for kernel selection + use_small_block = total_ctas_m64 < threshold + + if use_small_block: + wrap_triton(_sdpa_fwd_kernel_m32)[grid]( + query, + key, + value, + out, + Mask_ptr if HAS_MASK else 0, + B, + H, + L_q, + L_kv, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + HEAD_DIM=D, + ) + else: + wrap_triton(_sdpa_fwd_kernel_m64)[grid]( + query, + key, + value, + out, + Mask_ptr if HAS_MASK else 0, + B, + H, + L_q, + L_kv, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + HEAD_DIM=D, + ) else: - # Dummy strides and mask - smb, smh, sml, smn = 0, 0, 0, 0 - attn_mask = torch.empty(0, dtype=torch.bool, device=query.device) - # Round up head dimension to next power of 2 for tile.arange in Triton kernel - HEAD_DIM_CE = _next_power_of_2(D) - # Launch kernel - wrap_triton(_sdpa_fwd_kernel)[grid]( - query, - key, - value, - attn_mask, - out, - B, - H, - L_q, # Query sequence length - L_kv, # Key/Value sequence length - D, # Actual head dimension - sqb, - sqh, - sql, - sqd, - skb, - skh, - skl, - skd, - svb, - svh, - svl, - svd, - smb, - smh, - sml, - smn, - sob, - soh, - sol, - sod, - sm_scale, - IS_CAUSAL=is_causal, - HAS_MASK=has_mask, - HEAD_DIM_CE=HEAD_DIM_CE, # Rounded to power of 2 - ) + # Use non-power-of-2 kernel with dynamic HEAD_DIM masking + BLOCK_D = _next_power_of_2(D) + + if BLOCK_D >= 256: + BLOCK_N = 64 + else: + BLOCK_N = 128 + + BLOCK_M = 32 + num_warps = 4 + num_stages = 2 + + # Handle mask for non-pow2 kernel (different stride layout) + if HAS_MASK: + mask_ptr = attn_mask + stride_mb_np2 = attn_mask.stride(0) + stride_mh_np2 = attn_mask.stride(1) + stride_mlq_np2 = attn_mask.stride(2) + stride_mlk_np2 = attn_mask.stride(3) + else: + mask_ptr = torch.empty((1,), device=query.device, dtype=torch.bool) + stride_mb_np2 = stride_mh_np2 = stride_mlq_np2 = stride_mlk_np2 = 0 + + def grid_non_pow2(meta): + return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H) + + wrap_triton(_sdpa_fwd_kernel_non_pow2)[grid_non_pow2]( + query, + key, + value, + out, + mask_ptr, + B, + H, + L_q, + L_kv, + D, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb_np2, + stride_mh_np2, + stride_mlq_np2, + stride_mlk_np2, + sm_scale, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_D=BLOCK_D, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + num_warps=num_warps, + num_stages=num_stages, + ) + return out