Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions iris/_distributed_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.


import torch
import torch.distributed as dist
import numpy as np
import triton
import triton.language as tl


def _infer_device():
Expand Down Expand Up @@ -207,6 +210,76 @@ def distributed_broadcast_tensor(value_to_broadcast=None, root=0):
return obj[0]


def extract_group_info(group, rank, num_ranks):
"""
Extract rank and stride information for a process group.

Args:
group: ProcessGroup or None. If None, uses the provided rank/num_ranks
as the default (all-ranks) group.
rank: Global rank of the current process.
num_ranks: Total number of ranks in the default group.

Returns:
Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride):
- rank_in_group: Rank within the group (0-indexed)
- rank_global: Global rank of this process
- world_size: Number of ranks in the group
- rank_start: Starting global rank of the group
- rank_stride: Stride between consecutive ranks in the group

Examples:
>>> # group=None: all ranks [0,1,2,3], current global rank is 2
>>> extract_group_info(None, 2, 4)
(2, 2, 4, 0, 1)

>>> # DP group: strided ranks [0,4,8,12], current global rank is 8
>>> extract_group_info(dp_group, 8, 16)
(2, 8, 4, 0, 4)
"""
if group is None:
return rank, rank, num_ranks, 0, 1

if not dist.is_initialized():
raise RuntimeError(
"torch.distributed must be initialized to use ProcessGroup. "
"Call torch.distributed.init_process_group() first."
)

group_ranks = dist.get_process_group_ranks(group)
world_size = len(group_ranks)
rank_global = rank

if rank_global not in group_ranks:
raise RuntimeError(
f"Rank {rank_global} is not part of the specified process group. Group contains ranks: {group_ranks}"
)

rank_in_group = group_ranks.index(rank_global)

if len(group_ranks) > 1:
strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))]
if not all(s == strides[0] for s in strides):
raise NotImplementedError(
f"Non-strided process groups are not yet supported. "
f"Group ranks: {group_ranks}. "
f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])."
)
rank_start = group_ranks[0]
rank_stride = strides[0]
if rank_stride == 0:
raise ValueError(
f"Invalid process group: rank_stride is 0, indicating duplicate ranks. "
f"Group ranks: {group_ranks}. "
f"Each rank must appear exactly once in a process group."
)
else:
rank_start = group_ranks[0]
rank_stride = 1

return rank_in_group, rank_global, world_size, rank_start, rank_stride


def distributed_barrier(group=None):
"""
Synchronization barrier using PyTorch distributed.
Expand All @@ -220,6 +293,98 @@ def distributed_barrier(group=None):
dist.barrier(group=group)


@triton.jit
def _translate_ptr(ptr, from_rank, to_rank, heap_bases):
"""Translate a pointer from one rank's address space to another's."""
from_base = tl.load(heap_bases + from_rank)
to_base = tl.load(heap_bases + to_rank)
offset = tl.cast(ptr, tl.uint64) - from_base
translated_ptr = tl.cast(tl.cast(to_base, tl.pointer_type(tl.int8)) + offset, ptr.dtype)
return translated_ptr
Comment thread
micmelesse marked this conversation as resolved.


@triton.jit
def _device_barrier_kernel(
flags_ptr,
iris_rank,
world_size: tl.constexpr,
rank_start,
rank_stride,
heap_bases,
MAX_SPINS: tl.constexpr = 1_000_000_000,
):
"""
Device-side barrier using atomic operations on the symmetric heap.
CUDA graph capturable.

Stateless w.r.t. host-side epoch tracking: there is no CPU-side epoch
counter. Each rank's flag on the heap serves as its own epoch counter,
managed entirely by the GPU via atomic_add. A persistent per-group flags
tensor is cached in ``_device_barrier_state``.

Launched with grid=(1,). A single CTA:
1. Atomically increments its own flag (atomic_add, release)
2. Serially polls each remote rank's flag for the same value (acquire)
"""
# Increment own flag and determine target
own_flag_ptr = flags_ptr + iris_rank
own_translated = _translate_ptr(own_flag_ptr, iris_rank, iris_rank, heap_bases)
old = tl.atomic_add(own_translated, 1, sem="release", scope="sys")
target = old + 1

# Poll each remote rank serially
for i in range(world_size):
remote_rank = rank_start + i * rank_stride
if remote_rank != iris_rank:
remote_flag_ptr = flags_ptr + remote_rank
remote_translated = _translate_ptr(remote_flag_ptr, iris_rank, remote_rank, heap_bases)
spin_count = 0
while (
tl.atomic_cas(
remote_translated,
target,
target,
sem="acquire",
scope="sys",
)
< target
):
spin_count += 1
tl.device_assert(spin_count < MAX_SPINS, "device_barrier: timeout")


def distributed_device_barrier(flags, group, rank, num_ranks, heap_bases):
"""
Device-side barrier using atomic operations on the symmetric heap.
CUDA graph capturable.

Unlike ``distributed_barrier`` which uses host-side ``torch.distributed.barrier()``,
this launches a single-CTA Triton kernel that synchronizes via
device-side atomics, making it safe to use during CUDA graph capture.

Stateless w.r.t. host-side epoch tracking: each rank's flag on the
symmetric heap serves as its own epoch counter, managed entirely by
the GPU via atomic_add. A persistent per-group flags tensor is cached
in ``_device_barrier_state``.

Args:
flags: int32 tensor on symmetric heap, one element per rank.
group: ProcessGroup or None. If None, uses all ranks.
rank: Global rank of this process.
num_ranks: Total number of ranks in the default group.
heap_bases: Tensor of heap base addresses for all ranks.
"""
_, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, rank, num_ranks)
_device_barrier_kernel[(1,)](
flags,
rank_global,
world_size,
rank_start,
rank_stride,
heap_bases,
)


def init_distributed():
"""
Initialize PyTorch distributed and return communicator info.
Expand Down
81 changes: 5 additions & 76 deletions iris/ccl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Tuple
import triton
import triton.language as tl
from iris._distributed_helpers import extract_group_info as _extract_group_info


