Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d44394c
wip back of sdma integration
dsidler Nov 6, 2025
c50e761
Apply Ruff auto-fixes
github-actions[bot] Nov 6, 2025
2f7bc5e
message passing example working
dsidler Nov 6, 2025
5e38fd6
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Nov 6, 2025
759f662
Apply Ruff auto-fixes
github-actions[bot] Nov 6, 2025
ad7769d
update put example to use ce
dsidler Nov 7, 2025
b8862cc
update api calls
dsidler Nov 7, 2025
75c5626
update submodule
dsidler Nov 7, 2025
2b228ab
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Nov 7, 2025
e3aef16
fix merge
dsidler Nov 7, 2025
df04547
Apply Ruff auto-fixes
github-actions[bot] Nov 7, 2025
c5e4735
wip fixed wrap into ring when placing
dsidler Dec 5, 2025
ea17dd6
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Dec 5, 2025
5362318
to_rank 7 working
dsidler Dec 5, 2025
a6b1d40
Apply Ruff auto-fixes
github-actions[bot] Dec 10, 2025
224511f
Merge branch 'main' into dev/dasidler/sdma
dsidler Jan 14, 2026
400b5b7
use triton commit with fix
dsidler Jan 14, 2026
d06cb72
Apply Ruff auto-fixes
github-actions[bot] Jan 14, 2026
b2e358b
send to all ranks but always same stride
dsidler Jan 20, 2026
b245899
update submodule
dsidler Jan 20, 2026
0e7fbd6
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Jan 20, 2026
1ee4c58
use 32B copy packets workaround
dsidler Jan 30, 2026
1c384c3
submodule update
dsidler Jan 30, 2026
0224866
use window command
dsidler Mar 4, 2026
40c228a
use new acquire function
dsidler Mar 5, 2026
34d4ffc
update submodule
dsidler Mar 5, 2026
c8d4b46
Apply Ruff auto-fixes
github-actions[bot] Mar 5, 2026
53f1a20
move padding code
dsidler Mar 5, 2026
099a84c
update submodule for nop packet
dsidler Mar 5, 2026
75b55b2
enable flat copy
dsidler Mar 5, 2026
17d0696
Merge branch 'dev/dasidler/sdma' of https://github.com/ROCm/iris into…
dsidler Mar 5, 2026
e5a38dd
Apply Ruff auto-fixes
github-actions[bot] Mar 5, 2026
02d08c9
Merge branch 'main' into dev/dasidler/sdma
dsidler Mar 5, 2026
0b6ff1a
clean up
dsidler Mar 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "ext/shader_sdma"]
path = ext/shader_sdma
url = https://github.com/AARInternal/shader_sdma.git
241 changes: 241 additions & 0 deletions examples/06_message_passing/message_passing_copy_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import argparse

import torch
import torch.distributed as dist
import triton
import triton.language as tl
import random

from mpi4py import MPI

import iris


@triton.jit
def producer_kernel(
source_buffer, # tl.tensor: pointer to source data
target_buffer, # tl.tensor: pointer to target data
flag, # tl.tensor: pointer to flags
buffer_size, # int32: total number of elements
producer_rank: tl.constexpr,
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
copy_engine_handle_ptr,
):
pid = tl.program_id(0)

# Compute start index of this block
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

# Guard for out-of-bounds accesses
mask = offsets < buffer_size

# Put chunk into remote buffer
iris.put(
source_buffer + offsets,
target_buffer + offsets,
producer_rank,
consumer_rank,
heap_bases_ptr,
copy_engine_handle_ptr,
mask=mask,
USE_COPY_ENGINE=True,
)

# Set flag to signal completion
iris.signal_ce(flag + pid, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr)


@triton.jit
def consumer_kernel(
buffer, # tl.tensor: pointer to shared buffer (read from target_rank)
flag, # tl.tensor: sync flag per block
buffer_size, # int32: total number of elements
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
):
pid = tl.program_id(0)

block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < buffer_size

# Spin-wait until writer sets flag[pid] = 1
# zero_u64 = tl.zeros((1,), tl.uint64)
# one_u64 = tl.full((1,), 1, tl.uint64)
done = 0 # zero_u64
while done == 0:
done = iris.atomic_cas(
flag + pid, 1, 0, consumer_rank, consumer_rank, heap_bases_ptr, sem="acquire", scope="sys"
)

