diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py index 2a61fc3d..190c9607 100644 --- a/iris/ccl/all_gather.py +++ b/iris/ccl/all_gather.py @@ -137,6 +137,7 @@ def persistent_all_gather( target_rank, heap_bases, mask=combined_mask, + hint=(1, BLOCK_SIZE_N), ) @@ -274,6 +275,7 @@ def persistent_all_gather_partitioned( target_rank, heap_bases, mask=combined_mask, + hint=(1, BLOCK_SIZE_N), ) diff --git a/iris/ccl/all_reduce.py b/iris/ccl/all_reduce.py index a0d44521..8503907a 100644 --- a/iris/ccl/all_reduce.py +++ b/iris/ccl/all_reduce.py @@ -334,6 +334,7 @@ def persistent_all_reduce_spinlock( dest_rank, heap_bases, mask=mask, + hint=(1, BLOCK_SIZE_N), ) # Release lock for this tile at dest_rank @@ -539,6 +540,7 @@ def persistent_all_reduce_ring( next_rank, heap_bases, mask=mask, + hint=(1, BLOCK_SIZE_N), ) tl.debug_barrier() iris.atomic_xchg( @@ -668,7 +670,7 @@ def persistent_all_reduce_two_shot( remote_rank_idx = (start_rank_idx + i) % world_size remote_rank = rank_start + remote_rank_idx * rank_stride if remote_rank_idx != group_rank: - iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases) + iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases, hint=(1, BLOCK_SIZE_N)) # Slow path: MASKED (only boundary tiles land here) # This path handles tiles at tensor boundaries where not all elements are valid. @@ -691,7 +693,15 @@ def persistent_all_reduce_two_shot( remote_rank_idx = (start_rank_idx + i) % world_size remote_rank = rank_start + remote_rank_idx * rank_stride if remote_rank_idx != group_rank: - iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases, mask=mask) + iris.store( + out_ptr, + reduced, + iris_rank, + remote_rank, + heap_bases, + mask=mask, + hint=(1, BLOCK_SIZE_N), + ) def all_reduce( diff --git a/iris/ccl/all_to_all.py b/iris/ccl/all_to_all.py index 9010ef06..9ff16a1b 100644 --- a/iris/ccl/all_to_all.py +++ b/iris/ccl/all_to_all.py @@ -144,6 +144,7 @@ def persistent_all_to_all( iris_rank, target_rank, heap_bases, + hint=(1, BLOCK_SIZE_N), ) # Slow path: MASKED (only boundary tiles land here) @@ -183,6 +184,7 @@ def persistent_all_to_all( target_rank, heap_bases, mask=mask, + hint=(1, BLOCK_SIZE_N), ) diff --git a/iris/x/all_gather.py b/iris/x/all_gather.py index a357ab7f..a8c84bde 100644 --- a/iris/x/all_gather.py +++ b/iris/x/all_gather.py @@ -64,7 +64,12 @@ def all_gather( # Scatter along N dimension: write to [:, ctx.rank * N_local : (ctx.rank+1) * N_local] dst_ptr, combined_mask = dst_view.offset_tile_ptr(tile, offset_n=ctx.rank * N_local, src_mask=None) - # Use iris.store to write to dest_rank's memory + # Use iris.store to write to dest_rank's memory. + # hint=(1, tile.block_n) asserts per-row contiguity only (BLOCK_N consecutive + # elements within each row). Using (tile.block_m, tile.block_n) would + # assert cross-row contiguity which is false when BLOCK_N < N (stride_m > BLOCK_N), + # causing getOrderFromContiguity to choose dim-0 for vectorization and emitting + # scalar buffer_store_short writes to wrong addresses. iris.store( dst_ptr, tile.data, @@ -72,4 +77,5 @@ def all_gather( dest_rank, # to_rank (destination rank) ctx.heap_bases, mask=combined_mask, + hint=(1, tile.block_n), ) diff --git a/iris/x/all_reduce.py b/iris/x/all_reduce.py index 8b7c8df8..901f5adb 100644 --- a/iris/x/all_reduce.py +++ b/iris/x/all_reduce.py @@ -313,6 +313,7 @@ def all_reduce_two_shot( dest_rank, # to_rank (destination rank) ctx.heap_bases, mask=mask, + hint=(1, tile.block_n), ) diff --git a/tests/ccl/test_all_gather.py b/tests/ccl/test_all_gather.py index f2c50e2f..7858ed18 100644 --- a/tests/ccl/test_all_gather.py +++ b/tests/ccl/test_all_gather.py @@ -21,14 +21,16 @@ ], ) @pytest.mark.parametrize( - "M, N", + "M, N, block_size_m, block_size_n", [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large + (128, 64, 32, 64), # Small + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) + (1024, 256, 32, 64), # Medium + (8192, 8192, 32, 64), # Large ], ) -def test_all_gather(dtype, M, N): +def test_all_gather(dtype, M, N, block_size_m, block_size_n): """Test all-gather functionality by comparing against PyTorch's implementation.""" # Ensure torch.distributed is initialized (should be done by test runner) if not dist.is_initialized(): @@ -62,7 +64,7 @@ def test_all_gather(dtype, M, N): # Run Iris all_gather shmem.barrier() - config = Config() + config = Config(block_size_m=block_size_m, block_size_n=block_size_n) shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) torch.cuda.synchronize() @@ -97,14 +99,16 @@ def test_all_gather(dtype, M, N): ], ) @pytest.mark.parametrize( - "M, N", + "M, N, block_size_m, block_size_n", [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large + (128, 64, 32, 64), # Small + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) + (1024, 256, 32, 64), # Medium + (8192, 8192, 32, 64), # Large ], ) -def test_all_gather_partitioned(dtype, M, N): +def test_all_gather_partitioned(dtype, M, N, block_size_m, block_size_n): """Test all-gather with partitioned variant by comparing against PyTorch's implementation.""" # Ensure torch.distributed is initialized (should be done by test runner) if not dist.is_initialized(): @@ -140,7 +144,9 @@ def test_all_gather_partitioned(dtype, M, N): # COMM_SMS must be divisible by world_size for partitioned variant comm_sms = 64 # Assuming world_size divides 64 (e.g., 2, 4, 8) shmem.barrier() - config = Config(all_gather_variant="partitioned", comm_sms=comm_sms) + config = Config( + block_size_m=block_size_m, block_size_n=block_size_n, all_gather_variant="partitioned", comm_sms=comm_sms + ) shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) torch.cuda.synchronize() diff --git a/tests/ccl/test_all_reduce.py b/tests/ccl/test_all_reduce.py index ffd55e9d..1862e0e3 100644 --- a/tests/ccl/test_all_reduce.py +++ b/tests/ccl/test_all_reduce.py @@ -32,14 +32,16 @@ ], ) @pytest.mark.parametrize( - "M, N", + "M, N, block_size_m, block_size_n", [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large + (128, 64, 32, 64), # Small + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) + (1024, 256, 32, 64), # Medium + (8192, 8192, 32, 64), # Large ], ) -def test_all_reduce(variant, dtype, M, N): +def test_all_reduce(variant, dtype, M, N, block_size_m, block_size_n): """Test all-reduce functionality by comparing against PyTorch's implementation.""" # Ensure torch.distributed is initialized (should be done by test runner) if not dist.is_initialized(): @@ -70,7 +72,7 @@ def test_all_reduce(variant, dtype, M, N): # Run Iris all_reduce with specified variant shmem.barrier() - config = Config(all_reduce_variant=variant) + config = Config(all_reduce_variant=variant, block_size_m=block_size_m, block_size_n=block_size_n) if variant == "two_shot": # Test both distribution modes for two_shot config.all_reduce_distribution = 0 # striding diff --git a/tests/ccl/test_all_to_all.py b/tests/ccl/test_all_to_all.py index 76478f5a..99e9bf19 100644 --- a/tests/ccl/test_all_to_all.py +++ b/tests/ccl/test_all_to_all.py @@ -21,14 +21,16 @@ ], ) @pytest.mark.parametrize( - "M, N", + "M, N, block_size_m, block_size_n", [ - (128, 64), # Small - (1024, 256), # Medium - (8192, 8192), # Large + (128, 64, 32, 64), # Small + (128, 128, 32, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 32, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) + (1024, 256, 32, 64), # Medium + (8192, 8192, 32, 64), # Large ], ) -def test_all_to_all(dtype, M, N): +def test_all_to_all(dtype, M, N, block_size_m, block_size_n): """Test all-to-all functionality by comparing against PyTorch's implementation.""" # Ensure torch.distributed is initialized (should be done by test runner) if not dist.is_initialized(): @@ -74,7 +76,7 @@ def test_all_to_all(dtype, M, N): # Run Iris all_to_all shmem.barrier() - config = Config() + config = Config(block_size_m=block_size_m, block_size_n=block_size_n) shmem.ccl.all_to_all(iris_output_concat, iris_input_concat, config=config) torch.cuda.synchronize() diff --git a/tests/x/test_all_gather.py b/tests/x/test_all_gather.py index 93dff4ad..17cb74f3 100644 --- a/tests/x/test_all_gather.py +++ b/tests/x/test_all_gather.py @@ -78,7 +78,9 @@ def x_all_gather_kernel( @pytest.mark.parametrize( "M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [ - (128, 64, 64, 32), # Small + (128, 64, 64, 32), # Small – BLOCK_N < N (partial-width, was the original failing case) + (128, 128, 64, 32), # Multiple N blocks per rank – BLOCK_N < N/world_size (2 tiles in N per rank) + (256, 128, 64, 16), # Very small BLOCK_N to stress 16-bit vectorization with partial-width tiles (1024, 256, 128, 128), # Medium (2048, 2048, 256, 256), # Large # TODO: Fix non-aligned dimension handling in all_gather for irregular tiling @@ -258,7 +260,9 @@ def x_all_gather_ctx_api_kernel( (torch.float32, 1e-5, 1e-5), ], ) -@pytest.mark.parametrize("M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64)]) +@pytest.mark.parametrize( + "M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [(256, 128, 64, 64), (128, 64, 64, 32), (128, 128, 64, 32)] +) def test_all_gather_ctx_api(gather_dim, dtype, atol, rtol, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): """Test tile-level all-gather using direct function call (ctx methods removed).""" if not dist.is_initialized(): diff --git a/tests/x/test_all_reduce.py b/tests/x/test_all_reduce.py index 30549a50..6e8934c1 100644 --- a/tests/x/test_all_reduce.py +++ b/tests/x/test_all_reduce.py @@ -223,6 +223,8 @@ def x_all_reduce_spinlock_kernel( "M, N, BLOCK_SIZE_M, BLOCK_SIZE_N", [ (128, 64, 64, 32), # Small + (128, 128, 64, 32), # BLOCK_N < N/world_size (partial-width, multi-block per rank) + (256, 128, 64, 16), # Minimum BLOCK_N=16 (16-bit vectorization path) (1024, 256, 128, 128), # Medium (2048, 2048, 256, 256), # Large # (100, 100, 64, 64), # Non-aligned dimensions - DISABLED: other=0.0 not supported