Skip to content

Commit 18ccbec

Browse files
committed
Pass to split all_gather prologue and reduce_scatter prologue from fsdp graph
stack-info: PR: #201, branch: IvanKobzarev/stack/9
1 parent ea2a7d6 commit 18ccbec

File tree

2 files changed

+247
-0
lines changed

2 files changed

+247
-0
lines changed

autoparallel/pipeline/passes.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import dataclasses
7+
from contextlib import contextmanager
8+
from functools import partial
9+
from typing import Any
10+
11+
import torch
12+
import torch.fx.node
13+
import torch.utils._pytree as pytree
14+
from torch._functorch._aot_autograd.descriptors import AOTOutput
15+
from torch._functorch.partitioners import _extract_graph_with_inputs_outputs
16+
from torch._inductor.fx_passes.bucketing import (
17+
is_all_gather_into_tensor,
18+
is_reduce_scatter_tensor,
19+
)
20+
21+
22+
@contextmanager
23+
def exclude_from_fx_side_effectful(exclude_vals: set[Any]):
24+
original_val = torch.fx.node._side_effectful_functions.copy()
25+
try:
26+
torch.fx.node._side_effectful_functions -= exclude_vals
27+
yield
28+
finally:
29+
torch.fx.node._side_effectful_functions.clear()
30+
torch.fx.node._side_effectful_functions.update(original_val)
31+
32+
33+
exclude_wait_from_fx_side_effectful = partial(
34+
exclude_from_fx_side_effectful,
35+
{
36+
torch.ops._c10d_functional.wait_tensor,
37+
torch.ops._c10d_functional.wait_tensor.default,
38+
},
39+
)
40+
41+
42+
@dataclasses.dataclass(frozen=True)
43+
class PrefetchOutput(AOTOutput):
44+
pass
45+
46+
47+
@dataclasses.dataclass(frozen=True)
48+
class EpilogueInput(AOTOutput):
49+
pass
50+
51+
52+
def split_fsdp_prefetch(
53+
g: torch.fx.Graph, stop_at_all_gather: bool = True
54+
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
55+
g_ins = g.find_nodes(op="placeholder")
56+
prefetch_g_outs_map = []
57+
58+
for g_in in g_ins:
59+
n = g_in
60+
has_ag = False
61+
while True:
62+
if len(n.users) != 1:
63+
break
64+
user = next(iter(n.users))
65+
if len(user.all_input_nodes) > 1:
66+
break
67+
n = user
68+
if stop_at_all_gather and is_all_gather_into_tensor(n):
69+
has_ag = True
70+
w_n = next(iter(n.users))
71+
n = w_n
72+
break
73+
if stop_at_all_gather and not has_ag:
74+
prefetch_g_outs_map.append(g_in)
75+
else:
76+
prefetch_g_outs_map.append(n)
77+
78+
prefetch_g_outs = prefetch_g_outs_map
79+
prefetch_g_outs_descs: list[AOTOutput] = [
80+
PrefetchOutput() for _ in range(len(prefetch_g_outs))
81+
]
82+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
83+
g_outs_descs = pytree.arg_tree_leaves(
84+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
85+
)
86+
with exclude_wait_from_fx_side_effectful():
87+
prefetch_g = _extract_graph_with_inputs_outputs(
88+
g,
89+
g_ins,
90+
prefetch_g_outs,
91+
prefetch_g_outs_descs,
92+
)
93+
94+
main_g = _extract_graph_with_inputs_outputs(
95+
g,
96+
prefetch_g_outs,
97+
g_outs,
98+
g_outs_descs,
99+
)
100+
return prefetch_g, main_g
101+
102+
103+
def split_fsdp_reduce_scatters_epilogue(
104+
g: torch.fx.Graph,
105+
) -> tuple[torch.fx.Graph, torch.fx.Graph]:
106+
g_ins = g.find_nodes(op="placeholder")
107+
g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output")))
108+
g_outs_descs = pytree.arg_tree_leaves(
109+
next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs))
110+
)
111+
112+
g_outs_map = []
113+
for g_out in g_outs:
114+
n = g_out
115+
has_rs = False
116+
while n is not None:
117+
if len(n.all_input_nodes) != 1:
118+
break
119+
n_in = n.all_input_nodes[0]
120+
if len(n_in.users) > 1:
121+
break
122+
prev_n = n
123+
n = n_in
124+
if is_reduce_scatter_tensor(prev_n):
125+
has_rs = True
126+
break
127+
if has_rs:
128+
g_outs_map.append(n)
129+
else:
130+
g_outs_map.append(g_out)
131+
132+
epi_g_ins = [n for n in g_outs_map if n is not None]
133+
epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))]
134+
main_g = _extract_graph_with_inputs_outputs(
135+
g,
136+
g_ins,
137+
epi_g_ins,
138+
epi_g_ins_descs,
139+
)
140+
epi_g = _extract_graph_with_inputs_outputs(
141+
g,
142+
epi_g_ins,
143+
g_outs,
144+
g_outs_descs,
145+
)
146+
147+
return main_g, epi_g

