Skip to content

support allreduce #326

Merged
coderfeli merged 32 commits intomainfrom
all_reduce_new
Apr 8, 2026
Merged

support allreduce #326
coderfeli merged 32 commits intomainfrom
all_reduce_new

Conversation

@yanboshao
Copy link
Copy Markdown
Contributor

@yanboshao yanboshao commented Apr 1, 2026

Motivation

Add Allreduce kernel in FlyDSL which include 1-stage kernel and 2-stage kernel.

Technical Details

Test Plan

Test accuracy and performance on MI308 and MI355.

Test Result

Detailed Results on MI308

1-stage Kernel Path

ws=2: all shapes always take the 1-stage path. ws=4: shapes < 160 KB. ws=8: shapes < 80 KB.

World Size = 2

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,8 f16 16.6 22.4 1.352x
1,8 bf16 16.2 15.6 0.962x
1,8 f32 13.9 18.2 1.313x
1,4096 f16 16.7 18.3 1.090x
1,4096 bf16 16.2 20.7 1.277x
1,4096 f32 16.7 16.0 0.957x
1,5120 f16 17.6 15.9 0.903x
1,5120 bf16 16.3 21.4 1.309x
1,5120 f32 14.7 15.9 1.084x
1,40960 f16 17.4 16.1 0.925x
1,40960 bf16 17.2 17.2 1.002x
1,40960 f32 13.2 17.8 1.351x
1,81920 f16 13.2 19.6 1.481x
1,81920 bf16 15.9 18.3 1.147x
1,81920 f32 18.1 20.4 1.127x
8,8192 f16 14.1 18.4 1.300x
8,8192 bf16 12.9 16.8 1.297x
8,8192 f32 15.8 19.3 1.224x
32,4096 f16 17.0 21.7 1.277x
32,4096 bf16 14.8 19.3 1.301x
32,4096 f32 21.5 25.4 1.183x
1,14336 f16 15.8 16.8 1.060x
1,14336 bf16 17.1 16.8 0.985x
4,7168 f16 15.5 16.1 1.044x
4,7168 bf16 16.3 20.2 1.240x
16,5120 f16 13.0 18.0 1.383x
16,5120 bf16 16.0 19.1 1.193x
128,8192 f16 46.9 51.3 1.095x
128,8192 bf16 46.9 52.2 1.113x
64,14336 f16 41.9 46.9 1.120x
64,14336 bf16 42.1 48.0 1.140x
256,8192 f16 81.8 87.3 1.067x
256,8192 bf16 82.0 90.4 1.102x
256,8192 f32 153.6 158.4 1.032x
16,28672 f16 26.5 30.6 1.157x
16,28672 bf16 26.8 31.1 1.159x

World Size = 4

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,8 f16 17.5 21.3 1.219x
1,8 bf16 20.2 18.8 0.930x
1,8 f32 16.5 18.1 1.095x
1,4096 f16 22.4 17.5 0.781x ⚠
1,4096 bf16 22.4 18.3 0.814x
1,4096 f32 17.3 16.4 0.948x
1,5120 f16 17.4 17.9 1.027x
1,5120 bf16 17.2 17.8 1.038x
1,5120 f32 22.7 17.5 0.770x ⚠
1,40960 f16 18.2 25.4 1.396x
1,40960 bf16 22.0 19.9 0.902x
1,14336 f16 16.9 17.3 1.024x
1,14336 bf16 21.6 21.0 0.971x
4,7168 f16 20.2 17.0 0.838x
4,7168 bf16 15.8 18.1 1.145x
8,8192 f16 16.7 23.4 1.402x
8,8192 bf16 18.0 17.9 0.994x

World Size = 8

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,8 f16 21.5 22.3 1.036x
1,8 bf16 20.3 19.2 0.946x
1,8 f32 23.9 18.0 0.753x ⚠
1,4096 f16 23.1 18.0 0.781x ⚠
1,4096 bf16 25.0 24.1 0.965x
1,4096 f32 22.0 22.9 1.042x
1,5120 f16 23.0 23.3 1.013x
1,5120 bf16 22.2 21.6 0.975x
1,5120 f32 20.1 19.3 0.961x
1,14336 f16 20.7 21.6 1.044x
1,14336 bf16 24.2 23.9 0.987x
4,7168 f16 19.9 34.4 1.730x
4,7168 bf16 17.5 26.7 1.524x

2-stage Kernel Path

Triggered when: ws=4 and tensor ≥ 160 KB, or ws=8 and tensor ≥ 80 KB (but below 4 MB at ws=8).

World Size = 4

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,40960 f32 17.5 19.5 1.113x
1,81920 f16 20.2 20.1 0.994x
1,81920 bf16 16.4 17.4 1.059x
1,81920 f32 19.7 18.0 0.914x
8,8192 f32 22.6 21.3 0.941x
32,4096 f16 17.1 22.6 1.323x
32,4096 bf16 16.3 26.5 1.620x
32,4096 f32 21.8 21.7 0.992x
16,5120 f16 24.1 20.0 0.829x
16,5120 bf16 18.2 18.0 0.988x
128,8192 f16 31.9 34.4 1.079x
128,8192 bf16 32.0 36.2 1.131x
64,14336 f16 28.9 33.3 1.154x
64,14336 bf16 28.7 34.3 1.195x
256,8192 f16 51.4 54.0 1.050x
256,8192 bf16 52.6 54.6 1.038x
256,8192 f32 87.5 91.4 1.045x
16,28672 f16 25.1 25.1 0.999x
16,28672 bf16 23.7 23.9 1.011x

World Size = 8

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,40960 f16 24.4 19.3 0.790x ⚠
1,40960 bf16 22.0 19.3 0.878x
1,40960 f32 22.6 23.1 1.023x
1,81920 f16 20.7 22.5 1.087x
1,81920 bf16 20.9 22.7 1.089x
1,81920 f32 23.8 23.7 0.998x
8,8192 f16 22.4 24.1 1.076x
8,8192 bf16 23.6 23.6 0.998x
8,8192 f32 20.9 21.4 1.021x
32,4096 f16 21.7 23.1 1.066x
32,4096 bf16 17.4 20.2 1.160x
32,4096 f32 20.9 24.3 1.161x
16,5120 f16 22.4 19.2 0.858x
16,5120 bf16 22.0 19.3 0.879x
128,8192 f16 31.6 27.7 0.875x
128,8192 bf16 28.7 27.6 0.963x
64,14336 f16 27.4 25.9 0.943x
64,14336 bf16 28.6 25.8 0.901x
256,8192 f16 35.9 37.5 1.044x
256,8192 bf16 37.6 37.9 1.007x
16,28672 f16 24.0 22.2 0.924x
16,28672 bf16 21.0 28.9 1.374x

write-mode Kernel Path (World Size = 8 only)

Triggered when: ws=8 and tensor > 4 MB.

Shape dtype FlyDSL (μs) AIter (μs) Speedup
256,8192 f32 55.3 57.0 1.030x
512,4097 f16 38.1 38.4 1.007x
512,4097 bf16 37.2 38.5 1.034x
512,8192 f16 55.9 56.9 1.018x
512,8192 bf16 55.7 57.1 1.026x
256,14336 f16 51.3 52.7 1.027x
256,14336 bf16 54.0 52.8 0.978x
1024,14336 f16 146.9 152.7 1.039x
1024,14336 bf16 149.6 153.1 1.024x
2048,8192 f16 164.8 175.1 1.063x
2048,8192 bf16 166.1 172.1 1.036x

Stress Test (Near 64 MB Limit)

Shape dtype ws Kernel Path FlyDSL (μs) AIter (μs) Speedup
2040,8192 f32 2 1-stage 1145.5 1168.3 1.020x
2040,8192 f32 4 2-stage 598.9 620.8 1.037x
2040,8192 f32 8 write-mode 308.7 326.3 1.057x
2048,8192 f32 2 1-stage 1150.3 1173.2 1.020x
2048,8192 f32 4 2-stage 601.2 621.8 1.034x
2048,8192 f32 8 write-mode 310.0 326.8 1.054x
4090,8192 f16 2 1-stage 1151.9 1177.0 1.022x
4090,8192 f16 4 2-stage 604.1 624.8 1.034x
4090,8192 f16 8 write-mode 311.1 325.0 1.045x
4096,8192 f16 2 1-stage 1152.9 1179.5 1.023x
4096,8192 f16 4 2-stage 599.3 621.4 1.037x
4096,8192 f16 8 write-mode 312.3 324.6 1.039x

Irregular Shapes

