Skip to content

Conversation

@petrpan26
Copy link

@petrpan26 petrpan26 commented Nov 4, 2025

Todo:

  • Clean up code
  • Debug backward pass incorrectness

For sequence parallelism, There is one thing that I think we are doing inefficiently right now. We are doing accumulation of KV in a linear way and this in turn incur more latencies as more GPU are added (O(n) in this case). This have a lot of GPU idling between waiting KV accumulation. I'm suggesting adding a blelloch prefix scan algorithm to help reduce this linear steps to logarithmic instead.

When testing in 8xH100SXM, I saw a 2x speed up over other methods.

dp-size=2 (Data Parallel: 2, Sequence Parallel: 4)

Forward-only throughput:

fuse: 49.56M tokens/s (3.59x speedup) — fastest
blelloch: 39.76M tokens/s (2.88x speedup)
fuse_parallel: 35.14M tokens/s (2.55x speedup)
cache: 26.78M tokens/s (1.94x speedup)
naive: 13.80M tokens/s (baseline)

Forward+backward throughput:

blelloch: 16.81M tokens/s (3.81x speedup) — fastest
fuse: 16.69M tokens/s (3.78x speedup)
fuse_parallel: 11.87M tokens/s (2.69x speedup)
cache: 6.68M tokens/s (1.51x speedup)
naive: 4.41M tokens/s (baseline)

gpu = 8, dp = 2

Method Forward Diff Backward Diff (dq/dk/dv) Status
naive 0.0 0.016 / 0.047 / 0.067 ✅ Pass
cache 0.0 0.016 / 0.047 / 0.067 ✅ Pass
fuse 0.051 0.019 / 0.021 / 0.064 ✅ Pass
fuse_parallel 0.018 0.021 / 0.025 / 0.036 ✅ Pass
blelloch 0.018 0.021 / 0.025 / 0.036 ✅ Pass

gpu = 8, dp = 1

Method Forward Diff Backward Diff (dq/dk/dv) Status
naive 0.0 0.015 / 0.032 / 0.045 ✅ Pass
cache 0.0 0.015 / 0.032 / 0.045 ✅ Pass
fuse 0.048 0.016 / 0.021 / 0.064 ✅ Pass
fuse_parallel 0.006 0.016 / 0.025 / 0.036 ✅ Pass
blelloch 0.006 0.016 / 0.209 / 0.297 ✅ Pass

At gpu = 16 RTX 5090

Forward + Backward we see more profound effect at more gpu

DP-Size Blelloch Best Alternative (fuse) Speedup Advantage
dp-size=1 21.28M tokens/s 7.49M tokens/s 2.84x faster
dp-size=2 33.90M tokens/s 18.12M tokens/s 1.87x faster
dp-size=4 44.64M tokens/s 31.21M tokens/s 1.43x faster

Forward only

DP-Size Blelloch Best Alternative (fuse) Speedup Advantage
dp-size=1 47.34M tokens/s 18.53M tokens/s 2.55x faster
dp-size=2 83.01M tokens/s 51.17M tokens/s 1.62x faster
dp-size=4 125.50M tokens/s 100.22M tokens/s 1.25x faster

Key improvements:

  • O(log P) communication complexity (e.g., 128 GPUs: 128 steps → 14 steps)
  • Work-efficient tree-based algorithm
  • Supports non-power-of-2 GPU counts
  • Reuses KV/DKV buffers to avoid allocation overhead

Implementation details:

  1. BlellochScanner (lasp/utils/blelloch_ops.py):

    • Tree-based up-sweep and down-sweep communication
    • Correct sender/receiver logic using "right edge" of subtrees
    • Distance-based decay in down-sweep for proper accumulation
    • Support for reverse scan (suffix) for backward pass
    • Global rank conversion for multi-group data parallelism
  2. lasp_blelloch (lasp/lasp_blelloch.py):

    • Combines Blelloch scan with fused Triton kernels
    • Correct inclusive-to-exclusive conversion: λ^(-C) * (inclusive - local)
    • Buffer reuse pattern matching lasp_fuse_parallel
    • Forward: prefix scan, Backward: suffix scan
  3. Tests and benchmarks:

    • test_blelloch_correctness.py: Gradient correctness tests
    • test_non_power_of_two.py: Non-power-of-2 world sizes
    • benchmark_blelloch.py: Performance benchmarks
    • benchmark_all_methods.py: Comprehensive comparison

Tested with:

  • Single GPU and multi-GPU (4-8 GPUs)
  • Data parallelism (dp_size > 1) with sequence parallelism
  • Power-of-2 and non-power-of-2 world sizes
  • Forward and backward pass correctness

@petrpan26 petrpan26 force-pushed the feature/blelloch-clean branch from 134e5a6 to 75aca60 Compare November 4, 2025 09:34
This PR implements Blelloch parallel prefix scan to reduce inter-GPU
communication from O(P) sequential steps (ring) to O(log P) parallel
steps (tree-based).

