Skip to content

Commit c5cd122

Browse files
author
Hoang Phan
committed
Add Blelloch parallel prefix scan for LASP
This PR implements Blelloch parallel prefix scan to reduce inter-GPU communication from O(P) sequential steps (ring) to O(log P) parallel steps (tree-based). Key improvements: - O(log P) communication complexity (e.g., 128 GPUs: 128 steps → 14 steps) - Work-efficient tree-based algorithm - Supports non-power-of-2 GPU counts - Reuses KV/DKV buffers to avoid allocation overhead Implementation details: 1. **BlellochScanner** (lasp/utils/blelloch_ops.py): - Tree-based up-sweep and down-sweep communication - Correct sender/receiver logic using "right edge" of subtrees - Distance-based decay in down-sweep for proper accumulation - Support for reverse scan (suffix) for backward pass - Global rank conversion for multi-group data parallelism 2. **lasp_blelloch** (lasp/lasp_blelloch.py): - Combines Blelloch scan with fused Triton kernels - Correct inclusive-to-exclusive conversion: λ^(-C) * (inclusive - local) - Buffer reuse pattern matching lasp_fuse_parallel - Forward: prefix scan, Backward: suffix scan 3. **Tests and benchmarks**: - test_blelloch_correctness.py: Gradient correctness tests - test_non_power_of_two.py: Non-power-of-2 world sizes - benchmark_blelloch.py: Performance benchmarks - benchmark_all_methods.py: Comprehensive comparison Tested with: - Single GPU and multi-GPU (4-8 GPUs) - Data parallelism (dp_size > 1) with sequence parallelism - Power-of-2 and non-power-of-2 world sizes - Forward and backward pass correctness
1 parent 13f320a commit c5cd122

File tree

10 files changed

+2113
-11
lines changed

10 files changed

+2113
-11
lines changed

lasp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .lasp_fuse import *
33
from .lasp_fuse_parallel import *
44
from .lasp_naive import *
5+
from .lasp_blelloch import *
56
from .lightning_attention import *
67
from .utils import *

0 commit comments

Comments
 (0)