Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/python/relax/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm import relax
from tvm.script import relax as R
from tvm.script import tirx as T
from tvm.testing import env


def test_pipeline_compile():
Expand Down Expand Up @@ -149,3 +150,83 @@ def test_non_gpu_target_raises_error(target_name, pipeline_func):
target = tvm.target.Target(target_name)
with pytest.raises(ValueError, match="not yet supported"):
pipeline_func(target)


# An elementwise binary op with a scalar constant operand. `R.power(x, const)`
# legalizes to a single elementwise TIR PrimFunc, which the default GPU pipeline
# must schedule (bind to GPU threads). Without a thread binding the kernel
# access memory from the host and `VerifyMemory` rejects it at build time
# ("... is directly accessed by the host memory ... Did you forget to bind?").
@tvm.script.ir_module
class PowerModule:
@R.function
def main(x: R.Tensor((1, 2, 1, 1), dtype="float32")) -> R.Tensor((1, 2, 1, 1), dtype="float32"):
with R.dataflow():
y: R.Tensor((1, 2, 1, 1), dtype="float32") = R.power(x, R.const(2.0, "float32"))
R.output(y)
return y


def _has_thread_binding(func: tvm.tirx.PrimFunc) -> bool:
"""Whether the PrimFunc body contains a GPU thread-binding loop."""
found = False

def _visit(node):
nonlocal found
if isinstance(node, tvm.tirx.For) and node.kind == tvm.tirx.ForKind.THREAD_BINDING:
found = True

tvm.tirx.stmt_functor.post_order_visit(func.body, _visit)
return found


def test_default_cuda_pipeline_schedules_power():
"""The CUDA legalization pipeline thread-binds a legalized elementwise kernel.

Device-free (no GPU required): runs the CUDA `legalize_passes`, which end
right after DLight scheduling, so the only TIR PrimFunc left is the `power`
kernel itself (no later host-side shape helpers to confuse the check). The
kernel must carry a GPU thread binding, otherwise `VerifyMemory` would reject
it during a real build.
"""
target = tvm.target.Target(
"cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32"
)
with target:
seq = tvm.transform.Sequential(relax.pipeline.legalize_passes(target))
mod = seq(PowerModule)

prim_funcs = [
(g_var, func)
for g_var, func in mod.functions_items()
if isinstance(func, tvm.tirx.PrimFunc) and "power" in g_var.name_hint
]
assert prim_funcs, "expected at least one power TIR PrimFunc after legalization"
for _, func in prim_funcs:
assert _has_thread_binding(func), (
"power PrimFunc left without a GPU thread binding (VerifyMemory would fail)"
)


@pytest.mark.gpu
@pytest.mark.skipif(not env.has_cuda(), reason="need cuda")
def test_power_cuda_build_and_run():
"""End-to-End build and run of an elementwise `R.power` kernel on CUDA.

Compiles through `tvm.compile`, which for a GPU target selects the
target-specific default pipeline (with DLight scheduling), then executes the
kernel and checks the result.
"""
dev = tvm.cuda(0)
target = tvm.target.Target.from_device(dev)

ex = tvm.compile(PowerModule, target=target)
vm = relax.VirtualMachine(ex, dev)

x_np = np.random.rand(1, 2, 1, 1).astype(np.float32)
out = vm["main"](tvm.runtime.tensor(x_np, dev))
tvm.testing.assert_allclose(out.numpy(), x_np**2, rtol=1e-6, atol=1e-6)


if __name__ == "__main__":
tvm.testing.main()
Loading