Key improvements:
- O(log P) communication complexity (e.g., 128 GPUs: 128 steps → 14 steps)
- Work-efficient tree-based algorithm
- Supports non-power-of-2 GPU counts
- Reuses KV/DKV buffers to avoid allocation overhead

Implementation details:

1. **BlellochScanner** (lasp/utils/blelloch_ops.py):
   - Tree-based up-sweep and down-sweep communication
   - Correct sender/receiver logic using "right edge" of subtrees
   - Distance-based decay in down-sweep for proper accumulation
   - Support for reverse scan (suffix) for backward pass
   - Global rank conversion for multi-group data parallelism

2. **lasp_blelloch** (lasp/lasp_blelloch.py):
   - Combines Blelloch scan with fused Triton kernels
   - Correct inclusive-to-exclusive conversion: λ^(-C) * (inclusive - local)
   - Buffer reuse pattern matching lasp_fuse_parallel
   - Forward: prefix scan, Backward: suffix scan

3. **Tests and benchmarks**:
   - test_blelloch_correctness.py: Gradient correctness tests
   - test_non_power_of_two.py: Non-power-of-2 world sizes
   - benchmark_blelloch.py: Performance benchmarks
   - benchmark_all_methods.py: Comprehensive comparison

Tested with:
- Single GPU and multi-GPU (4-8 GPUs)
- Data parallelism (dp_size > 1) with sequence parallelism
- Power-of-2 and non-power-of-2 world sizes
- Forward and backward pass correctness
@petrpan26 petrpan26 force-pushed the feature/blelloch-clean branch from 75aca60 to c5cd122 Compare November 4, 2025 09:38
@weigao266
Copy link
Collaborator

Hi petrpan26,

Nice work. Thanks for contributing LASP! I will check the code change and do some tests in a few days.

Changed Blelloch scan to compute exclusive prefix directly instead of
converting from inclusive, avoiding division by lambda^n which causes
overflow when lambda is small.

Implementation:
1. Compute inclusive prefix using standard up-sweep + down-sweep
2. Convert to exclusive via simple rank shift: each rank i receives
   inclusive[i-1] from rank i-1, rank 0 gets zero

This matches the pattern used in lasp_naive where the ring naturally
produces exclusive prefix, avoiding the numerical issues of computing
1/lambda^n which overflows to infinity when s >= 1.0.

Fixes NaN gradients in backward pass.
@petrpan26 petrpan26 force-pushed the feature/blelloch-clean branch from 6842b21 to ac2f03b Compare November 4, 2025 09:52
@petrpan26
Copy link
Author

@weigao266 Sounds good I'm debugging why for large steps i think errors are accumulating but I added most of the result in and it should looks correct and add a few benchmarks file as well. Feel free to comment and let me know if I can change anything

Root cause: In suffix scan (backward pass), the rank shift was sending
in the wrong direction. For suffix scan, rank i should receive from
rank i+1 (not i-1) and send to rank i-1 (not i+1).

The bug: Used scan_rank±1 for both prefix and suffix, which worked for
prefix but was backwards for suffix due to the scan_rank reversal.

The fix:
- Separate logic for prefix vs suffix scan in rank shift
- Prefix: rank i receives from i-1, sends to i+1 (left to right)
- Suffix: rank i receives from i+1, sends to i-1 (right to left)
- Use actual rank (not scan_rank) for the shift communication
- Add actual_to_global_rank() helper to avoid scan_rank confusion

This should fix the 10x larger backward gradient errors (dk: 0.209,
dv: 0.297) by ensuring the suffix scan produces correct exclusive
values for each rank.
@petrpan26 petrpan26 force-pushed the feature/blelloch-clean branch from 1834356 to 9881835 Compare November 4, 2025 14:51
Hoang Phan added 2 commits November 4, 2025 09:53
Root cause: With 32+ GPUs, the rank shift was hanging because blocking
send/recv created a sequential dependency chain. Each rank had to wait
for the previous rank to send before it could send to the next rank,
creating O(P) latency and potential deadlock.

The fix: Use dist.irecv() and dist.isend() (non-blocking) instead of
blocking send/recv. This allows all ranks to initiate their send/recv
operations simultaneously, then wait for completion.

Benefits:
- Prevents deadlock with large GPU counts (tested hang at 32 GPUs)
- Allows parallel execution of send/recv operations
- Maintains O(1) latency for the rank shift step

This preserves the O(log P) overall complexity of Blelloch scan.
@petrpan26 petrpan26 force-pushed the feature/blelloch-clean branch from 890279c to f84f1d4 Compare November 4, 2025 16:40
@petrpan26
Copy link
Author

I was try running on smaller GPU but i cant so i tune the params a bit as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants