Skip to content

Port iris.x.all_to_all to Gluon with sweep benchmark, scatter plot, and assembly report#417

Draft
Copilot wants to merge 3 commits intomainfrom
copilot/port-all-to-gluon
Draft

Port iris.x.all_to_all to Gluon with sweep benchmark, scatter plot, and assembly report#417
Copilot wants to merge 3 commits intomainfrom
copilot/port-all-to-gluon

Conversation

Copy link
Contributor

Copilot AI commented Mar 4, 2026

Ports iris.x.all_to_all to a Gluon (@gluon.jit) backend and adds infrastructure to validate correctness, benchmark at scale, and compare generated assembly between the two backends.

New primitive — iris.x.all_to_all_gluon

  • iris/x/all_to_all.py: adds all_to_all_gluon, a @gluon.jit tile-level all-to-all using IrisDeviceCtx
  • Exported from iris.x when Gluon is available (_GLUON_ALL_TO_ALL_AVAILABLE guard)
  • Key differences vs. Triton variant:
Aspect Triton Gluon
Rank loop dynamic range(first, last+1) range(world_size) — fully unrolled (constexpr)
Tile processing 2-D vectorised load/store row-by-row with gl.BlockedLayout([1],[64],[4],[0])
Remote read iris.load(ptr, cur_rank, src_rank, heap_bases, mask) ctx.load(ptr, src_rank, mask)
Local/remote branch runtime predicate resolved at compile time
# In a @gluon.jit kernel:
iris.x.all_to_all_gluon(
    IrisDeviceCtx, context_tensor,
    src_ptr, dst_ptr, M, N,
    stride_src_m, stride_src_n, stride_dst_m, stride_dst_n,
    pid_m, pid_n, N_per_rank, cur_rank, world_size,
    BLOCK_SIZE_M, BLOCK_SIZE_N,
)

Validation

  • tests/x/test_all_to_all_gluon.py: correctness against dist.all_to_all across 5 shapes × 3 dtypes (fp16/fp32/bf16)

Benchmark & analysis tooling

  • benchmark/ccl/all_to_all/benchmark_x.py — extended with:

    • --sweep: iterates 14 (M, N_per_rank) configs from 128×64 to 16384×256
    • --dump_asm: compiles both kernels and writes AMDGCN ISA to .asm files (reports VGPR count and spill count per backend)
    • JSON output stores a list of per-size results for downstream plotting
  • benchmark/ccl/all_to_all/plot_x_all_to_all.py — new scatter plot script:

    • Blue circles (Triton) vs. orange squares (Gluon), X = total bytes/rank (log₂), Y = GB/s
    • Point labels with (M×N) dimensions; also prints a plain-text table
  • docs/reference/gluon/all_to_all_report.md — technical report covering algorithm differences, assembly analysis methodology (VGPR, spills, vectorisation width, branch reduction), and how to reproduce results

Original prompt

This section details on the original issue you should resolve

<issue_title>Gluon All2All</issue_title>
<issue_description>Port iris.x all to all from triton to gluon. Prepare a detailed report comparing performance and generated assembly

For reference, checkout the Triton source code (pip show triton)</issue_description>

Comments on the Issue (you are @copilot in this section)


💬 We'd love your input! Share your thoughts on Copilot coding agent in our 2 minute survey.

…hmark

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copilot AI changed the title [WIP] Port iris.x all to all from triton to gluon Port iris.x.all_to_all to Gluon Mar 4, 2026
Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copilot AI changed the title Port iris.x.all_to_all to Gluon Port iris.x.all_to_all to Gluon with sweep benchmark, scatter plot, and assembly report Mar 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gluon All2All

2 participants