Skip to content
2 changes: 2 additions & 0 deletions iris/ccl/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def persistent_all_gather(
target_rank,
heap_bases,
mask=combined_mask,
hint=(1, BLOCK_SIZE_N),
)


Expand Down Expand Up @@ -274,6 +275,7 @@ def persistent_all_gather_partitioned(
target_rank,
heap_bases,
mask=combined_mask,
hint=(1, BLOCK_SIZE_N),
)


Expand Down
14 changes: 12 additions & 2 deletions iris/ccl/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def persistent_all_reduce_spinlock(
dest_rank,
heap_bases,
mask=mask,
hint=(1, BLOCK_SIZE_N),
)

# Release lock for this tile at dest_rank
Expand Down Expand Up @@ -539,6 +540,7 @@ def persistent_all_reduce_ring(
next_rank,
heap_bases,
mask=mask,
hint=(1, BLOCK_SIZE_N),
)
tl.debug_barrier()
iris.atomic_xchg(
Expand Down Expand Up @@ -668,7 +670,7 @@ def persistent_all_reduce_two_shot(
remote_rank_idx = (start_rank_idx + i) % world_size
remote_rank = rank_start + remote_rank_idx * rank_stride
if remote_rank_idx != group_rank:
iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases)
iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases, hint=(1, BLOCK_SIZE_N))

# Slow path: MASKED (only boundary tiles land here)
# This path handles tiles at tensor boundaries where not all elements are valid.
Expand All @@ -691,7 +693,15 @@ def persistent_all_reduce_two_shot(
remote_rank_idx = (start_rank_idx + i) % world_size
remote_rank = rank_start + remote_rank_idx * rank_stride
if remote_rank_idx != group_rank:
iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases, mask=mask)
iris.store(
out_ptr,
reduced,
iris_rank,
remote_rank,
heap_bases,
mask=mask,
hint=(1, BLOCK_SIZE_N),
)