# Read from the target buffer (written by producer)
values = tl.load(buffer + offsets, mask=mask)

# Do something with values...
# (Here you might write to output, do computation, etc.)
values = values * 2

# Store chunk to target buffer
tl.store(
buffer + offsets,
values,
mask=mask,
)

# Optionally reset the flag for next iteration
tl.store(flag + pid, 0)


torch.manual_seed(123)
random.seed(123)


def torch_dtype_from_str(datatype: str) -> torch.dtype:
dtype_map = {
"fp16": torch.float16,
"fp32": torch.float32,
"int8": torch.int8,
"bf16": torch.bfloat16,
}
try:
return dtype_map[datatype]
except KeyError:
print(f"Unknown datatype: {datatype}")
exit(1)


def parse_args():
parser = argparse.ArgumentParser(
description="Parse Message Passing configuration.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-t",
"--datatype",
type=str,
default="fp32",
choices=["fp16", "fp32", "int8", "bf16"],
help="Datatype of computation",
)
parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size")
parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size")

parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size")

return vars(parser.parse_args())


def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
"""Worker function for PyTorch distributed execution."""
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(
backend=backend,
init_method=init_url,
world_size=world_size,
rank=local_rank,
device_id=torch.device(f"cuda:{local_rank}"),
)

# Main benchmark logic
shmem = iris.iris(args["heap_size"])
dtype = torch_dtype_from_str(args["datatype"])
cur_rank = shmem.get_rank()
world_size = shmem.get_num_ranks()

# Allocate source and destination buffers on the symmetric heap
destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype)
if dtype.is_floating_point:
source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype)
else:
ii = torch.iinfo(dtype)
source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype)

if world_size != 2:
raise ValueError("This example requires exactly two processes.")

producer_rank = 0
consumer_rank = 1

n_elements = source_buffer.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
num_blocks = triton.cdiv(n_elements, args["block_size"])

# Allocate flags on the symmetric heap
flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32)

if cur_rank == producer_rank:
shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.")
kk = producer_kernel[grid](
source_buffer,
destination_buffer,
flags,
n_elements,
producer_rank,
consumer_rank,
args["block_size"],
shmem.get_heap_bases(),
shmem.get_copy_engine_handle(consumer_rank),
)
else:
shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.")
kk = consumer_kernel[grid](
destination_buffer, flags, n_elements, consumer_rank, args["block_size"], shmem.get_heap_bases()
)
shmem.barrier()
shmem.info(f"Rank {cur_rank} has finished sending/receiving data.")
shmem.info("Validating output...")

success = True
if cur_rank == consumer_rank:
expected = source_buffer * 2
diff_mask = ~torch.isclose(destination_buffer, expected, atol=1)
breaking_indices = torch.nonzero(diff_mask, as_tuple=False)

if not torch.allclose(destination_buffer, expected, atol=1):
max_diff = (destination_buffer - expected).abs().max().item()
shmem.info(f"Max absolute difference: {max_diff}")
for idx in breaking_indices:
idx = tuple(idx.tolist())
computed_val = destination_buffer[idx]
expected_val = expected[idx]
shmem.info(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}")
success = False
break

if success:
shmem.info("Validation successful.")
else:
shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}")

shmem.barrier()

dist.barrier()
dist.destroy_process_group()


def main():
args = parse_args()

comm = MPI.COMM_WORLD # Communicator for all processes
rank = comm.Get_rank() # Get the rank of the current process
num_ranks = comm.Get_size() # Total number of processes
# TODO local_rank
torch.cuda.set_device(rank)

# Synchronize all processes
comm.barrier()

init_url = "tcp://127.0.0.1:29500"

_worker(rank, num_ranks, init_url, args)


if __name__ == "__main__":
main()
44 changes: 37 additions & 7 deletions examples/06_message_passing/message_passing_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def producer_kernel(
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
copy_engine_handle_ptr,
USE_COPY_ENGINE: tl.constexpr,
):
pid = tl.program_id(0)