Shape dtype ws Kernel Path FlyDSL (μs) AIter (μs) Speedup
1,28672 f16 2 1-stage 12.1 16.4 1.360x
1,28672 bf16 2 1-stage 16.4 16.2 0.990x
1,28672 f16 4 1-stage 17.2 17.4 1.015x
1,28672 bf16 4 1-stage 15.8 18.2 1.155x
1,28672 f16 8 1-stage 28.8 18.7 0.648x ⚠
1,28672 bf16 8 1-stage 20.5 22.2 1.084x
3,7168 f16 2 1-stage 12.9 16.3 1.265x
3,7168 bf16 2 1-stage 17.0 19.6 1.149x
3,7168 f16 4 1-stage 20.6 17.4 0.847x
3,7168 bf16 4 1-stage 16.2 22.3 1.373x
3,7168 f16 8 1-stage 21.5 22.2 1.030x
3,7168 bf16 8 1-stage 20.7 18.5 0.893x
5,5120 f16 2 1-stage 14.2 17.1 1.203x
5,5120 bf16 2 1-stage 18.0 16.0 0.888x
5,5120 f16 4 1-stage 17.2 18.3 1.060x
5,5120 bf16 4 1-stage 17.8 19.4 1.089x
5,5120 f16 8 1-stage 21.4 17.8 0.832x
5,5120 bf16 8 1-stage 23.2 19.9 0.860x
7,14336 f16 2 1-stage 14.0 18.5 1.328x
7,14336 bf16 2 1-stage 16.5 18.2 1.104x
7,14336 f16 4 2-stage 20.1 21.4 1.065x
7,14336 bf16 4 2-stage 17.5 17.9 1.025x
7,14336 f16 8 2-stage 22.2 21.7 0.975x
7,14336 bf16 8 2-stage 24.6 23.3 0.949x
13,4096 f16 2 1-stage 13.3 16.5 1.241x
13,4096 bf16 2 1-stage 15.8 19.3 1.222x
13,4096 f16 4 1-stage 19.4 19.6 1.006x
13,4096 bf16 4 1-stage 20.7 18.1 0.874x
13,4096 f16 8 2-stage 22.7 19.4 0.856x
13,4096 bf16 8 2-stage 21.0 22.4 1.066x
17,7168 f16 2 1-stage 14.2 19.2 1.346x
17,7168 bf16 2 1-stage 14.7 18.8 1.277x
17,7168 f16 4 2-stage 16.3 21.5 1.320x
17,7168 bf16 4 2-stage 23.3 18.1 0.779x ⚠
17,7168 f16 8 2-stage 19.4 22.2 1.142x
17,7168 bf16 8 2-stage 22.4 21.0 0.938x

Detailed Results on MI355

1-stage Kernel Path

ws=2: all shapes always take the 1-stage path. ws=4: shapes < 160 KB. ws=8: shapes < 80 KB.

World Size = 2

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,8 f16 16.6 22.4 1.352x
1,8 bf16 16.2 15.6 0.962x
1,8 f32 13.9 18.2 1.313x
1,4096 f16 16.7 18.3 1.090x
1,4096 bf16 16.2 20.7 1.277x
1,4096 f32 16.7 16.0 0.957x
1,5120 f16 17.6 15.9 0.903x
1,5120 bf16 16.3 21.4 1.309x
1,5120 f32 14.7 15.9 1.084x
1,40960 f16 17.4 16.1 0.925x
1,40960 bf16 17.2 17.2 1.002x
1,40960 f32 13.2 17.8 1.351x
1,81920 f16 13.2 19.6 1.481x
1,81920 bf16 15.9 18.3 1.147x
1,81920 f32 18.1 20.4 1.127x
8,8192 f16 14.1 18.4 1.300x
8,8192 bf16 12.9 16.8 1.297x
8,8192 f32 15.8 19.3 1.224x
32,4096 f16 17.0 21.7 1.277x
32,4096 bf16 14.8 19.3 1.301x
32,4096 f32 21.5 25.4 1.183x
1,14336 f16 15.8 16.8 1.060x
1,14336 bf16 17.1 16.8 0.985x
4,7168 f16 15.5 16.1 1.044x
4,7168 bf16 16.3 20.2 1.240x
16,5120 f16 13.0 18.0 1.383x
16,5120 bf16 16.0 19.1 1.193x
128,8192 f16 46.9 51.3 1.095x
128,8192 bf16 46.9 52.2 1.113x
64,14336 f16 41.9 46.9 1.120x
64,14336 bf16 42.1 48.0 1.140x
256,8192 f16 81.8 87.3 1.067x
256,8192 bf16 82.0 90.4 1.102x
256,8192 f32 153.6 158.4 1.032x
16,28672 f16 26.5 30.6 1.157x
16,28672 bf16 26.8 31.1 1.159x