tests/test_pipeline_passes.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
import torch
5+
from torch import nn
6+
from torch.fx import GraphModule
7+
from torch.testing._internal.distributed.fake_pg import FakeStore
8+
9+
from autoparallel.api import AutoParallel
10+
from autoparallel.pipeline.passes import (
11+
split_fsdp_prefetch,
12+
split_fsdp_reduce_scatters_epilogue,
13+
)
14+
15+
16+
@pytest.fixture(scope="module", autouse=True)
17+
def init_pg():
18+
world_size = 256
19+
fake_store = FakeStore()
20+
if torch.distributed.is_initialized():
21+
return
22+
torch.distributed.init_process_group(
23+
"fake", store=fake_store, rank=0, world_size=world_size
24+
)
25+
26+
27+
@pytest.fixture(scope="module")
28+
def device_mesh_2d():
29+
world_size = torch.distributed.get_world_size()
30+
mesh = torch.distributed.device_mesh.init_device_mesh(
31+
"cuda",
32+
(world_size // 8, 8),
33+
mesh_dim_names=(
34+
"dp",
35+
"tp",
36+
),
37+
)
38+
return mesh
39+
40+
41+
class FFN(nn.Module):
42+
def __init__(self, dim1, dim2):
43+
super().__init__()
44+
bias = False
45+
self.linear1 = nn.Linear(dim1, dim2, bias=bias)
46+
self.linear2 = nn.Linear(dim2, dim1, bias=bias)
47+
48+
def forward(self, x, y):
49+
return y + 2, self.linear2(self.linear1(x)), y + 2
50+
51+
52+
def _make_model_and_input_fn(mesh, device="cuda"):
53+
bs = 2048 * mesh.shape[0]
54+
dim1 = 1024
55+
dim2 = 4096
56+
57+
def model_fn():
58+
return FFN(dim1, dim2)
59+
60+
def input_fn():
61+
return torch.randn(bs, dim1).to(device), torch.randn(bs, 1).to(device)
62+
63+
return model_fn, input_fn
64+
65+
66+
@patch("torch.cuda.device_count", lambda: 8)
67+
@patch("torch.cuda.get_device_name", lambda device: "H100")
68+
def test_fsdp_split_passes(device_mesh_2d):
69+
low_mem = 0
70+
high_mem = None
71+
model_fn, input_fn = _make_model_and_input_fn(device_mesh_2d)
72+
with torch.device("meta"):
73+
model = model_fn()
74+
75+
with AutoParallel(model, input_fn, device_mesh_2d) as autop:
76+
autop.add_parameter_memory_constraint(low=low_mem, high=high_mem)
77+
sharding_placement = autop.optimize_placement()
78+
autop.apply_placement(sharding_placement)
79+
gm = autop.parallel_gm
80+
g = gm.graph
81+
82+
def gen_g_inputs(g):
83+
phs = g.find_nodes(op="placeholder")
84+
ret = []
85+
for ph in phs:
86+
ft = ph.meta["val"]
87+
t = torch.randn(ft.shape, dtype=ft.dtype, device=ft.device)
88+
ret.append(t)
89+
return ret
90+
91+
inputs = gen_g_inputs(g)
92+
g_pro, g_main = split_fsdp_prefetch(g)
93+
g_main, g_epi = split_fsdp_reduce_scatters_epilogue(g_main)
94+
95+
gm_pro = GraphModule(gm, g_pro)
96+
gm_main = GraphModule(gm, g_main)
97+
gm_epi = GraphModule(gm, g_epi)
98+
99+
gm(*inputs)
100+
gm_epi(*gm_main(*gm_pro(*inputs)))

0 commit comments

Comments
 (0)