Fix x.all_gather hint: use (1, block_n) for correct vectorization with partial-width tiles#434
Conversation
…8-64-64-32 Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
|
@copilot Do not merge the current fix as-is. Removing the hint is not acceptable: without the hint, loads/stores become non-vectorized and we cannot ship that. Please investigate this issue more deeply:
Report back with what you find in the assembly and a concrete, reasonable solution that preserves vectorization. |
… for correct vectorization Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Investigated in assembly — here's what was found: Assembly comparison (M=128, N=64, BLOCK_M=64, BLOCK_N=32, float16):
Why Why it accidentally worked for other shapes (e.g., BLOCK_M=BLOCK_N=128): Fix ( |
|
@copilot I merged this. Can you suggest tests to make sure all APIs are tested against this behavior? Maybe the existing tests are enough but we need some new shapes in the parametrized tests. |
iris.x.all_gatherproduced incorrect results (max diff 1.0) for shapes where the tile width is narrower than the tensor width (e.g. M=128, N=64, BLOCK_M=64, BLOCK_N=32).Root Cause
The
hint=(tile.block_m, tile.block_n)passed toiris.storeappliestl.max_contiguousto the translated pointer block. Triton'sgetOrderFromContiguity([BLOCK_M, BLOCK_N])picks the dimension with the highest contiguity for vectorization. WhenBLOCK_M > BLOCK_N(e.g. 64 > 32), dim-0 is selected — but for a row-major tensor withstride_m = N > BLOCK_N, rows are not adjacent in memory. The compiler emits scalarbuffer_store_short(1 float16/instruction) with addresses derived from the incorrect dim-0 assumption, writing to wrong locations.This also explains why other shapes were unaffected: when
BLOCK_M == BLOCK_N, both dims have equal contiguity sostable_sortpreserves reverse order and dim-1 is selected correctly.Assembly evidence (M=128, N=64, BLOCK_M=64, BLOCK_N=32, float16)
(BLOCK_M, BLOCK_N)=(64, 32)buffer_store_short×8(1, BLOCK_N)=(1, 32)buffer_store_dwordx4×1buffer_store_dwordx4×1Fix
Replace
hint=(tile.block_m, tile.block_n)withhint=(1, tile.block_n)iniris/x/all_gather.py.(1, tile.block_n)correctly asserts only per-row contiguity (BLOCK_N consecutive elements in dim-1).getOrderFromContiguity([1, BLOCK_N])always selects dim-1, producingbuffer_store_dwordx4— 8× wider than the broken hint, identical to the no-hint case, and always correct for any BLOCK_N/N relationship in row-major tensors.💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.