diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 01e8a681e899..29449b30fe4c 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -23,6 +23,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import lax as lax_internal from jax._src import numpy as jnp from jax._src import tree_util from jax._src.custom_derivatives import custom_vjp @@ -104,6 +105,23 @@ def _scaled_matmul_gpu_lowering( return [out.result] +def _scaled_matmul_rocm_lowering( + ctx, a, b, a_scales, b_scales, preferred_element_type + ): + def _scaled_dot_lowering_impl(lhs, rhs, lhs_scales, rhs_scales): + return lax_internal.scaled_dot( + lhs, + rhs, + lhs_scale=lhs_scales, + rhs_scale=rhs_scales, + dimension_numbers=(((2,), (2,)), ((0,), (0,))), + preferred_element_type=preferred_element_type, + ) + return mlir.lower_fun(_scaled_dot_lowering_impl, multiple_results=False)( + ctx, a, b, a_scales, b_scales + ) + + def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type): batch, non_contracting_lhs, contracting_lhs = a.shape _, non_contracting_rhs, _ = b.shape @@ -122,6 +140,11 @@ def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type): _scaled_matmul_gpu_lowering, platform="cuda", ) +mlir.register_lowering( + _scaled_matmul_p, + _scaled_matmul_rocm_lowering, + platform="rocm", +) _scaled_matmul_p_wrapper = core.Primitive("scaled_matmul_wrapper") _scaled_matmul_p_wrapper.multiple_results = True