-
Notifications
You must be signed in to change notification settings - Fork 5
Add Blelloch parallel prefix scan for LASP #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
134e5a6 to
75aca60
Compare
This PR implements Blelloch parallel prefix scan to reduce inter-GPU communication from O(P) sequential steps (ring) to O(log P) parallel steps (tree-based). Key improvements: - O(log P) communication complexity (e.g., 128 GPUs: 128 steps → 14 steps) - Work-efficient tree-based algorithm - Supports non-power-of-2 GPU counts - Reuses KV/DKV buffers to avoid allocation overhead Implementation details: 1. **BlellochScanner** (lasp/utils/blelloch_ops.py): - Tree-based up-sweep and down-sweep communication - Correct sender/receiver logic using "right edge" of subtrees - Distance-based decay in down-sweep for proper accumulation - Support for reverse scan (suffix) for backward pass - Global rank conversion for multi-group data parallelism 2. **lasp_blelloch** (lasp/lasp_blelloch.py): - Combines Blelloch scan with fused Triton kernels - Correct inclusive-to-exclusive conversion: λ^(-C) * (inclusive - local) - Buffer reuse pattern matching lasp_fuse_parallel - Forward: prefix scan, Backward: suffix scan 3. **Tests and benchmarks**: - test_blelloch_correctness.py: Gradient correctness tests - test_non_power_of_two.py: Non-power-of-2 world sizes - benchmark_blelloch.py: Performance benchmarks - benchmark_all_methods.py: Comprehensive comparison Tested with: - Single GPU and multi-GPU (4-8 GPUs) - Data parallelism (dp_size > 1) with sequence parallelism - Power-of-2 and non-power-of-2 world sizes - Forward and backward pass correctness
75aca60 to
c5cd122
Compare
|
Hi petrpan26, Nice work. Thanks for contributing LASP! I will check the code change and do some tests in a few days. |
Changed Blelloch scan to compute exclusive prefix directly instead of converting from inclusive, avoiding division by lambda^n which causes overflow when lambda is small. Implementation: 1. Compute inclusive prefix using standard up-sweep + down-sweep 2. Convert to exclusive via simple rank shift: each rank i receives inclusive[i-1] from rank i-1, rank 0 gets zero This matches the pattern used in lasp_naive where the ring naturally produces exclusive prefix, avoiding the numerical issues of computing 1/lambda^n which overflows to infinity when s >= 1.0. Fixes NaN gradients in backward pass.
6842b21 to
ac2f03b
Compare
|
@weigao266 Sounds good I'm debugging why for large steps i think errors are accumulating but I added most of the result in and it should looks correct and add a few benchmarks file as well. Feel free to comment and let me know if I can change anything |
Root cause: In suffix scan (backward pass), the rank shift was sending in the wrong direction. For suffix scan, rank i should receive from rank i+1 (not i-1) and send to rank i-1 (not i+1). The bug: Used scan_rank±1 for both prefix and suffix, which worked for prefix but was backwards for suffix due to the scan_rank reversal. The fix: - Separate logic for prefix vs suffix scan in rank shift - Prefix: rank i receives from i-1, sends to i+1 (left to right) - Suffix: rank i receives from i+1, sends to i-1 (right to left) - Use actual rank (not scan_rank) for the shift communication - Add actual_to_global_rank() helper to avoid scan_rank confusion This should fix the 10x larger backward gradient errors (dk: 0.209, dv: 0.297) by ensuring the suffix scan produces correct exclusive values for each rank.
1834356 to
9881835
Compare
Root cause: With 32+ GPUs, the rank shift was hanging because blocking send/recv created a sequential dependency chain. Each rank had to wait for the previous rank to send before it could send to the next rank, creating O(P) latency and potential deadlock. The fix: Use dist.irecv() and dist.isend() (non-blocking) instead of blocking send/recv. This allows all ranks to initiate their send/recv operations simultaneously, then wait for completion. Benefits: - Prevents deadlock with large GPU counts (tested hang at 32 GPUs) - Allows parallel execution of send/recv operations - Maintains O(1) latency for the rank shift step This preserves the O(log P) overall complexity of Blelloch scan.
890279c to
f84f1d4
Compare
|
I was try running on smaller GPU but i cant so i tune the params a bit as well |
Todo:
For sequence parallelism, There is one thing that I think we are doing inefficiently right now. We are doing accumulation of KV in a linear way and this in turn incur more latencies as more GPU are added (O(n) in this case). This have a lot of GPU idling between waiting KV accumulation. I'm suggesting adding a blelloch prefix scan algorithm to help reduce this linear steps to logarithmic instead.
When testing in 8xH100SXM, I saw a 2x speed up over other methods.
dp-size=2 (Data Parallel: 2, Sequence Parallel: 4)
Forward-only throughput:
fuse: 49.56M tokens/s (3.59x speedup) — fastest
blelloch: 39.76M tokens/s (2.88x speedup)
fuse_parallel: 35.14M tokens/s (2.55x speedup)
cache: 26.78M tokens/s (1.94x speedup)
naive: 13.80M tokens/s (baseline)
Forward+backward throughput:
blelloch: 16.81M tokens/s (3.81x speedup) — fastest
fuse: 16.69M tokens/s (3.78x speedup)
fuse_parallel: 11.87M tokens/s (2.69x speedup)
cache: 6.68M tokens/s (1.51x speedup)
naive: 4.41M tokens/s (baseline)
gpu = 8, dp = 2
gpu = 8, dp = 1
At gpu = 16 RTX 5090
Forward + Backward we see more profound effect at more gpu
Forward only
Key improvements:
Implementation details:
BlellochScanner (lasp/utils/blelloch_ops.py):
lasp_blelloch (lasp/lasp_blelloch.py):
Tests and benchmarks:
Tested with: