diff --git a/iris/_distributed_helpers.py b/iris/_distributed_helpers.py index 9b222375e..ac237cb5e 100644 --- a/iris/_distributed_helpers.py +++ b/iris/_distributed_helpers.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + import torch import torch.distributed as dist import numpy as np +import triton +import triton.language as tl def _infer_device(): @@ -207,6 +210,76 @@ def distributed_broadcast_tensor(value_to_broadcast=None, root=0): return obj[0] +def extract_group_info(group, rank, num_ranks): + """ + Extract rank and stride information for a process group. + + Args: + group: ProcessGroup or None. If None, uses the provided rank/num_ranks + as the default (all-ranks) group. + rank: Global rank of the current process. + num_ranks: Total number of ranks in the default group. + + Returns: + Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride): + - rank_in_group: Rank within the group (0-indexed) + - rank_global: Global rank of this process + - world_size: Number of ranks in the group + - rank_start: Starting global rank of the group + - rank_stride: Stride between consecutive ranks in the group + + Examples: + >>> # group=None: all ranks [0,1,2,3], current global rank is 2 + >>> extract_group_info(None, 2, 4) + (2, 2, 4, 0, 1) + + >>> # DP group: strided ranks [0,4,8,12], current global rank is 8 + >>> extract_group_info(dp_group, 8, 16) + (2, 8, 4, 0, 4) + """ + if group is None: + return rank, rank, num_ranks, 0, 1 + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized to use ProcessGroup. " + "Call torch.distributed.init_process_group() first." + ) + + group_ranks = dist.get_process_group_ranks(group) + world_size = len(group_ranks) + rank_global = rank + + if rank_global not in group_ranks: + raise RuntimeError( + f"Rank {rank_global} is not part of the specified process group. Group contains ranks: {group_ranks}" + ) + + rank_in_group = group_ranks.index(rank_global) + + if len(group_ranks) > 1: + strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))] + if not all(s == strides[0] for s in strides): + raise NotImplementedError( + f"Non-strided process groups are not yet supported. " + f"Group ranks: {group_ranks}. " + f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])." + ) + rank_start = group_ranks[0] + rank_stride = strides[0] + if rank_stride == 0: + raise ValueError( + f"Invalid process group: rank_stride is 0, indicating duplicate ranks. " + f"Group ranks: {group_ranks}. " + f"Each rank must appear exactly once in a process group." + ) + else: + rank_start = group_ranks[0] + rank_stride = 1 + + return rank_in_group, rank_global, world_size, rank_start, rank_stride + + def distributed_barrier(group=None): """ Synchronization barrier using PyTorch distributed. @@ -220,6 +293,98 @@ def distributed_barrier(group=None): dist.barrier(group=group) +@triton.jit +def _translate_ptr(ptr, from_rank, to_rank, heap_bases): + """Translate a pointer from one rank's address space to another's.""" + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + offset = tl.cast(ptr, tl.uint64) - from_base + translated_ptr = tl.cast(tl.cast(to_base, tl.pointer_type(tl.int8)) + offset, ptr.dtype) + return translated_ptr + + +@triton.jit +def _device_barrier_kernel( + flags_ptr, + iris_rank, + world_size: tl.constexpr, + rank_start, + rank_stride, + heap_bases, + MAX_SPINS: tl.constexpr = 1_000_000_000, +): + """ + Device-side barrier using atomic operations on the symmetric heap. + CUDA graph capturable. + + Stateless w.r.t. host-side epoch tracking: there is no CPU-side epoch + counter. Each rank's flag on the heap serves as its own epoch counter, + managed entirely by the GPU via atomic_add. A persistent per-group flags + tensor is cached in ``_device_barrier_state``. + + Launched with grid=(1,). A single CTA: + 1. Atomically increments its own flag (atomic_add, release) + 2. Serially polls each remote rank's flag for the same value (acquire) + """ + # Increment own flag and determine target + own_flag_ptr = flags_ptr + iris_rank + own_translated = _translate_ptr(own_flag_ptr, iris_rank, iris_rank, heap_bases) + old = tl.atomic_add(own_translated, 1, sem="release", scope="sys") + target = old + 1 + + # Poll each remote rank serially + for i in range(world_size): + remote_rank = rank_start + i * rank_stride + if remote_rank != iris_rank: + remote_flag_ptr = flags_ptr + remote_rank + remote_translated = _translate_ptr(remote_flag_ptr, iris_rank, remote_rank, heap_bases) + spin_count = 0 + while ( + tl.atomic_cas( + remote_translated, + target, + target, + sem="acquire", + scope="sys", + ) + < target + ): + spin_count += 1 + tl.device_assert(spin_count < MAX_SPINS, "device_barrier: timeout") + + +def distributed_device_barrier(flags, group, rank, num_ranks, heap_bases): + """ + Device-side barrier using atomic operations on the symmetric heap. + CUDA graph capturable. + + Unlike ``distributed_barrier`` which uses host-side ``torch.distributed.barrier()``, + this launches a single-CTA Triton kernel that synchronizes via + device-side atomics, making it safe to use during CUDA graph capture. + + Stateless w.r.t. host-side epoch tracking: each rank's flag on the + symmetric heap serves as its own epoch counter, managed entirely by + the GPU via atomic_add. A persistent per-group flags tensor is cached + in ``_device_barrier_state``. + + Args: + flags: int32 tensor on symmetric heap, one element per rank. + group: ProcessGroup or None. If None, uses all ranks. + rank: Global rank of this process. + num_ranks: Total number of ranks in the default group. + heap_bases: Tensor of heap base addresses for all ranks. + """ + _, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, rank, num_ranks) + _device_barrier_kernel[(1,)]( + flags, + rank_global, + world_size, + rank_start, + rank_stride, + heap_bases, + ) + + def init_distributed(): """ Initialize PyTorch distributed and return communicator info. diff --git a/iris/ccl/utils.py b/iris/ccl/utils.py index 4f90ac09a..eeff2781a 100644 --- a/iris/ccl/utils.py +++ b/iris/ccl/utils.py @@ -9,6 +9,7 @@ from typing import Tuple import triton import triton.language as tl +from iris._distributed_helpers import extract_group_info as _extract_group_info @triton.jit() @@ -67,83 +68,11 @@ def extract_group_info(group, shmem) -> Tuple[int, int, int, int, int]: Returns: Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride) - - rank_in_group: Rank within the group (0-indexed), used for tile assignment and comparisons - - rank_global: Global rank of this process, used for iris RMA operations (heap_bases indexing) + - rank_in_group: Rank within the group (0-indexed) + - rank_global: Global rank of this process - world_size: Number of ranks in the group - rank_start: Starting global rank of the group - rank_stride: Stride between consecutive ranks in the group - - Examples: - >>> # group=None: all ranks [0,1,2,3], current global rank is 2 - >>> extract_group_info(None, shmem) - (2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1 - - >>> # TP group: consecutive ranks [0,1,2,3], current global rank is 2 - >>> extract_group_info(tp_group, shmem) - (2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1 - - >>> # DP group: strided ranks [0,4,8,12], current global rank is 8 - >>> extract_group_info(dp_group, shmem) - (2, 8, 4, 0, 4) # rank_in_group=2, rank_global=8, world_size=4, start=0, stride=4 """ - if group is None: - # Use all ranks in shmem context - # When group is None, rank_in_group equals rank_global - rank_global = shmem.get_rank() - rank_in_group = rank_global - world_size = shmem.get_num_ranks() - rank_start = 0 - rank_stride = 1 - return rank_in_group, rank_global, world_size, rank_start, rank_stride - - # Extract from ProcessGroup - import torch.distributed as dist - - if not dist.is_initialized(): - raise RuntimeError( - "torch.distributed must be initialized to use ProcessGroup. " - "Call torch.distributed.init_process_group() first." - ) - - group_ranks = dist.get_process_group_ranks(group) - world_size = len(group_ranks) - rank_global = dist.get_rank() - - if rank_global not in group_ranks: - raise RuntimeError( - f"Current rank {rank_global} is not part of the specified process group. " - f"Group contains ranks: {group_ranks}" - ) - - rank_in_group = group_ranks.index(rank_global) - - # Detect stride pattern - if len(group_ranks) > 1: - # Check if all consecutive pairs have the same stride - strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))] - is_strided = all(s == strides[0] for s in strides) - - if is_strided: - rank_start = group_ranks[0] - rank_stride = strides[0] - - # Validate rank_stride is not zero (would indicate duplicate ranks) - if rank_stride == 0: - raise ValueError( - f"Invalid process group: rank_stride is 0, indicating duplicate ranks. " - f"Group ranks: {group_ranks}. " - f"Each rank must appear exactly once in a process group." - ) - else: - # Non-strided group - not supported yet - raise NotImplementedError( - f"Non-strided process groups are not yet supported. " - f"Group ranks: {group_ranks}. " - f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])." - ) - else: - # Single rank group - rank_start = group_ranks[0] - rank_stride = 1 - - return rank_in_group, rank_global, world_size, rank_start, rank_stride + + return _extract_group_info(group, shmem.get_rank(), shmem.get_num_ranks()) diff --git a/iris/iris.py b/iris/iris.py index f0effbb2d..28ee3681b 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -45,6 +45,7 @@ from iris._distributed_helpers import ( init_distributed, distributed_barrier, + distributed_device_barrier, distributed_broadcast_scalar, distributed_broadcast_tensor, ) @@ -55,6 +56,7 @@ ) from iris.symmetric_heap import SymmetricHeap import numpy as np +from typing import Any import torch import logging @@ -135,6 +137,9 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"): # Lazy initialization for ops interface self._ops = None + # Device-side barrier state, keyed by process group (None = all ranks). + self._device_barrier_state: dict[Any, torch.Tensor] = {} + # Initialize tracing self.tracing = Tracing(self) @@ -989,6 +994,36 @@ def barrier(self, stream=None, group=None): # Distributed barrier distributed_barrier(group=group) + def device_barrier(self, group=None): + """ + Device-side barrier that is CUDA graph capturable. + + Unlike ``barrier()`` which uses host-side ``torch.distributed.barrier()``, + this uses device-side atomic operations on the symmetric heap to synchronize + ranks. Stateless w.r.t. host-side epoch tracking: each rank's flag on + the heap serves as its own epoch counter, managed entirely by the GPU + via atomic_add. A persistent per-group flags tensor is cached in + ``_device_barrier_state``. + + Args: + group (ProcessGroup, optional): The process group to synchronize. + If None, uses all ranks in the shmem context. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> ctx.device_barrier() # Synchronize all ranks on device + """ + if group not in self._device_barrier_state: + self._device_barrier_state[group] = self.zeros((self.num_ranks,), dtype=torch.int32) + + distributed_device_barrier( + self._device_barrier_state[group], + group, + self.cur_rank, + self.num_ranks, + self.get_heap_bases(), + ) + def get_device(self): """ Get the underlying device where the Iris symmetric heap resides. diff --git a/tests/unittests/test_barriers.py b/tests/unittests/test_barriers.py new file mode 100644 index 000000000..79f6e8351 --- /dev/null +++ b/tests/unittests/test_barriers.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import gc +from typing import Literal + +import pytest +import torch +import triton +import triton.language as tl +import iris + + +BarrierType = Literal["host", "device"] +BARRIER_TYPES: list[BarrierType] = ["host", "device"] + + +def _call_barrier(shmem: iris.Iris, barrier_type: BarrierType) -> None: + if barrier_type == "host": + shmem.barrier() + else: + shmem.device_barrier() + + +@triton.jit +def _read_remote_kernel( + buf_ptr, + result_ptr, + cur_rank: tl.constexpr, + remote_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + offsets = tl.arange(0, BLOCK_SIZE) + data = iris.load(buf_ptr + offsets, cur_rank, remote_rank, heap_bases) + tl.store(result_ptr + offsets, data) + + +@triton.jit +def _write_remote_kernel( + buf_ptr, + value, + cur_rank: tl.constexpr, + remote_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + offsets = tl.arange(0, BLOCK_SIZE) + data = tl.full([BLOCK_SIZE], value, dtype=tl.float32) + iris.store(buf_ptr + offsets, data, cur_rank, remote_rank, heap_bases) + + +@pytest.mark.parametrize("n", [1, 10]) +@pytest.mark.parametrize("barrier_type", BARRIER_TYPES) +def test_barrier_basic(barrier_type, n): + shmem = iris.iris(1 << 20) + _call_barrier(shmem, barrier_type) + + try: + for _ in range(n): + _call_barrier(shmem, barrier_type) + finally: + _call_barrier(shmem, barrier_type) + del shmem + gc.collect() + + +@pytest.mark.parametrize("n", [1, 2, 5, 10]) +@pytest.mark.parametrize("barrier_type", BARRIER_TYPES) +def test_barrier_state_reuse(barrier_type, n): + """Verify device barrier reuses the same flags tensor across calls.""" + shmem = iris.iris(1 << 20) + _call_barrier(shmem, barrier_type) + + try: + shmem.device_barrier() + assert None in shmem._device_barrier_state + flags = shmem._device_barrier_state[None] + flags_ptr = flags.data_ptr() + + for _ in range(n): + shmem.device_barrier() + assert shmem._device_barrier_state[None].data_ptr() == flags_ptr + finally: + _call_barrier(shmem, barrier_type) + del shmem + gc.collect() + + +def _cross_rank_eager( + shmem, + barrier_type, + op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, +): + if op == "load": + for i in range(rounds): + buf.fill_(float(rank + i * 100)) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + _read_remote_kernel[(1,)]( + buf, + result, + rank, + neighbor, + N, + heap_bases, + ) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + expected_val = float(neighbor + i * 100) + expected = torch.full((N,), expected_val, dtype=torch.float32, device="cuda") + torch.testing.assert_close(result, expected, rtol=0, atol=0) + else: + for i in range(rounds): + buf.fill_(0.0) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + write_val = float(rank + i * 100) + _write_remote_kernel[(1,)]( + buf, + write_val, + rank, + neighbor, + N, + heap_bases, + ) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + expected_val = float(writer + i * 100) + expected = torch.full((N,), expected_val, dtype=torch.float32, device="cuda") + torch.testing.assert_close(buf, expected, rtol=0, atol=0) + + +def _cross_rank_graph( + shmem, + op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, +): + capture_stream = torch.cuda.Stream() + + if op == "load": + buf.fill_(float(rank)) + + # Warmup on capture stream. + with torch.cuda.stream(capture_stream): + for _ in range(num_barriers): + shmem.device_barrier() + _read_remote_kernel[(1,)]( + buf, + result, + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + capture_stream.synchronize() + + # Capture. + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=capture_stream): + for _ in range(num_barriers): + shmem.device_barrier() + _read_remote_kernel[(1,)]( + buf, + result, + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + + # Replay with fresh data. + for i in range(rounds): + val = float(rank + (i + 1) * 10) + with torch.cuda.stream(capture_stream): + buf.fill_(val) + shmem.device_barrier() + graph.replay() + capture_stream.synchronize() + + expected = torch.full( + (N,), + float(neighbor + (i + 1) * 10), + dtype=torch.float32, + device="cuda", + ) + torch.testing.assert_close(result, expected, rtol=0, atol=0) + else: + buf.fill_(0.0) + + # Warmup on capture stream. + with torch.cuda.stream(capture_stream): + for _ in range(num_barriers): + shmem.device_barrier() + _write_remote_kernel[(1,)]( + buf, + float(rank), + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + capture_stream.synchronize() + + # Capture. + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=capture_stream): + for _ in range(num_barriers): + shmem.device_barrier() + _write_remote_kernel[(1,)]( + buf, + float(rank), + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + + # Replay and verify. + for _ in range(rounds): + with torch.cuda.stream(capture_stream): + buf.fill_(0.0) + shmem.device_barrier() + graph.replay() + capture_stream.synchronize() + + with torch.cuda.stream(capture_stream): + shmem.device_barrier() + capture_stream.synchronize() + expected = torch.full((N,), float(writer), dtype=torch.float32, device="cuda") + torch.testing.assert_close(buf, expected, rtol=0, atol=0) + + +# Host barrier is not graph-capturable (uses NCCL which crashes with +# hipErrorStreamCaptureUnsupported on ROCm). Skip host+graph combos. +@pytest.mark.parametrize("N", [1, 64, 256, 1024]) +@pytest.mark.parametrize("num_barriers", [1, 2, 4]) +@pytest.mark.parametrize("mode", ["eager", "graph"]) +@pytest.mark.parametrize("op", ["load", "store", "both"]) +@pytest.mark.parametrize("barrier_type", BARRIER_TYPES) +def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3): + """Verify cross-rank data visibility after barrier. + + - op: load (iris.load from neighbor), store (iris.store to neighbor), or both + - mode: eager (direct calls) or graph (CUDA graph capture + replay) + - num_barriers: consecutive barriers to test idempotency + - N: number of elements (must be power of 2 for Triton BLOCK_SIZE) + - rounds: number of iterations with changing data (default 3) + + Each mode runs multiple rounds with changing data to stress correctness. + Graph mode captures barrier + kernel into a CUDA graph, then replays + with fresh data to verify correctness through the captured graph. + """ + if mode == "graph" and barrier_type == "host": + pytest.skip( + "Host barrier uses NCCL which is not graph-capturable on ROCm. See https://github.com/ROCm/HIP/issues/3876" + ) + + shmem = iris.iris(1 << 20) + _call_barrier(shmem, barrier_type) + rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + neighbor = (rank + 1) % num_ranks + writer = (rank - 1 + num_ranks) % num_ranks + + buf = shmem.zeros((N,), dtype=torch.float32) + result = shmem.zeros((N,), dtype=torch.float32) + + ops = ["load", "store"] if op == "both" else [op] + + try: + for single_op in ops: + if mode == "eager": + _cross_rank_eager( + shmem, + barrier_type, + single_op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, + ) + else: + _cross_rank_graph( + shmem, + single_op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, + ) + finally: + _call_barrier(shmem, barrier_type) + del shmem + gc.collect()