Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions jax/_src/cudnn/scaled_matmul_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import lax as lax_internal
from jax._src import numpy as jnp
from jax._src import tree_util
from jax._src.custom_derivatives import custom_vjp
Expand Down Expand Up @@ -104,6 +105,23 @@ def _scaled_matmul_gpu_lowering(
return [out.result]


def _scaled_matmul_rocm_lowering(
ctx, a, b, a_scales, b_scales, preferred_element_type
):
def _scaled_dot_lowering_impl(lhs, rhs, lhs_scales, rhs_scales):
return lax_internal.scaled_dot(
lhs,
rhs,
lhs_scale=lhs_scales,
rhs_scale=rhs_scales,
dimension_numbers=(((2,), (2,)), ((0,), (0,))),
preferred_element_type=preferred_element_type,
)
return mlir.lower_fun(_scaled_dot_lowering_impl, multiple_results=False)(
ctx, a, b, a_scales, b_scales
)


def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type):
batch, non_contracting_lhs, contracting_lhs = a.shape
_, non_contracting_rhs, _ = b.shape
Expand All @@ -122,6 +140,11 @@ def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type):
_scaled_matmul_gpu_lowering,
platform="cuda",
)
mlir.register_lowering(
_scaled_matmul_p,
_scaled_matmul_rocm_lowering,
platform="rocm",
)

_scaled_matmul_p_wrapper = core.Primitive("scaled_matmul_wrapper")
_scaled_matmul_p_wrapper.multiple_results = True
Expand Down