World Size = 4

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,8 f16 17.5 21.3 1.219x
1,8 bf16 20.2 18.8 0.930x
1,8 f32 16.5 18.1 1.095x
1,4096 f16 22.4 17.5 0.781x ⚠
1,4096 bf16 22.4 18.3 0.814x
1,4096 f32 17.3 16.4 0.948x
1,5120 f16 17.4 17.9 1.027x
1,5120 bf16 17.2 17.8 1.038x
1,5120 f32 22.7 17.5 0.770x ⚠
1,40960 f16 18.2 25.4 1.396x
1,40960 bf16 22.0 19.9 0.902x
1,14336 f16 16.9 17.3 1.024x
1,14336 bf16 21.6 21.0 0.971x
4,7168 f16 20.2 17.0 0.838x
4,7168 bf16 15.8 18.1 1.145x
8,8192 f16 16.7 23.4 1.402x
8,8192 bf16 18.0 17.9 0.994x

World Size = 8

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,8 f16 21.5 22.3 1.036x
1,8 bf16 20.3 19.2 0.946x
1,8 f32 23.9 18.0 0.753x ⚠
1,4096 f16 23.1 18.0 0.781x ⚠
1,4096 bf16 25.0 24.1 0.965x
1,4096 f32 22.0 22.9 1.042x
1,5120 f16 23.0 23.3 1.013x
1,5120 bf16 22.2 21.6 0.975x
1,5120 f32 20.1 19.3 0.961x
1,14336 f16 20.7 21.6 1.044x
1,14336 bf16 24.2 23.9 0.987x
4,7168 f16 19.9 34.4 1.730x
4,7168 bf16 17.5 26.7 1.524x

2-stage Kernel Path

Triggered when: ws=4 and tensor ≥ 160 KB, or ws=8 and tensor ≥ 80 KB (but below 4 MB at ws=8).

World Size = 4

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,40960 f32 17.5 19.5 1.113x
1,81920 f16 20.2 20.1 0.994x
1,81920 bf16 16.4 17.4 1.059x
1,81920 f32 19.7 18.0 0.914x
8,8192 f32 22.6 21.3 0.941x
32,4096 f16 17.1 22.6 1.323x
32,4096 bf16 16.3 26.5 1.620x
32,4096 f32 21.8 21.7 0.992x
16,5120 f16 24.1 20.0 0.829x
16,5120 bf16 18.2 18.0 0.988x
128,8192 f16 31.9 34.4 1.079x
128,8192 bf16 32.0 36.2 1.131x
64,14336 f16 28.9 33.3 1.154x
64,14336 bf16 28.7 34.3 1.195x
256,8192 f16 51.4 54.0 1.050x
256,8192 bf16 52.6 54.6 1.038x
256,8192 f32 87.5 91.4 1.045x
16,28672 f16 25.1 25.1 0.999x
16,28672 bf16 23.7 23.9 1.011x

World Size = 8

Shape dtype FlyDSL (μs) AIter (μs) Speedup
1,40960 f16 24.4 19.3 0.790x ⚠
1,40960 bf16 22.0 19.3 0.878x
1,40960 f32 22.6 23.1 1.023x
1,81920 f16 20.7 22.5 1.087x
1,81920 bf16 20.9 22.7 1.089x
1,81920 f32 23.8 23.7 0.998x
8,8192 f16 22.4 24.1 1.076x
8,8192 bf16 23.6 23.6 0.998x
8,8192 f32 20.9 21.4 1.021x
32,4096 f16 21.7 23.1 1.066x
32,4096 bf16 17.4 20.2 1.160x
32,4096 f32 20.9 24.3 1.161x
16,5120 f16 22.4 19.2 0.858x
16,5120 bf16 22.0 19.3 0.879x
128,8192 f16 31.6 27.7 0.875x
128,8192 bf16 28.7 27.6 0.963x
64,14336 f16 27.4 25.9 0.943x
64,14336 bf16 28.6 25.8 0.901x
256,8192 f16 35.9 37.5 1.044x
256,8192 bf16 37.6 37.9 1.007x
16,28672 f16 24.0 22.2 0.924x
16,28672 bf16 21.0 28.9 1.374x
256,8192 f32 55.3 57.0 1.030x
512,4097 f16 38.1 38.4 1.007x
512,4097 bf16 37.2 38.5 1.034x
512,8192 f16 55.9 56.9 1.018x
512,8192 bf16 55.7 57.1 1.026x
256,14336 f16 51.3 52.7 1.027x
256,14336 bf16 54.0 52.8 0.978x
1024,14336 f16 146.9 152.7 1.039x
1024,14336 bf16 149.6 153.1 1.024x
2048,8192 f16 164.8 175.1 1.063x
2048,8192 bf16 166.1 172.1 1.036x

