From 2080f137599271c7390b6e6db4ad0177718642b9 Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 10 Mar 2026 14:35:05 -0400 Subject: [PATCH 1/2] add device barrier --- iris/_distributed_helpers.py | 163 ++++++++++++++ iris/ccl/utils.py | 81 +------ iris/iris.py | 33 +++ tests/unittests/test_barriers.py | 375 +++++++++++++++++++++++++++++++ 4 files changed, 576 insertions(+), 76 deletions(-) create mode 100644 tests/unittests/test_barriers.py diff --git a/iris/_distributed_helpers.py b/iris/_distributed_helpers.py index 9b222375e..7900763d9 100644 --- a/iris/_distributed_helpers.py +++ b/iris/_distributed_helpers.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + import torch import torch.distributed as dist import numpy as np +import triton +import triton.language as tl def _infer_device(): @@ -207,6 +210,77 @@ def distributed_broadcast_tensor(value_to_broadcast=None, root=0): return obj[0] +def extract_group_info(group, rank, num_ranks): + """ + Extract rank and stride information for a process group. + + Args: + group: ProcessGroup or None. If None, uses the provided rank/num_ranks + as the default (all-ranks) group. + rank: Global rank of the current process. + num_ranks: Total number of ranks in the default group. + + Returns: + Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride): + - rank_in_group: Rank within the group (0-indexed) + - rank_global: Global rank of this process + - world_size: Number of ranks in the group + - rank_start: Starting global rank of the group + - rank_stride: Stride between consecutive ranks in the group + + Examples: + >>> # group=None: all ranks [0,1,2,3], current global rank is 2 + >>> extract_group_info(None, 2, 4) + (2, 2, 4, 0, 1) + + >>> # DP group: strided ranks [0,4,8,12], current global rank is 8 + >>> extract_group_info(dp_group, 8, 16) + (2, 8, 4, 0, 4) + """ + if group is None: + return rank, rank, num_ranks, 0, 1 + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized to use ProcessGroup. " + "Call torch.distributed.init_process_group() first." + ) + + group_ranks = dist.get_process_group_ranks(group) + world_size = len(group_ranks) + rank_global = dist.get_rank() + + if rank_global not in group_ranks: + raise RuntimeError( + f"Current rank {rank_global} is not part of the specified process group. " + f"Group contains ranks: {group_ranks}" + ) + + rank_in_group = group_ranks.index(rank_global) + + if len(group_ranks) > 1: + strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))] + if not all(s == strides[0] for s in strides): + raise NotImplementedError( + f"Non-strided process groups are not yet supported. " + f"Group ranks: {group_ranks}. " + f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])." + ) + rank_start = group_ranks[0] + rank_stride = strides[0] + if rank_stride == 0: + raise ValueError( + f"Invalid process group: rank_stride is 0, indicating duplicate ranks. " + f"Group ranks: {group_ranks}. " + f"Each rank must appear exactly once in a process group." + ) + else: + rank_start = group_ranks[0] + rank_stride = 1 + + return rank_in_group, rank_global, world_size, rank_start, rank_stride + + def distributed_barrier(group=None): """ Synchronization barrier using PyTorch distributed. @@ -220,6 +294,95 @@ def distributed_barrier(group=None): dist.barrier(group=group) +@triton.jit +def _translate_ptr(ptr, from_rank, to_rank, heap_bases): + """Translate a pointer from one rank's address space to another's.""" + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + offset = tl.cast(ptr, tl.uint64) - from_base + translated_ptr = tl.cast(tl.cast(to_base, tl.pointer_type(tl.int8)) + offset, ptr.dtype) + return translated_ptr + + +@triton.jit +def _device_barrier_kernel( + flags_ptr, + iris_rank, + world_size: tl.constexpr, + rank_start, + rank_stride, + heap_bases, + MAX_SPINS: tl.constexpr = 1_000_000_000, +): + """ + Stateless device-side barrier using atomic operations on the symmetric heap. + + Launched with grid=(1,). A single CTA: + 1. Atomically increments its own flag (atomic_add, release) + 2. Serially polls each remote rank's flag for the same value (acquire) + + No CPU-side epoch tracking. Each rank's flag IS the epoch, managed + entirely on the GPU via atomic_add. This makes the barrier safe for + CUDA graph capture: during recording the kernel is just recorded, + during replay all ranks increment together. + """ + # Increment own flag and determine target + own_flag_ptr = flags_ptr + iris_rank + own_translated = _translate_ptr(own_flag_ptr, iris_rank, iris_rank, heap_bases) + old = tl.atomic_add(own_translated, 1, sem="release", scope="sys") + target = old + 1 + + # Poll each remote rank serially + for i in range(world_size): + remote_rank = rank_start + i * rank_stride + if remote_rank != iris_rank: + remote_flag_ptr = flags_ptr + remote_rank + remote_translated = _translate_ptr(remote_flag_ptr, iris_rank, remote_rank, heap_bases) + spin_count = 0 + while ( + tl.atomic_cas( + remote_translated, + target, + target, + sem="acquire", + scope="sys", + ) + < target + ): + spin_count += 1 + tl.device_assert(spin_count < MAX_SPINS, "device_barrier: timeout") + + +def distributed_device_barrier(flags, group, rank, num_ranks, heap_bases): + """ + Stateless device-side barrier using atomic operations on the symmetric heap. + + Unlike ``distributed_barrier`` which uses host-side ``torch.distributed.barrier()``, + this launches a single-CTA Triton kernel that synchronizes via + device-side atomics, making it safe to use during CUDA graph capture. + + No CPU-side epoch tracking is needed. Each rank's flag on the symmetric + heap serves as its own epoch counter, managed entirely by the GPU via + atomic_add. + + Args: + flags: int32 tensor on symmetric heap, one element per rank. + group: ProcessGroup or None. If None, uses all ranks. + rank: Global rank of this process. + num_ranks: Total number of ranks in the default group. + heap_bases: Tensor of heap base addresses for all ranks. + """ + _, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, rank, num_ranks) + _device_barrier_kernel[(1,)]( + flags, + rank_global, + world_size, + rank_start, + rank_stride, + heap_bases, + ) + + def init_distributed(): """ Initialize PyTorch distributed and return communicator info. diff --git a/iris/ccl/utils.py b/iris/ccl/utils.py index 4f90ac09a..eeff2781a 100644 --- a/iris/ccl/utils.py +++ b/iris/ccl/utils.py @@ -9,6 +9,7 @@ from typing import Tuple import triton import triton.language as tl +from iris._distributed_helpers import extract_group_info as _extract_group_info @triton.jit() @@ -67,83 +68,11 @@ def extract_group_info(group, shmem) -> Tuple[int, int, int, int, int]: Returns: Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride) - - rank_in_group: Rank within the group (0-indexed), used for tile assignment and comparisons - - rank_global: Global rank of this process, used for iris RMA operations (heap_bases indexing) + - rank_in_group: Rank within the group (0-indexed) + - rank_global: Global rank of this process - world_size: Number of ranks in the group - rank_start: Starting global rank of the group - rank_stride: Stride between consecutive ranks in the group - - Examples: - >>> # group=None: all ranks [0,1,2,3], current global rank is 2 - >>> extract_group_info(None, shmem) - (2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1 - - >>> # TP group: consecutive ranks [0,1,2,3], current global rank is 2 - >>> extract_group_info(tp_group, shmem) - (2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1 - - >>> # DP group: strided ranks [0,4,8,12], current global rank is 8 - >>> extract_group_info(dp_group, shmem) - (2, 8, 4, 0, 4) # rank_in_group=2, rank_global=8, world_size=4, start=0, stride=4 """ - if group is None: - # Use all ranks in shmem context - # When group is None, rank_in_group equals rank_global - rank_global = shmem.get_rank() - rank_in_group = rank_global - world_size = shmem.get_num_ranks() - rank_start = 0 - rank_stride = 1 - return rank_in_group, rank_global, world_size, rank_start, rank_stride - - # Extract from ProcessGroup - import torch.distributed as dist - - if not dist.is_initialized(): - raise RuntimeError( - "torch.distributed must be initialized to use ProcessGroup. " - "Call torch.distributed.init_process_group() first." - ) - - group_ranks = dist.get_process_group_ranks(group) - world_size = len(group_ranks) - rank_global = dist.get_rank() - - if rank_global not in group_ranks: - raise RuntimeError( - f"Current rank {rank_global} is not part of the specified process group. " - f"Group contains ranks: {group_ranks}" - ) - - rank_in_group = group_ranks.index(rank_global) - - # Detect stride pattern - if len(group_ranks) > 1: - # Check if all consecutive pairs have the same stride - strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))] - is_strided = all(s == strides[0] for s in strides) - - if is_strided: - rank_start = group_ranks[0] - rank_stride = strides[0] - - # Validate rank_stride is not zero (would indicate duplicate ranks) - if rank_stride == 0: - raise ValueError( - f"Invalid process group: rank_stride is 0, indicating duplicate ranks. " - f"Group ranks: {group_ranks}. " - f"Each rank must appear exactly once in a process group." - ) - else: - # Non-strided group - not supported yet - raise NotImplementedError( - f"Non-strided process groups are not yet supported. " - f"Group ranks: {group_ranks}. " - f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])." - ) - else: - # Single rank group - rank_start = group_ranks[0] - rank_stride = 1 - - return rank_in_group, rank_global, world_size, rank_start, rank_stride + + return _extract_group_info(group, shmem.get_rank(), shmem.get_num_ranks()) diff --git a/iris/iris.py b/iris/iris.py index f0effbb2d..1c6e325aa 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -45,6 +45,7 @@ from iris._distributed_helpers import ( init_distributed, distributed_barrier, + distributed_device_barrier, distributed_broadcast_scalar, distributed_broadcast_tensor, ) @@ -55,6 +56,7 @@ ) from iris.symmetric_heap import SymmetricHeap import numpy as np +from typing import Any import torch import logging @@ -135,6 +137,9 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"): # Lazy initialization for ops interface self._ops = None + # Device-side barrier state, keyed by process group (None = all ranks). + self._device_barrier_state: dict[Any, torch.Tensor] = {} + # Initialize tracing self.tracing = Tracing(self) @@ -989,6 +994,34 @@ def barrier(self, stream=None, group=None): # Distributed barrier distributed_barrier(group=group) + def device_barrier(self, group=None): + """ + Stateless device-side barrier that is CUDA graph capturable. + + Unlike ``barrier()`` which uses host-side ``torch.distributed.barrier()``, + this uses device-side atomic operations on the symmetric heap to synchronize + ranks. No CPU-side epoch tracking -- each rank's flag on the heap serves + as its own epoch counter, managed entirely by the GPU via atomic_add. + + Args: + group (ProcessGroup, optional): The process group to synchronize. + If None, uses all ranks in the shmem context. + + Example: + >>> ctx = iris.iris(1 << 20) + >>> ctx.device_barrier() # Synchronize all ranks on device + """ + if group not in self._device_barrier_state: + self._device_barrier_state[group] = self.zeros((self.num_ranks,), dtype=torch.int32) + + distributed_device_barrier( + self._device_barrier_state[group], + group, + self.cur_rank, + self.num_ranks, + self.get_heap_bases(), + ) + def get_device(self): """ Get the underlying device where the Iris symmetric heap resides. diff --git a/tests/unittests/test_barriers.py b/tests/unittests/test_barriers.py new file mode 100644 index 000000000..1c9e5efaf --- /dev/null +++ b/tests/unittests/test_barriers.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import gc +from typing import Literal + +import pytest +import torch +import triton +import triton.language as tl +import iris +from iris._distributed_helpers import _device_barrier_kernel, extract_group_info + + +BarrierType = Literal["host", "device"] +BARRIER_TYPES: list[BarrierType] = ["host", "device"] + + +def _call_barrier(shmem: iris.Iris, barrier_type: BarrierType) -> None: + if barrier_type == "host": + shmem.barrier() + else: + shmem.device_barrier() + + +@triton.jit +def _read_remote_kernel( + buf_ptr, + result_ptr, + cur_rank: tl.constexpr, + remote_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + offsets = tl.arange(0, BLOCK_SIZE) + data = iris.load(buf_ptr + offsets, cur_rank, remote_rank, heap_bases) + tl.store(result_ptr + offsets, data) + + +@triton.jit +def _write_remote_kernel( + buf_ptr, + value, + cur_rank: tl.constexpr, + remote_rank: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, +): + offsets = tl.arange(0, BLOCK_SIZE) + data = tl.full([BLOCK_SIZE], value, dtype=tl.float32) + iris.store(buf_ptr + offsets, data, cur_rank, remote_rank, heap_bases) + + +@pytest.mark.parametrize("n", [1, 10]) +@pytest.mark.parametrize("barrier_type", BARRIER_TYPES) +def test_barrier_basic(barrier_type, n): + shmem = iris.iris(1 << 20) + shmem.barrier() + + try: + for _ in range(n): + _call_barrier(shmem, barrier_type) + finally: + shmem.barrier() + del shmem + gc.collect() + + +@pytest.mark.parametrize("n", [1, 2, 5, 10]) +def test_barrier_state_reuse(n): + """Verify device barrier reuses the same flags tensor across calls.""" + shmem = iris.iris(1 << 20) + shmem.barrier() + + try: + shmem.device_barrier() + assert None in shmem._device_barrier_state + flags = shmem._device_barrier_state[None] + flags_ptr = flags.data_ptr() + + for _ in range(n): + shmem.device_barrier() + assert shmem._device_barrier_state[None].data_ptr() == flags_ptr + finally: + shmem.barrier() + del shmem + gc.collect() + + +def _cross_rank_eager( + shmem, + barrier_type, + op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, +): + if op == "load": + for i in range(rounds): + buf.fill_(float(rank + i * 100)) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + _read_remote_kernel[(1,)]( + buf, + result, + rank, + neighbor, + N, + heap_bases, + ) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + expected_val = float(neighbor + i * 100) + expected = torch.full((N,), expected_val, dtype=torch.float32, device="cuda") + torch.testing.assert_close(result, expected, rtol=0, atol=0) + else: + for i in range(rounds): + buf.fill_(0.0) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + write_val = float(rank + i * 100) + _write_remote_kernel[(1,)]( + buf, + write_val, + rank, + neighbor, + N, + heap_bases, + ) + + for _ in range(num_barriers): + _call_barrier(shmem, barrier_type) + + expected_val = float(writer + i * 100) + expected = torch.full((N,), expected_val, dtype=torch.float32, device="cuda") + torch.testing.assert_close(buf, expected, rtol=0, atol=0) + + +def _cross_rank_graph( + shmem, + op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, +): + stream = torch.cuda.Stream() + + if op == "load": + buf.fill_(float(rank)) + + # Warmup on capture stream. + with torch.cuda.stream(stream): + for _ in range(num_barriers): + shmem.device_barrier() + _read_remote_kernel[(1,)]( + buf, + result, + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + stream.synchronize() + + # Capture. + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + for _ in range(num_barriers): + shmem.device_barrier() + _read_remote_kernel[(1,)]( + buf, + result, + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + + # Replay with fresh data. + for i in range(rounds): + val = float(rank + (i + 1) * 10) + buf.fill_(val) + shmem.device_barrier() + + graph.replay() + stream.synchronize() + + expected = torch.full( + (N,), + float(neighbor + (i + 1) * 10), + dtype=torch.float32, + device="cuda", + ) + torch.testing.assert_close(result, expected, rtol=0, atol=0) + else: + buf.fill_(0.0) + + # Warmup on capture stream. + with torch.cuda.stream(stream): + for _ in range(num_barriers): + shmem.device_barrier() + _write_remote_kernel[(1,)]( + buf, + float(rank), + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + stream.synchronize() + + # Capture. + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + for _ in range(num_barriers): + shmem.device_barrier() + _write_remote_kernel[(1,)]( + buf, + float(rank), + rank, + neighbor, + N, + heap_bases, + ) + for _ in range(num_barriers): + shmem.device_barrier() + + # Replay and verify. + for _ in range(rounds): + buf.fill_(0.0) + shmem.device_barrier() + + graph.replay() + stream.synchronize() + + shmem.device_barrier() + expected = torch.full((N,), float(writer), dtype=torch.float32, device="cuda") + torch.testing.assert_close(buf, expected, rtol=0, atol=0) + + +# Host barrier is not graph-capturable (uses NCCL which crashes with +# hipErrorStreamCaptureUnsupported on ROCm). Skip host+graph combos. +@pytest.mark.parametrize("N", [1, 64, 256, 1024]) +@pytest.mark.parametrize("num_barriers", [1, 2, 4]) +@pytest.mark.parametrize("mode", ["eager", "graph"]) +@pytest.mark.parametrize("op", ["load", "store", "both"]) +@pytest.mark.parametrize("barrier_type", BARRIER_TYPES) +def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3): + """Verify cross-rank data visibility after barrier. + + - op: load (iris.load from neighbor), store (iris.store to neighbor), or both + - mode: eager (direct calls) or graph (CUDA graph capture + replay) + - num_barriers: consecutive barriers to test idempotency + - N: number of elements (must be power of 2 for Triton BLOCK_SIZE) + - rounds: number of iterations with changing data (default 3) + + Each mode runs multiple rounds with changing data to stress correctness. + Graph mode captures barrier + kernel into a CUDA graph, then replays + with fresh data to verify correctness through the captured graph. + """ + if mode == "graph" and barrier_type == "host": + pytest.skip( + "Host barrier uses NCCL which is not graph-capturable on ROCm. See https://github.com/ROCm/HIP/issues/3876" + ) + + shmem = iris.iris(1 << 20) + shmem.barrier() + rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + neighbor = (rank + 1) % num_ranks + writer = (rank - 1 + num_ranks) % num_ranks + + buf = shmem.zeros((N,), dtype=torch.float32) + result = shmem.zeros((N,), dtype=torch.float32) + + ops = ["load", "store"] if op == "both" else [op] + + try: + for single_op in ops: + if mode == "eager": + _cross_rank_eager( + shmem, + barrier_type, + single_op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, + ) + else: + _cross_rank_graph( + shmem, + single_op, + num_barriers, + rounds, + N, + rank, + neighbor, + writer, + heap_bases, + buf, + result, + ) + finally: + shmem.barrier() + del shmem + gc.collect() + + +def test_barrier_timeout_assert(): + """Verify device_barrier asserts on timeout instead of hanging forever. + + Only rank 0 calls the barrier kernel. Other ranks skip it, so rank 0 + spins waiting for them and hits the MAX_SPINS assert. + """ + shmem = iris.iris(1 << 20) + rank = shmem.get_rank() + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + + if num_ranks < 2: + pytest.skip("Need at least 2 ranks") + + shmem.barrier() + + flags = shmem._device_barrier_state.setdefault(None, shmem.zeros((num_ranks,), dtype=torch.int32)) + + try: + if rank == 0: + _, rank_global, world_size, rank_start, rank_stride = extract_group_info(None, rank, num_ranks) + _device_barrier_kernel[(1,)]( + flags, + rank_global, + world_size, + rank_start, + rank_stride, + heap_bases, + MAX_SPINS=1000, + ) + with pytest.raises(RuntimeError, match="device-side assert"): + torch.cuda.synchronize() + finally: + shmem.barrier() + del shmem + gc.collect() From 4d1cc6f44e8129f75c6aeea153820b1843a87569 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 11 Mar 2026 10:27:46 -0400 Subject: [PATCH 2/2] address copilot feedback --- iris/_distributed_helpers.py | 28 +++++----- iris/iris.py | 8 +-- tests/unittests/test_barriers.py | 92 ++++++++++---------------------- 3 files changed, 48 insertions(+), 80 deletions(-) diff --git a/iris/_distributed_helpers.py b/iris/_distributed_helpers.py index 7900763d9..ac237cb5e 100644 --- a/iris/_distributed_helpers.py +++ b/iris/_distributed_helpers.py @@ -248,12 +248,11 @@ def extract_group_info(group, rank, num_ranks): group_ranks = dist.get_process_group_ranks(group) world_size = len(group_ranks) - rank_global = dist.get_rank() + rank_global = rank if rank_global not in group_ranks: raise RuntimeError( - f"Current rank {rank_global} is not part of the specified process group. " - f"Group contains ranks: {group_ranks}" + f"Rank {rank_global} is not part of the specified process group. Group contains ranks: {group_ranks}" ) rank_in_group = group_ranks.index(rank_global) @@ -315,16 +314,17 @@ def _device_barrier_kernel( MAX_SPINS: tl.constexpr = 1_000_000_000, ): """ - Stateless device-side barrier using atomic operations on the symmetric heap. + Device-side barrier using atomic operations on the symmetric heap. + CUDA graph capturable. + + Stateless w.r.t. host-side epoch tracking: there is no CPU-side epoch + counter. Each rank's flag on the heap serves as its own epoch counter, + managed entirely by the GPU via atomic_add. A persistent per-group flags + tensor is cached in ``_device_barrier_state``. Launched with grid=(1,). A single CTA: 1. Atomically increments its own flag (atomic_add, release) 2. Serially polls each remote rank's flag for the same value (acquire) - - No CPU-side epoch tracking. Each rank's flag IS the epoch, managed - entirely on the GPU via atomic_add. This makes the barrier safe for - CUDA graph capture: during recording the kernel is just recorded, - during replay all ranks increment together. """ # Increment own flag and determine target own_flag_ptr = flags_ptr + iris_rank @@ -355,15 +355,17 @@ def _device_barrier_kernel( def distributed_device_barrier(flags, group, rank, num_ranks, heap_bases): """ - Stateless device-side barrier using atomic operations on the symmetric heap. + Device-side barrier using atomic operations on the symmetric heap. + CUDA graph capturable. Unlike ``distributed_barrier`` which uses host-side ``torch.distributed.barrier()``, this launches a single-CTA Triton kernel that synchronizes via device-side atomics, making it safe to use during CUDA graph capture. - No CPU-side epoch tracking is needed. Each rank's flag on the symmetric - heap serves as its own epoch counter, managed entirely by the GPU via - atomic_add. + Stateless w.r.t. host-side epoch tracking: each rank's flag on the + symmetric heap serves as its own epoch counter, managed entirely by + the GPU via atomic_add. A persistent per-group flags tensor is cached + in ``_device_barrier_state``. Args: flags: int32 tensor on symmetric heap, one element per rank. diff --git a/iris/iris.py b/iris/iris.py index 1c6e325aa..28ee3681b 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -996,12 +996,14 @@ def barrier(self, stream=None, group=None): def device_barrier(self, group=None): """ - Stateless device-side barrier that is CUDA graph capturable. + Device-side barrier that is CUDA graph capturable. Unlike ``barrier()`` which uses host-side ``torch.distributed.barrier()``, this uses device-side atomic operations on the symmetric heap to synchronize - ranks. No CPU-side epoch tracking -- each rank's flag on the heap serves - as its own epoch counter, managed entirely by the GPU via atomic_add. + ranks. Stateless w.r.t. host-side epoch tracking: each rank's flag on + the heap serves as its own epoch counter, managed entirely by the GPU + via atomic_add. A persistent per-group flags tensor is cached in + ``_device_barrier_state``. Args: group (ProcessGroup, optional): The process group to synchronize. diff --git a/tests/unittests/test_barriers.py b/tests/unittests/test_barriers.py index 1c9e5efaf..79f6e8351 100644 --- a/tests/unittests/test_barriers.py +++ b/tests/unittests/test_barriers.py @@ -9,7 +9,6 @@ import triton import triton.language as tl import iris -from iris._distributed_helpers import _device_barrier_kernel, extract_group_info BarrierType = Literal["host", "device"] @@ -55,22 +54,23 @@ def _write_remote_kernel( @pytest.mark.parametrize("barrier_type", BARRIER_TYPES) def test_barrier_basic(barrier_type, n): shmem = iris.iris(1 << 20) - shmem.barrier() + _call_barrier(shmem, barrier_type) try: for _ in range(n): _call_barrier(shmem, barrier_type) finally: - shmem.barrier() + _call_barrier(shmem, barrier_type) del shmem gc.collect() @pytest.mark.parametrize("n", [1, 2, 5, 10]) -def test_barrier_state_reuse(n): +@pytest.mark.parametrize("barrier_type", BARRIER_TYPES) +def test_barrier_state_reuse(barrier_type, n): """Verify device barrier reuses the same flags tensor across calls.""" shmem = iris.iris(1 << 20) - shmem.barrier() + _call_barrier(shmem, barrier_type) try: shmem.device_barrier() @@ -82,7 +82,7 @@ def test_barrier_state_reuse(n): shmem.device_barrier() assert shmem._device_barrier_state[None].data_ptr() == flags_ptr finally: - shmem.barrier() + _call_barrier(shmem, barrier_type) del shmem gc.collect() @@ -161,13 +161,13 @@ def _cross_rank_graph( buf, result, ): - stream = torch.cuda.Stream() + capture_stream = torch.cuda.Stream() if op == "load": buf.fill_(float(rank)) # Warmup on capture stream. - with torch.cuda.stream(stream): + with torch.cuda.stream(capture_stream): for _ in range(num_barriers): shmem.device_barrier() _read_remote_kernel[(1,)]( @@ -180,11 +180,11 @@ def _cross_rank_graph( ) for _ in range(num_barriers): shmem.device_barrier() - stream.synchronize() + capture_stream.synchronize() # Capture. graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): + with torch.cuda.graph(graph, stream=capture_stream): for _ in range(num_barriers): shmem.device_barrier() _read_remote_kernel[(1,)]( @@ -201,11 +201,11 @@ def _cross_rank_graph( # Replay with fresh data. for i in range(rounds): val = float(rank + (i + 1) * 10) - buf.fill_(val) - shmem.device_barrier() - - graph.replay() - stream.synchronize() + with torch.cuda.stream(capture_stream): + buf.fill_(val) + shmem.device_barrier() + graph.replay() + capture_stream.synchronize() expected = torch.full( (N,), @@ -218,7 +218,7 @@ def _cross_rank_graph( buf.fill_(0.0) # Warmup on capture stream. - with torch.cuda.stream(stream): + with torch.cuda.stream(capture_stream): for _ in range(num_barriers): shmem.device_barrier() _write_remote_kernel[(1,)]( @@ -231,11 +231,11 @@ def _cross_rank_graph( ) for _ in range(num_barriers): shmem.device_barrier() - stream.synchronize() + capture_stream.synchronize() # Capture. graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): + with torch.cuda.graph(graph, stream=capture_stream): for _ in range(num_barriers): shmem.device_barrier() _write_remote_kernel[(1,)]( @@ -251,13 +251,15 @@ def _cross_rank_graph( # Replay and verify. for _ in range(rounds): - buf.fill_(0.0) - shmem.device_barrier() - - graph.replay() - stream.synchronize() + with torch.cuda.stream(capture_stream): + buf.fill_(0.0) + shmem.device_barrier() + graph.replay() + capture_stream.synchronize() - shmem.device_barrier() + with torch.cuda.stream(capture_stream): + shmem.device_barrier() + capture_stream.synchronize() expected = torch.full((N,), float(writer), dtype=torch.float32, device="cuda") torch.testing.assert_close(buf, expected, rtol=0, atol=0) @@ -288,7 +290,7 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3): ) shmem = iris.iris(1 << 20) - shmem.barrier() + _call_barrier(shmem, barrier_type) rank = shmem.get_rank() num_ranks = shmem.get_num_ranks() heap_bases = shmem.get_heap_bases() @@ -332,44 +334,6 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3): result, ) finally: - shmem.barrier() - del shmem - gc.collect() - - -def test_barrier_timeout_assert(): - """Verify device_barrier asserts on timeout instead of hanging forever. - - Only rank 0 calls the barrier kernel. Other ranks skip it, so rank 0 - spins waiting for them and hits the MAX_SPINS assert. - """ - shmem = iris.iris(1 << 20) - rank = shmem.get_rank() - num_ranks = shmem.get_num_ranks() - heap_bases = shmem.get_heap_bases() - - if num_ranks < 2: - pytest.skip("Need at least 2 ranks") - - shmem.barrier() - - flags = shmem._device_barrier_state.setdefault(None, shmem.zeros((num_ranks,), dtype=torch.int32)) - - try: - if rank == 0: - _, rank_global, world_size, rank_start, rank_stride = extract_group_info(None, rank, num_ranks) - _device_barrier_kernel[(1,)]( - flags, - rank_global, - world_size, - rank_start, - rank_stride, - heap_bases, - MAX_SPINS=1000, - ) - with pytest.raises(RuntimeError, match="device-side assert"): - torch.cuda.synchronize() - finally: - shmem.barrier() + _call_barrier(shmem, barrier_type) del shmem gc.collect()