def all_reduce(
Expand Down
2 changes: 2 additions & 0 deletions iris/ccl/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def persistent_all_to_all(
iris_rank,
target_rank,
heap_bases,
hint=(1, BLOCK_SIZE_N),
)

# Slow path: MASKED (only boundary tiles land here)
Expand Down Expand Up @@ -183,6 +184,7 @@ def persistent_all_to_all(
target_rank,
heap_bases,
mask=mask,
hint=(1, BLOCK_SIZE_N),
)


Expand Down
8 changes: 7 additions & 1 deletion iris/x/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,18 @@ def all_gather(
# Scatter along N dimension: write to [:, ctx.rank * N_local : (ctx.rank+1) * N_local]
dst_ptr, combined_mask = dst_view.offset_tile_ptr(tile, offset_n=ctx.rank * N_local, src_mask=None)

# Use iris.store to write to dest_rank's memory
# Use iris.store to write to dest_rank's memory.
# hint=(1, tile.block_n) asserts per-row contiguity only (BLOCK_N consecutive
# elements within each row). Using (tile.block_m, tile.block_n) would
# assert cross-row contiguity which is false when BLOCK_N < N (stride_m > BLOCK_N),
# causing getOrderFromContiguity to choose dim-0 for vectorization and emitting
# scalar buffer_store_short writes to wrong addresses.
iris.store(
dst_ptr,
tile.data,
ctx.rank, # from_rank (current rank)
dest_rank, # to_rank (destination rank)
ctx.heap_bases,
mask=combined_mask,
hint=(1, tile.block_n),
)
1 change: 1 addition & 0 deletions iris/x/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def all_reduce_two_shot(
dest_rank, # to_rank (destination rank)
ctx.heap_bases,
mask=mask,
hint=(1, tile.block_n),
)


Expand Down
30 changes: 18 additions & 12 deletions tests/ccl/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
],
)
@pytest.mark.parametrize(
"M, N",
"M, N, block_size_m, block_size_n",
[
(128, 64), # Small
(1024, 256), # Medium
(8192, 8192), # Large
(128, 64, 32, 64), # Small
(128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
(256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
(1024, 256, 32, 64), # Medium
(8192, 8192, 32, 64), # Large
],
)
def test_all_gather(dtype, M, N):
def test_all_gather(dtype, M, N, block_size_m, block_size_n):
"""Test all-gather functionality by comparing against PyTorch's implementation."""
# Ensure torch.distributed is initialized (should be done by test runner)
if not dist.is_initialized():
Expand Down Expand Up @@ -62,7 +64,7 @@ def test_all_gather(dtype, M, N):

# Run Iris all_gather
shmem.barrier()
config = Config()
config = Config(block_size_m=block_size_m, block_size_n=block_size_n)
shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config)
torch.cuda.synchronize()

Expand Down Expand Up @@ -97,14 +99,16 @@ def test_all_gather(dtype, M, N):
],
)
@pytest.mark.parametrize(
"M, N",
"M, N, block_size_m, block_size_n",
[
(128, 64), # Small
(1024, 256), # Medium
(8192, 8192), # Large
(128, 64, 32, 64), # Small
(128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
(256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
(1024, 256, 32, 64), # Medium
(8192, 8192, 32, 64), # Large
],
)
def test_all_gather_partitioned(dtype, M, N):
def test_all_gather_partitioned(dtype, M, N, block_size_m, block_size_n):
"""Test all-gather with partitioned variant by comparing against PyTorch's implementation."""
# Ensure torch.distributed is initialized (should be done by test runner)
if not dist.is_initialized():
Expand Down Expand Up @@ -140,7 +144,9 @@ def test_all_gather_partitioned(dtype, M, N):
# COMM_SMS must be divisible by world_size for partitioned variant
comm_sms = 64 # Assuming world_size divides 64 (e.g., 2, 4, 8)
shmem.barrier()
config = Config(all_gather_variant="partitioned", comm_sms=comm_sms)
config = Config(
block_size_m=block_size_m, block_size_n=block_size_n, all_gather_variant="partitioned", comm_sms=comm_sms
)
shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config)
torch.cuda.synchronize()

Expand Down
14 changes: 8 additions & 6 deletions tests/ccl/test_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@
],
)
@pytest.mark.parametrize(
"M, N",
"M, N, block_size_m, block_size_n",
[
(128, 64), # Small
(1024, 256), # Medium
(8192, 8192), # Large
(128, 64, 32, 64), # Small
(128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
(256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
(1024, 256, 32, 64), # Medium
(8192, 8192, 32, 64), # Large
],
)
def test_all_reduce(variant, dtype, M, N):
def test_all_reduce(variant, dtype, M, N, block_size_m, block_size_n):
"""Test all-reduce functionality by comparing against PyTorch's implementation."""
# Ensure torch.distributed is initialized (should be done by test runner)
if not dist.is_initialized():
Expand Down Expand Up @@ -70,7 +72,7 @@ def test_all_reduce(variant, dtype, M, N):

# Run Iris all_reduce with specified variant
shmem.barrier()
config = Config(all_reduce_variant=variant)
config = Config(all_reduce_variant=variant, block_size_m=block_size_m, block_size_n=block_size_n)
if variant == "two_shot":
# Test both distribution modes for two_shot
config.all_reduce_distribution = 0 # striding
Expand Down
14 changes: 8 additions & 6 deletions tests/ccl/test_all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
],
)
@pytest.mark.parametrize(
"M, N",
"M, N, block_size_m, block_size_n",
[
(128, 64), # Small
(1024, 256), # Medium
(8192, 8192), # Large
(128, 64, 32, 64), # Small
(128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
(256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
(1024, 256, 32, 64), # Medium
(8192, 8192, 32, 64), # Large
],
)
def test_all_to_all(dtype, M, N):
def test_all_to_all(dtype, M, N, block_size_m, block_size_n):
"""Test all-to-all functionality by comparing against PyTorch's implementation."""
# Ensure torch.distributed is initialized (should be done by test runner)
if not dist.is_initialized():
Expand Down Expand Up @@ -74,7 +76,7 @@ def test_all_to_all(dtype, M, N):

# Run Iris all_to_all
shmem.barrier()
config = Config()
config = Config(block_size_m=block_size_m, block_size_n=block_size_n)
shmem.ccl.all_to_all(iris_output_concat, iris_input_concat, config=config)
torch.cuda.synchronize()

Expand Down
8 changes: 6 additions & 2 deletions tests/x/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def x_all_gather_kernel(
@pytest.mark.parametrize(
"M, N, BLOCK_SIZE_M, BLOCK_SIZE_N",
[
(128, 64, 64, 32), # Small
(128, 64, 64, 32), # Small – BLOCK_N < N (partial-width, was the original failing case)
(128, 128, 64, 32), # Multiple N blocks per rank – BLOCK_N < N/world_size (2 tiles in N per rank)
(256, 128, 64, 16), # Very small BLOCK_N to stress 16-bit vectorization with partial-width tiles
(1024, 256, 128, 128), # Medium
(2048, 2048, 256, 256), # Large
# TODO: Fix non-aligned dimension handling in all_gather for irregular tiling
Expand Down Expand Up @@ -258,7 +260,9 @@ def x_all_gather_ctx_api_kernel(
(torch.float32, 1e-5, 1e-5),
],
)
@pytest.mark.parametrize("M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64)])
@pytest.mark.parametrize(
"M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64), (128, 64, 64, 32), (128, 128, 64, 32)]
)
def test_all_gather_ctx_api(gather_dim, dtype, atol, rtol, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N):
"""Test tile-level all-gather using direct function call (ctx methods removed)."""
if not dist.is_initialized():
Expand Down
2 changes: 2 additions & 0 deletions tests/x/test_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def x_all_reduce_spinlock_kernel(
"M, N, BLOCK_SIZE_M, BLOCK_SIZE_N",
[
(128, 64, 64, 32), # Small
(128, 128, 64, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank)
(256, 128, 64, 16), # Minimum BLOCK_N=16 (16-bit vectorization path)
(1024, 256, 128, 128), # Medium
(2048, 2048, 256, 256), # Large
# (100, 100, 64, 64), # Non-aligned dimensions - DISABLED: other=0.0 not supported
Expand Down
Loading