Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 7f0d6bb

Browse files
weifengpyfacebook-github-bot
authored andcommitted
add unit tests for FSDP2 + torch.compile(transformer block) (#321)
Summary: TorchTitan complains about FSDP2 + float8 + torch.compile(transformer block). there is a mismatch in float8 scale so dynamo guards assersion failed `torch._C._dynamo.guards.assert_size_stride(new_inputs[3], (), ())` * in 1st iteration, we calculate float8 scale through `cast_to_float8_e4m3_dynamic` ([code](https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/fsdp_utils.py#L172)). scale is a scalar tensor, eg `tensor(4674.8633)` * in 2nd iteration, we calulate float8 scale through `precompute_float8_dynamic_scale`, but scale is NOT a scalar tensor, eg `tensor([[4674.8633]]` * this PR calls `.squeeze` to make sure scales are always scalar tensors, and dynamo guards assersion always hold true added unit test so we can catch the isssue at PR time TODO: add fp8 + torch.compile to CI in torchtitan Pull Request resolved: #321 Reviewed By: vkuzo Differential Revision: D59892261 Pulled By: weifengpy fbshipit-source-id: 6f9f5a4e2de06c347403f4c7c82b3978f37ff9eb
1 parent ec8b46c commit 7f0d6bb

File tree

6 files changed

+24
-11
lines changed

6 files changed

+24
-11
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ pytest test/test_numerics_integration.py
137137
./test/test_dtensor.sh
138138

139139
# run integration tests on the FSDP2 integration
140-
python test/test_fsdp2/test_fsdp2_eager.py
140+
python test/test_fsdp2/test_fsdp2.py
141141

142142
# run all of these tests
143143
./test/test_everything.sh

float8_experimental/float8_dynamic_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional, Tuple
8-
97
import torch
108

119
from float8_experimental.float8_tensor import (

float8_experimental/fsdp_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
6464
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
6565
scales = torch.split(scale_tensor, 1) # Replicate
6666
for scale, float8_linear in zip(scales, float8_linears):
67-
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor
67+
float8_linear.weight._local_tensor._precomputed_scale = (
68+
scale._local_tensor.squeeze()
69+
)
6870

6971

7072
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
@@ -301,7 +303,7 @@ def __tensor_flatten__(self):
301303
],
302304
{
303305
"mm_config": self._mm_config,
304-
"is_amax_initialized": is_amax_initialized,
306+
"is_amax_initialized": self.is_amax_initialized,
305307
},
306308
)
307309

test/test_everything.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ then
1515
./test/test_fsdp.sh
1616
./test/test_fsdp_compile.sh
1717
./test/test_dtensor.sh
18-
pytest test/test_fsdp2/test_fsdp2_eager.py
18+
pytest test/test_fsdp2/test_fsdp2.py
1919
fi
2020

2121
echo "all tests successful"

test/test_fsdp2/test_fsdp2_eager.py renamed to test/test_fsdp2/test_fsdp2.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def test_transformer_parity(self):
8989
TensorScalingType.DYNAMIC,
9090
TensorScalingType.DELAYED,
9191
],
92+
"compile_transformer_block": [False, True],
9293
},
9394
self._test_transformer_parity,
9495
)
@@ -98,6 +99,7 @@ def _test_transformer_parity(
9899
enable_fsdp_fp8_all_gather: bool,
99100
precompute: bool,
100101
scaling_type_w: TensorScalingType,
102+
compile_transformer_block: bool,
101103
):
102104
if not enable_fsdp_fp8_all_gather and precompute:
103105
return
@@ -112,11 +114,17 @@ def _test_transformer_parity(
112114
module = self.init_transformer(weight_tying=weight_tying).cuda()
113115
ref_module = copy.deepcopy(module)
114116
swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w)
117+
if compile_transformer_block:
118+
for layer_id, transformer_block in ref_module.layers.named_children():
119+
transformer_block = torch.compile(transformer_block, dynamic=False)
120+
ref_module.layers.register_module(layer_id, transformer_block)
115121
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
116122
swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w)
117-
for submodule in module.modules():
118-
if isinstance(submodule, TransformerBlock):
119-
fully_shard(submodule)
123+
for layer_id, transformer_block in module.layers.named_children():
124+
if compile_transformer_block:
125+
transformer_block = torch.compile(transformer_block, dynamic=False)
126+
fully_shard(transformer_block)
127+
module.layers.register_module(layer_id, transformer_block)
120128
fully_shard(module)
121129
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
122130
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
@@ -132,6 +140,7 @@ def _test_transformer_parity(
132140
local_inp,
133141
precompute,
134142
scaling_type_w=scaling_type_w,
143+
compile_transformer_block=compile_transformer_block,
135144
)
136145

137146
@skip_if_lt_x_gpu(2)

test/test_fsdp2/test_fsdp2_common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import torch.distributed as dist
88
import torch.nn as nn
9-
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
9+
from float8_experimental.float8_linear import TensorScalingType
1010
from float8_experimental.float8_linear_utils import (
1111
linear_requires_sync,
1212
sync_float8_amax_and_scale_history,
@@ -23,6 +23,7 @@ def check_parity_no_mp(
2323
local_inp: torch.Tensor,
2424
precompute: bool = False,
2525
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
26+
compile_transformer_block: bool = False,
2627
):
2728
for iter_idx in range(10):
2829
losses: List[torch.Tensor] = []
@@ -46,7 +47,10 @@ def check_parity_no_mp(
4647
):
4748
precompute_float8_dynamic_scale_for_fsdp(model)
4849

49-
test_cls.assertEqual(losses[0], losses[1])
50+
if compile_transformer_block:
51+
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4)
52+
else:
53+
test_cls.assertEqual(losses[0], losses[1])
5054

5155

5256
def check_parity_bf16_mp(

0 commit comments

Comments
 (0)