Expand All @@ -34,10 +36,30 @@ def producer_kernel(
mask = offsets < buffer_size

# Put chunk into remote buffer
iris.put(source_buffer + offsets, target_buffer + offsets, producer_rank, consumer_rank, heap_bases_ptr, mask=mask)
iris.put(
source_buffer + offsets,
target_buffer + offsets,
producer_rank,
consumer_rank,
heap_bases_ptr,
copy_engine_handle_ptr,
mask=mask,
USE_COPY_ENGINE=USE_COPY_ENGINE,
)

# Set flag to signal completion
iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, sem="release", scope="sys")
# iris.atomic_cas(flag + pid, 0, 1, producer_rank, consumer_rank, heap_bases_ptr, copy_engine_handle_ptr, sem="release", scope="sys")
iris.atomic_add(
flag + pid,
1,
producer_rank,
consumer_rank,
heap_bases_ptr,
sem="release",
scope="sys",
copy_engine_ctx=copy_engine_handle_ptr,
USE_COPY_ENGINE=USE_COPY_ENGINE,
)


@triton.jit
Expand Down Expand Up @@ -113,9 +135,11 @@ def parse_args():
)
parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size")
parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size")

parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")
parser.add_argument(
"-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies"
)

return vars(parser.parse_args())

Expand All @@ -138,12 +162,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
world_size = shmem.get_num_ranks()

# Allocate source and destination buffers on the symmetric heap
source_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype)
destination_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype)
if dtype.is_floating_point:
destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype)
source_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype)
else:
ii = torch.iinfo(dtype)
destination_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype)
source_buffer = shmem.randint(ii.min, ii.max, (args["buffer_size"],), device="cuda", dtype=dtype)

if world_size != 2:
raise ValueError("This example requires exactly two processes.")
Expand All @@ -158,6 +182,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
# Allocate flags on the symmetric heap
flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32)

# Get copy engine context
# copy_engine_ctx = shmem.get_copy_engine_handle(consumer_rank) if args["use_copy_engine"] and cur_rank == producer_rank else None
copy_engine_ctx = shmem.get_copy_engine_ctx()

if cur_rank == producer_rank:
shmem.info(f"Rank {cur_rank} is sending data to rank {consumer_rank}.")
kk = producer_kernel[grid](
Expand All @@ -169,6 +197,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
consumer_rank,
args["block_size"],
shmem.get_heap_bases(),
copy_engine_ctx,
USE_COPY_ENGINE=args["use_copy_engine"],
)
else:
shmem.info(f"Rank {cur_rank} is receiving data from rank {producer_rank}.")
Expand Down Expand Up @@ -199,7 +229,7 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
if success:
shmem.info("Validation successful.")
else:
shmem.info("Validation failed.")
shmem.info(f"Validation failed with {len(breaking_indices)} errors / {destination_buffer.numel()}")

shmem.barrier()

Expand Down
12 changes: 11 additions & 1 deletion examples/10_gemm_all_scatter_wg_specialization/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def parse_args():
)
parser.add_argument("--num_stages", type=int, default=2, help="Number of stages")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")
parser.add_argument(
"-c", "--use_copy_engine", action="store_true", help="Use copy engine for device-to-device copies"
)

return vars(parser.parse_args())

Expand Down Expand Up @@ -133,6 +136,11 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
total_tiles = total_blocks_M * total_blocks_N

locks = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int8)
comm_sms = args["num_sms"] - args["gemm_sms"]
flags = shmem.zeros((comm_sms, world_size), device="cuda", dtype=torch.uint32)

# Get copy engine context
copy_engine_ctx = shmem.get_copy_engine_ctx()

bias = None

Expand Down Expand Up @@ -175,6 +183,7 @@ def run_experiment():
global_C,
bias,
locks,
flags,
rank,
world_size,
args["gemm_sms"],
Expand All @@ -187,6 +196,8 @@ def run_experiment():
shmem.get_heap_bases(),
"gfx942",
args["trace_tiles"],
args["use_copy_engine"],
copy_engine_ctx,
timestamps.mm_begin_timestamp,
timestamps.mm_end_timestamp,
)
Expand Down Expand Up @@ -224,7 +235,6 @@ def run_experiment():

# Wait for all to finish validation
shmem.barrier()
shmem.info("Validating local C...")

json_writer.add_field("success", success)

Expand Down
Loading
Loading