Stress Test (Near 64 MB Limit)

Shape dtype ws Kernel Path FlyDSL (μs) AIter (μs) Speedup
2040,8192 f32 2 1-stage 1145.5 1168.3 1.020x
2040,8192 f32 4 2-stage 598.9 620.8 1.037x
2040,8192 f32 8 2-stage 308.7 326.3 1.057x
2048,8192 f32 2 1-stage 1150.3 1173.2 1.020x
2048,8192 f32 4 2-stage 601.2 621.8 1.034x
2048,8192 f32 8 2-stage 310.0 326.8 1.054x
4090,8192 f16 2 1-stage 1151.9 1177.0 1.022x
4090,8192 f16 4 2-stage 604.1 624.8 1.034x
4090,8192 f16 8 2-stage 311.1 325.0 1.045x
4096,8192 f16 2 1-stage 1152.9 1179.5 1.023x
4096,8192 f16 4 2-stage 599.3 621.4 1.037x
4096,8192 f16 8 2-stage 312.3 324.6 1.039x

Irregular Shapes

Shape dtype ws Kernel Path FlyDSL (μs) AIter (μs) Speedup
1,28672 f16 2 1-stage 12.1 16.4 1.360x
1,28672 bf16 2 1-stage 16.4 16.2 0.990x
1,28672 f16 4 1-stage 17.2 17.4 1.015x
1,28672 bf16 4 1-stage 15.8 18.2 1.155x
1,28672 f16 8 1-stage 28.8 18.7 0.648x ⚠
1,28672 bf16 8 1-stage 20.5 22.2 1.084x
3,7168 f16 2 1-stage 12.9 16.3 1.265x
3,7168 bf16 2 1-stage 17.0 19.6 1.149x
3,7168 f16 4 1-stage 20.6 17.4 0.847x
3,7168 bf16 4 1-stage 16.2 22.3 1.373x
3,7168 f16 8 1-stage 21.5 22.2 1.030x
3,7168 bf16 8 1-stage 20.7 18.5 0.893x
5,5120 f16 2 1-stage 14.2 17.1 1.203x
5,5120 bf16 2 1-stage 18.0 16.0 0.888x
5,5120 f16 4 1-stage 17.2 18.3 1.060x
5,5120 bf16 4 1-stage 17.8 19.4 1.089x
5,5120 f16 8 1-stage 21.4 17.8 0.832x
5,5120 bf16 8 1-stage 23.2 19.9 0.860x
7,14336 f16 2 1-stage 14.0 18.5 1.328x
7,14336 bf16 2 1-stage 16.5 18.2 1.104x
7,14336 f16 4 2-stage 20.1 21.4 1.065x
7,14336 bf16 4 2-stage 17.5 17.9 1.025x
7,14336 f16 8 2-stage 22.2 21.7 0.975x
7,14336 bf16 8 2-stage 24.6 23.3 0.949x
13,4096 f16 2 1-stage 13.3 16.5 1.241x
13,4096 bf16 2 1-stage 15.8 19.3 1.222x
13,4096 f16 4 1-stage 19.4 19.6 1.006x
13,4096 bf16 4 1-stage 20.7 18.1 0.874x
13,4096 f16 8 2-stage 22.7 19.4 0.856x
13,4096 bf16 8 2-stage 21.0 22.4 1.066x
17,7168 f16 2 1-stage 14.2 19.2 1.346x
17,7168 bf16 2 1-stage 14.7 18.8 1.277x
17,7168 f16 4 2-stage 16.3 21.5 1.320x
17,7168 bf16 4 2-stage 23.3 18.1 0.779x ⚠
17,7168 f16 8 2-stage 19.4 22.2 1.142x
17,7168 bf16 8 2-stage 22.4 21.0 0.938x

