Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/optimum-executorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d03e90c2cd9048e6d9a75285c0355f033cd016fc
0123293118efb08ac4ffc4fefe9d330201465c93
37 changes: 19 additions & 18 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
)
triton_kernel_mode = mode

return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
# return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
return [ReplaceEdgeOpWithTritonOpPass()]

@classmethod
def get_aoti_compile_options(
Expand Down Expand Up @@ -134,20 +135,20 @@ def get_aoti_compile_options(

return options

@classmethod
def get_extra_aoti_compile_context_manager(cls):
"""
Return SDPA MATH backend context manager for CUDA compilation.

This context manager plays as a fallback solution for any remaining PyTorch SDPA
operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.

Note:
- If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
this context manager will have no effect on those ops (they are no longer
PyTorch SDPA ops).
- If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
context manager will force them to use the MATH backend, causing them to
be automatically decomposed during compilation.
"""
return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])
# @classmethod
# def get_extra_aoti_compile_context_manager(cls):
# """
# Return SDPA MATH backend context manager for CUDA compilation.

# This context manager plays as a fallback solution for any remaining PyTorch SDPA
# operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.

# Note:
# - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
# this context manager will have no effect on those ops (they are no longer
# PyTorch SDPA ops).
# - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
# context manager will force them to use the MATH backend, causing them to
# be automatically decomposed during compilation.
# """
# return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])
Loading
Loading