@triton.jit()
Expand Down Expand Up @@ -67,83 +68,11 @@ def extract_group_info(group, shmem) -> Tuple[int, int, int, int, int]:

Returns:
Tuple of (rank_in_group, rank_global, world_size, rank_start, rank_stride)
- rank_in_group: Rank within the group (0-indexed), used for tile assignment and comparisons
- rank_global: Global rank of this process, used for iris RMA operations (heap_bases indexing)
- rank_in_group: Rank within the group (0-indexed)
- rank_global: Global rank of this process
- world_size: Number of ranks in the group
- rank_start: Starting global rank of the group
- rank_stride: Stride between consecutive ranks in the group

Examples:
>>> # group=None: all ranks [0,1,2,3], current global rank is 2
>>> extract_group_info(None, shmem)
(2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1

>>> # TP group: consecutive ranks [0,1,2,3], current global rank is 2
>>> extract_group_info(tp_group, shmem)
(2, 2, 4, 0, 1) # rank_in_group=2, rank_global=2, world_size=4, start=0, stride=1

>>> # DP group: strided ranks [0,4,8,12], current global rank is 8
>>> extract_group_info(dp_group, shmem)
(2, 8, 4, 0, 4) # rank_in_group=2, rank_global=8, world_size=4, start=0, stride=4
"""
if group is None:
# Use all ranks in shmem context
# When group is None, rank_in_group equals rank_global
rank_global = shmem.get_rank()
rank_in_group = rank_global
world_size = shmem.get_num_ranks()
rank_start = 0
rank_stride = 1
return rank_in_group, rank_global, world_size, rank_start, rank_stride

# Extract from ProcessGroup
import torch.distributed as dist

if not dist.is_initialized():
raise RuntimeError(
"torch.distributed must be initialized to use ProcessGroup. "
"Call torch.distributed.init_process_group() first."
)

group_ranks = dist.get_process_group_ranks(group)
world_size = len(group_ranks)
rank_global = dist.get_rank()

if rank_global not in group_ranks:
raise RuntimeError(
f"Current rank {rank_global} is not part of the specified process group. "
f"Group contains ranks: {group_ranks}"
)

rank_in_group = group_ranks.index(rank_global)

# Detect stride pattern
if len(group_ranks) > 1:
# Check if all consecutive pairs have the same stride
strides = [group_ranks[i] - group_ranks[i - 1] for i in range(1, len(group_ranks))]
is_strided = all(s == strides[0] for s in strides)

if is_strided:
rank_start = group_ranks[0]
rank_stride = strides[0]

# Validate rank_stride is not zero (would indicate duplicate ranks)
if rank_stride == 0:
raise ValueError(
f"Invalid process group: rank_stride is 0, indicating duplicate ranks. "
f"Group ranks: {group_ranks}. "
f"Each rank must appear exactly once in a process group."
)
else:
# Non-strided group - not supported yet
raise NotImplementedError(
f"Non-strided process groups are not yet supported. "
f"Group ranks: {group_ranks}. "
f"Please use groups with uniform stride (e.g., [0,1,2,3] or [0,4,8,12])."
)
else:
# Single rank group
rank_start = group_ranks[0]
rank_stride = 1

return rank_in_group, rank_global, world_size, rank_start, rank_stride

return _extract_group_info(group, shmem.get_rank(), shmem.get_num_ranks())
35 changes: 35 additions & 0 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from iris._distributed_helpers import (
init_distributed,
distributed_barrier,
distributed_device_barrier,
distributed_broadcast_scalar,
distributed_broadcast_tensor,
)
Expand All @@ -55,6 +56,7 @@
)
from iris.symmetric_heap import SymmetricHeap
import numpy as np
from typing import Any
import torch
import logging

Expand Down Expand Up @@ -135,6 +137,9 @@ def __init__(self, heap_size=1 << 30, allocator_type="torch"):
# Lazy initialization for ops interface
self._ops = None

# Device-side barrier state, keyed by process group (None = all ranks).
self._device_barrier_state: dict[Any, torch.Tensor] = {}

# Initialize tracing
self.tracing = Tracing(self)

Expand Down Expand Up @@ -989,6 +994,36 @@ def barrier(self, stream=None, group=None):
# Distributed barrier
distributed_barrier(group=group)

def device_barrier(self, group=None):
"""
Device-side barrier that is CUDA graph capturable.

Unlike ``barrier()`` which uses host-side ``torch.distributed.barrier()``,
this uses device-side atomic operations on the symmetric heap to synchronize
ranks. Stateless w.r.t. host-side epoch tracking: each rank's flag on
the heap serves as its own epoch counter, managed entirely by the GPU
via atomic_add. A persistent per-group flags tensor is cached in
``_device_barrier_state``.

Args:
group (ProcessGroup, optional): The process group to synchronize.
If None, uses all ranks in the shmem context.

Example:
>>> ctx = iris.iris(1 << 20)
>>> ctx.device_barrier() # Synchronize all ranks on device
"""
if group not in self._device_barrier_state:
self._device_barrier_state[group] = self.zeros((self.num_ranks,), dtype=torch.int32)

distributed_device_barrier(
self._device_barrier_state[group],
group,
self.cur_rank,
self.num_ranks,
self.get_heap_bases(),
)

def get_device(self):
"""
Get the underlying device where the Iris symmetric heap resides.
Expand Down
Loading
Loading