Speedup = aiter_avg_time_us / flydsl_avg_time_us | Bold indicates speedup ≥ 1.2x | ⚠ indicates regression (speedup < 0.8)

Submission Checklist

@yanboshao yanboshao changed the title All reduce new support allreduce Apr 1, 2026
@coderfeli
Copy link
Copy Markdown
Collaborator

@yanboshao Issues to address:

  1. mgpuLaunchClusterKernel deleted from FlyRocmRuntimeWrappers.cpp — This removes the cluster launch API entirely. This is unrelated to allreduce and could break other users or future kernels. Should this deletion be in a separate PR, or is there a justification?
  2. Hard dependency on aiter in FlyDSLAllreduce.init (line 227):
    import aiter as aiter_ops
    self.meta = aiter_ops.allocate_meta_buffer(...)
    can we run without aiter or add a skip when no aiter available.
  3. _is_weak_contiguous silently swallows all exceptions (line 27-33):
    except Exception:
    return False
  4. This could mask real bugs. At least log a warning.
  5. Write-mode kernel uses store_v4i32 (cached store) not store_v4i32_nt for remote writes (line 1586 in kernel). The docstring says "matching aiter's __builtin_nontemporal_store", but the code uses the regular cached store. This is inconsistent
  6. _signal_end_sync need_wbl2=True for write-mode (line 1592-1593). If the data stores in write-mode ARE supposed to be nontemporal (bypassing L2), then need_wbl2 should be False (as the docstring says: "For nt data stores, no wbl2 is needed").
    This is contradictory.

stream, params, extra));
}

extern "C" void mgpuLaunchClusterKernel(hipFunction_t function,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove all of these?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -0,0 +1,830 @@
"""Custom all-reduce kernel + Python-facing shim.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add license header.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -0,0 +1,820 @@
"""FlyDSL all-reduce kernels using signal protocol for multi-GPU communication.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

license header

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread kernels/custom_all_reduce_kernel.py Outdated
self_sg_i64 = _unwrap_value(self_sg)
sg_ptrs_i64 = _unwrap_value(sg_ptrs)
in_ptrs_i64 = _unwrap_value(in_ptrs)
out_ptr_i64 = _unwrap_value(out_ptr)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use so many ir.* wrap, unwrap? try to use native types in numeric.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread kernels/custom_all_reduce_kernel.py Outdated
Each warp loads data from one rank into shared memory, then warp 0
reduces across all warps and writes the result to global memory.
"""
from flydsl._mlir.dialects import arith, memref, scf, vector
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use fly.memref and arith?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

p = bfor.arguments[0]
cond = arith.CmpIOp(arith.CmpIPredicate.ult, p,
ea.constant(num_packs, type=i32)).result
scf.ConditionOp(cond, [p, bfor.arguments[1]])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly use > < and arith .cond?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@yanboshao yanboshao force-pushed the all_reduce_new branch 2 times, most recently from 158a605 to 37cf182 Compare April 7, 2026 02:58
@coderfeli
Copy link
Copy Markdown
Collaborator

@yanboshao still has conflicts with main

self._call_state_cache[cache_key] = state
except Exception:
pass

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems already in main. try using main directly

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

yanboshao and others added 16 commits April 8, 2026 14:23
- Add stream_ptr param to run_{1stage,2stage,2stage_ptr} wrappers and launch via async deps.
- Pass torch current stream in FlyDSL custom_all_reduce to avoid per-launch stream create/sync/destroy.
- Keep AIter CustomAllreduce import compatible across package layouts.

Co-authored-by: Cursor <cursoragent@cursor.com>
Extend run_2stage_ptr ABI with inp_ptr_override to avoid per-call H2D pointer updates.
Cache grid_x and add optional out reuse / validation controls to reduce host overhead.

Made-with: Cursor
@coderfeli coderfeli merged commit 8801996 into main Apr 8, 2026
9 checks passed
@coderfeli coderfeli deleted the all_reduce_new branch April 8, 2026 14:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants