Skip to content

Commit f80260d

Browse files
mawad-amdclaude
andcommitted
Add CUDA graph capture unit test for gluon all-reduce
Part A: single capture with multiple replays — catches barrier flag bugs and pointer table corruption. Part B: piecewise capture with 3 different tensor sizes sharing one workspace — catches data_ptr reuse bugs across captures (the vLLM pattern that crashed on 2nd decode step). Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
1 parent 97fb8e2 commit f80260d

1 file changed

Lines changed: 208 additions & 0 deletions

File tree

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved.
3+
4+
"""Minimal CUDA graph capture test for iris gluon all-reduce.
5+
6+
Isolates graph capture / replay without vLLM or aiter.
7+
Run with: torchrun --nproc_per_node=N python tests/test_graph_capture_allreduce.py
8+
9+
Part A — single capture:
10+
1. Eager correctness (baseline)
11+
2. Graph capture succeeds (no non-capturable ops)
12+
3. Single replay correctness
13+
4. Double replay correctness (2nd decode step crash repro)
14+
5. Replay with new input data (pointer table validity)
15+
16+
Part B — piecewise capture (vLLM pattern):
17+
6. Three separate graphs with different tensor sizes, shared workspace
18+
7. Interleaved replay of all three graphs
19+
8. Catches data_ptr reuse bugs across captures
20+
"""
21+
22+
import os
23+
import sys
24+
import torch
25+
import torch.distributed as dist
26+
27+
28+
def check(name, actual, expected_val, shape, rank):
29+
expected = torch.full(shape, expected_val, device="cuda", dtype=torch.float32)
30+
if torch.allclose(actual.float(), expected, rtol=1e-2, atol=1e-2):
31+
if rank == 0:
32+
print(f"PASS: {name}")
33+
return True
34+
else:
35+
print(f"FAIL: {name} rank={rank} got={actual.view(-1)[0].item():.4f} expected={expected_val:.4f}")
36+
return False
37+
38+
39+
def main():
40+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
41+
torch.cuda.set_device(local_rank)
42+
dist.init_process_group(backend="nccl")
43+
44+
world_size = dist.get_world_size()
45+
rank = dist.get_rank()
46+
47+
import iris
48+
from iris.ccl.config import Config
49+
50+
ctx = iris.iris(heap_size=2 ** 30)
51+
cfg = Config(use_gluon=True)
52+
53+
dtype = torch.bfloat16
54+
passed = 0
55+
total = 0
56+
57+
# =========================================
58+
# Part A: single capture, multiple replays
59+
# =========================================
60+
shape = (2, 8192)
61+
62+
# Test 1: eager correctness
63+
total += 1
64+
inp = ctx.empty(shape, dtype=dtype)
65+
inp.fill_(rank + 1.0)
66+
out = ctx.empty(shape, dtype=dtype)
67+
68+
ws = ctx.ccl.all_reduce(out, inp, config=cfg)
69+
torch.cuda.synchronize()
70+
71+
expected = sum(r + 1.0 for r in range(world_size))
72+
if check("eager correctness", out, expected, shape, rank):
73+
passed += 1
74+
75+
# Test 2-5: graph capture + replay
76+
graph_out = ctx.empty(shape, dtype=dtype)
77+
78+
stream = torch.cuda.Stream()
79+
torch.cuda.synchronize()
80+
dist.barrier()
81+
82+
# warmup in capture stream
83+
with torch.cuda.stream(stream):
84+
ws = ctx.ccl.all_reduce(graph_out, inp, config=cfg, workspace=ws)
85+
torch.cuda.synchronize()
86+
dist.barrier()
87+
88+
# capture
89+
graph = torch.cuda.CUDAGraph()
90+
with torch.cuda.stream(stream):
91+
with torch.cuda.graph(graph, stream=stream):
92+
ws = ctx.ccl.all_reduce(graph_out, inp, config=cfg, workspace=ws)
93+
94+
total += 1
95+
if rank == 0:
96+
print("PASS: graph capture succeeded")
97+
passed += 1
98+
99+
# Test 3: single replay
100+
total += 1
101+
inp.fill_(rank + 1.0)
102+
graph.replay()
103+
torch.cuda.synchronize()
104+
if check("single replay", graph_out, expected, shape, rank):
105+
passed += 1
106+
107+
# Test 4: double replay
108+
total += 1
109+
inp.fill_(rank + 1.0)
110+
graph.replay()
111+
graph.replay()
112+
torch.cuda.synchronize()
113+
if check("double replay", graph_out, expected, shape, rank):
114+
passed += 1
115+
116+
# Test 5: replay with new data
117+
total += 1
118+
inp.fill_((rank + 1.0) * 2)
119+
graph.replay()
120+
torch.cuda.synchronize()
121+
expected2 = sum((r + 1.0) * 2 for r in range(world_size))
122+
if check("replay new data", graph_out, expected2, shape, rank):
123+
passed += 1
124+
125+
# =========================================
126+
# Part B: piecewise capture (vLLM pattern)
127+
# 3 graphs with different sizes, shared workspace
128+
# =========================================
129+
if rank == 0:
130+
print("\n--- Part B: piecewise capture ---")
131+
132+
shapes = [(1, 8192), (4, 8192), (2, 8192)]
133+
graphs = []
134+
inputs = []
135+
outputs = []
136+
ws_piece = None
137+
138+
for i, s in enumerate(shapes):
139+
inp_i = ctx.empty(s, dtype=dtype)
140+
out_i = ctx.empty(s, dtype=dtype)
141+
inp_i.fill_(rank + 1.0)
142+
inputs.append(inp_i)
143+
outputs.append(out_i)
144+
145+
# warmup
146+
st = torch.cuda.Stream()
147+
with torch.cuda.stream(st):
148+
ws_piece = ctx.ccl.all_reduce(out_i, inp_i, config=cfg, workspace=ws_piece)
149+
torch.cuda.synchronize()
150+
dist.barrier()
151+
152+
for i, s in enumerate(shapes):
153+
g = torch.cuda.CUDAGraph()
154+
st = torch.cuda.Stream()
155+
with torch.cuda.stream(st):
156+
with torch.cuda.graph(g, stream=st):
157+
ws_piece = ctx.ccl.all_reduce(outputs[i], inputs[i], config=cfg, workspace=ws_piece)
158+
graphs.append(g)
159+
160+
total += 1
161+
if rank == 0:
162+
print(f"PASS: piecewise capture ({len(shapes)} graphs)")
163+
passed += 1
164+
165+
# Test 7: replay each graph
166+
for i, (g, s) in enumerate(zip(graphs, shapes)):
167+
total += 1
168+
inputs[i].fill_(rank + 1.0)
169+
g.replay()
170+
torch.cuda.synchronize()
171+
if check(f"piecewise replay graph[{i}] shape={s}", outputs[i], expected, s, rank):
172+
passed += 1
173+
174+
# Test 8: interleaved replay (catches cross-capture corruption)
175+
total += 1
176+
for inp_i in inputs:
177+
inp_i.fill_(rank + 1.0)
178+
graphs[2].replay()
179+
graphs[0].replay()
180+
graphs[1].replay()
181+
torch.cuda.synchronize()
182+
all_ok = all(
183+
torch.allclose(outputs[i].float(), torch.full(shapes[i], expected, device="cuda"), rtol=1e-2, atol=1e-2)
184+
for i in range(len(shapes))
185+
)
186+
if all_ok:
187+
if rank == 0:
188+
print("PASS: interleaved replay correctness")
189+
passed += 1
190+
else:
191+
for i in range(len(shapes)):
192+
if not torch.allclose(outputs[i].float(), torch.full(shapes[i], expected, device="cuda"), rtol=1e-2, atol=1e-2):
193+
print(f"FAIL: interleaved replay graph[{i}] rank={rank} got={outputs[i].view(-1)[0].item():.4f}")
194+
195+
# Summary
196+
if rank == 0:
197+
print(f"\n{passed}/{total} tests passed")
198+
if passed == total:
199+
print("ALL TESTS PASSED")
200+
else:
201+
print("SOME TESTS FAILED")
202+
sys.exit(0 if passed == total else 1)
203+
204+
dist.destroy_process_group()
205+
206+
207+
if __name__ == "__main__":
208+
main()

0 commit comments

Comments
 (0)