Skip to content

Commit 97762cc

Browse files
committed
[cp] set up load balancing testbed
ghstack-source-id: bbeed6b Pull Request resolved: #120
1 parent af82ef0 commit 97762cc

File tree

3 files changed

+211
-0
lines changed

3 files changed

+211
-0
lines changed

attn_gym/load_balance/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from attn_gym.load_balance.load_balancer import load_balance_algo
2+
3+
__all__ = ["load_balance_algo"]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import List
2+
3+
import torch
4+
5+
6+
__all__ = ["load_balance_algo"]
7+
8+
9+
def load_balance_algo(S: int, size: int, block_size: int) -> torch.Tensor:
10+
total_num_blk = S // block_size
11+
assert S % (size * total_num_blk) == 0
12+
local_num_blk = total_num_blk // size
13+
return torch.arange(total_num_blk, device="cuda").view(size, local_num_blk)

examples/distributed_benchmark.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from functools import lru_cache
2+
from typing import Optional, List
3+
4+
import os
5+
import torch
6+
import torch.distributed as dist
7+
import torch.nn.functional as F
8+
from torch.distributed.device_mesh import init_device_mesh
9+
from torch.distributed.tensor import distribute_tensor, DTensor, DeviceMesh, Partial, Replicate, Shard
10+
11+
12+
from torch.nn.attention.flex_attention import (
13+
_DEFAULT_SPARSE_BLOCK_SIZE,
14+
create_block_mask,
15+
flex_attention,
16+
_mask_mod_signature,
17+
)
18+
19+
from attn_gym.masks.document_mask import length_to_offsets
20+
from attn_gym.masks import (
21+
causal_mask,
22+
generate_doc_mask_mod,
23+
)
24+
from attn_gym.load_balance import load_balance_algo
25+
26+
27+
def get_device_type() -> str:
28+
return "cuda"
29+
30+
31+
@lru_cache
32+
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
33+
block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
34+
return block_mask
35+
36+
37+
# TODO: re-write it into a wrapper???
38+
def rewrite_mask_mod_for_cp(
39+
mask_mod: _mask_mod_signature,
40+
rank: int,
41+
block_size: int,
42+
load_balancer_output: torch.Tensor,
43+
) -> _mask_mod_signature:
44+
def local_q_idx_to_q_idx(local_q_idx) -> int:
45+
# calculate local block_idx and block_offset
46+
local_blk_idx, local_blk_offset = (
47+
local_q_idx // block_size, local_q_idx % block_size
48+
)
49+
current_rank_blk_list = load_balancer_output[rank]
50+
blk_idx = current_rank_blk_list[local_blk_idx]
51+
return blk_idx * block_size + local_blk_offset
52+
53+
return lambda b, h, q_idx, kv_idx: mask_mod(
54+
b, h, local_q_idx_to_q_idx(q_idx), kv_idx
55+
)
56+
57+
58+
def run_document_masking(device_mesh, max_seq_len, num_docs):
59+
# initialize the document lengths
60+
import random
61+
62+
random.seed(0)
63+
torch.cuda.manual_seed(0)
64+
65+
def generate_random_lengths(total_length, num_documents):
66+
# Initialize all lengths to 1 to ensure each document has at least one token
67+
lengths = [1] * num_documents
68+
remaining_length = total_length - num_documents
69+
70+
# Randomly distribute the remaining length
71+
for _ in range(remaining_length):
72+
index = random.randint(0, num_documents - 1)
73+
lengths[index] += 1
74+
75+
return lengths
76+
77+
lengths = generate_random_lengths(max_seq_len, num_docs)
78+
offsets = length_to_offsets(lengths, torch.device(f'cuda:{torch.cuda.current_device():d}')) # TODO: replace with a device mesh call
79+
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
80+
test_mask_with_load_balance(device_mesh, mask_mod=document_causal_mask, S=max_seq_len)
81+
82+
83+
def test_mask_with_load_balance(
84+
device_mesh: DeviceMesh,
85+
mask_mod: Optional[_mask_mod_signature] = None,
86+
B: int = 16,
87+
H: int = 16,
88+
S: int = 8192,
89+
D: int = 64,
90+
skip_correctness: bool = False,
91+
print_mask: bool = True,
92+
device: str = "cuda",
93+
):
94+
data_type = torch.float16
95+
96+
# create block mask
97+
block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device)
98+
block_size = _DEFAULT_SPARSE_BLOCK_SIZE # TODO: get block size from block mask
99+
100+
# input initialization
101+
qkv = [
102+
torch.rand(
103+
(B, H, S, D),
104+
device=device_mesh.device_type,
105+
dtype=data_type,
106+
requires_grad=True,
107+
)
108+
for _ in range(3)
109+
]
110+
111+
# TODO: input sharding with load-balancing
112+
# sparsity_info = get_sparsity_info_from_block_mask(block_mask)
113+
# load_balancer_output = load_balance_algo(sparsity_info)
114+
cp_mesh_size = device_mesh.size()
115+
load_balancer_output = load_balance_algo(S, cp_mesh_size, block_size)
116+
117+
seq_dim = 2
118+
qkv_dist = [
119+
distribute_tensor(
120+
t.detach().clone().requires_grad_(), device_mesh, [
121+
Shard(seq_dim) if i == 0 else Replicate()
122+
]
123+
)
124+
for (i, t) in enumerate(qkv)
125+
]
126+
127+
q_local, k_full, v_full = (dt.to_local() for dt in qkv_dist)
128+
129+
# rewrite `block_mask`
130+
mask_mod: _mask_mod_signature = block_mask.mask_mod
131+
cp_rank = device_mesh.get_local_rank()
132+
cp_mask_mod = rewrite_mask_mod_for_cp(
133+
mask_mod, cp_rank, block_size, load_balancer_output
134+
)
135+
cp_block_mask = create_block_mask_cached(
136+
cp_mask_mod, B=1, H=1, M=S // cp_mesh_size, N=S, device=device
137+
)
138+
139+
# Compile the flex_attention function
140+
compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
141+
142+
# TODO: this doesn't address the return_lse=True case
143+
cp_out = compiled_flex_attention(
144+
q_local,
145+
k_full,
146+
v_full,
147+
score_mod=None,
148+
block_mask=cp_block_mask,
149+
)
150+
assert isinstance(cp_out, torch.Tensor)
151+
152+
# unshard
153+
cp_out_dist = DTensor.from_local(cp_out, device_mesh, [Shard(seq_dim)])
154+
full_cp_out_dist = cp_out_dist.full_tensor()
155+
# rearrange
156+
blk_idx_to_origin = load_balancer_output.view(-1)
157+
num_chunks = blk_idx_to_origin.numel()
158+
blk_list_rearranged = [None] * num_chunks
159+
blk_list = torch.chunk(full_cp_out_dist, num_chunks, dim=seq_dim)
160+
assert len(blk_list) == num_chunks
161+
for blk_idx, blk in enumerate(blk_list):
162+
blk_list_rearranged[blk_idx_to_origin[blk_idx].item()] = blk
163+
164+
full_cp_out_dist = torch.cat(blk_list_rearranged, dim=seq_dim)
165+
166+
# local flex attention
167+
expect_out = flex_attention(*qkv, block_mask=block_mask)
168+
torch.testing.assert_close(full_cp_out_dist, expect_out, atol=1e-1, rtol=1e-2)
169+
170+
171+
def load_balancing_example(world_size: int, rank: int) -> None:
172+
device_type = get_device_type()
173+
device_handle = getattr(torch, device_type, None)
174+
assert device_handle is not None, f"Unsupported device type: {device_type}"
175+
num_devices_per_host = device_handle.device_count()
176+
device_handle.set_device(rank % num_devices_per_host)
177+
torch._dynamo.config.cache_size_limit = 1000
178+
179+
# init device mesh
180+
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
181+
182+
run_document_masking(device_mesh, max_seq_len=4096, num_docs=12)
183+
184+
185+
if __name__ == "__main__":
186+
# this script is launched via torchrun which automatically manages ProcessGroup
187+
rank = int(os.environ["RANK"])
188+
world_size = int(os.environ["WORLD_SIZE"])
189+
# assert world_size == 4 # our example uses 4 worker ranks
190+
191+
try:
192+
load_balancing_example(world_size, rank)
193+
finally:
194+
dist.barrier()
195+
dist.destroy_process_group()

0 commit comments

Comments
 (0)