Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 9d5f892

Browse files
vkuzofacebook-github-bot
authored andcommitted
unify filtering functions (#322)
Summary: Pull Request resolved: #322 Before this PR, we had two top level filtering affordances on `swap_linear_with_float8_linear`: `skip_fqn_list` and `linear_layer_filter`: ``` def swap_linear_with_float8_linear( ..., skip_fqn_list: Optional[List[str]] = None, linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, ) ``` This PR unifies them into a single filtering method which is aware of both the FQN as well as the module instance. This should be more future proof and allow for more fine grained filtering. ``` def swap_linear_with_float8_linear( ..., layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, ) ``` Note that the `filter_out_small_unaligned_layers` function is removed from the public API, and users are encouraged to write their own, again to make this more future proof. I'm not opposed to adding good utility functions back at some point, but IMO we should finalize the main UX first. Note that in the future (as outlined in the Meta-only UX discussion doc), we may change this function from being just a filter to also providing the float8 per-module config. I'm saving that for a separate PR, which will also be BC breaking (but we don't have BC yet). Reviewed By: y-sq Differential Revision: D60072597 fbshipit-source-id: 0b9482459e0826e46bae42dabe6f4f40b805bde7
1 parent c58fb5d commit 9d5f892

File tree

4 files changed

+88
-83
lines changed

4 files changed

+88
-83
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,19 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f
4343
# create model
4444
m = Model(...)
4545

46+
# optional: filter modules from being eligible for float8 conversion
47+
def module_filter_fn(fqn: str, mod: torch.nn.Module):
48+
# don't convert the output module
49+
if fqn == "output":
50+
return False
51+
# don't convert linear modules with weight dimensions not divisible by 16
52+
if isinstance(mod, torch.nn.Linear):
53+
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
54+
return False
55+
return True
56+
4657
# convert all `torch.nn.Linear` modules to `Float8Linear`
47-
swap_linear_with_float8_linear(m)
58+
swap_linear_with_float8_linear(m, module_filter_fn=module_filter_fn)
4859

4960
# optional: use FSDP
5061
model = FSDP(model, use_orig_params=True)

float8_experimental/float8_linear_utils.py

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -59,26 +59,11 @@ def _update_history_stack(
5959
amax_history_stack.copy_(new_amax_history_stack)
6060

6161

62-
def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear], bool]:
63-
"""
64-
Returns a callable that filters out small (dimensions less than the given `size_limit`)
65-
and unaligned (dimenstions not divisible by 16) layers.
66-
It can be passed as the `linear_layer_filter` argument to `swap_linear_with_float8_linear`.
67-
"""
68-
return (
69-
lambda linear_layer: linear_layer.in_features >= size_limit
70-
and linear_layer.out_features >= size_limit
71-
and linear_layer.in_features % 16 == 0
72-
and linear_layer.out_features % 16 == 0
73-
)
74-
75-
7662
def swap_linear_layers(
7763
module: nn.Module,
7864
from_float_func: Callable[[nn.Linear], nn.Linear],
7965
*,
80-
skip_fqn_list: Optional[List[str]] = None,
81-
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
66+
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
8267
) -> Optional[nn.Module]:
8368
"""
8469
Generic function to swap linear layers in a module with a new type of linear layer.
@@ -90,18 +75,15 @@ def swap_linear_layers(
9075
Args:
9176
module: Module to modify.
9277
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
93-
skip_fqn_list: If specified, a list of module FQNs to skip.
94-
linear_layer_filter: If specified, only the linear layers
95-
that pass the filter function will be swapped.
96-
from_float_kwargs: Additional keyword arguments for from_float_func.
78+
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
79+
that pass the filter function will be swapped. The inputs to the
80+
filter function are the FQN and module instance.
9781
9882
Returns:
9983
nn.Module: The modified module with swapped linear layers.
10084
"""
101-
module_names_to_skip = set(skip_fqn_list or [])
102-
10385
if isinstance(module, nn.Linear) and (
104-
linear_layer_filter is None or linear_layer_filter(module)
86+
module_filter_fn is None or module_filter_fn("", module)
10587
):
10688
if len(list(module.children())) > 0:
10789
raise AssertionError(
@@ -112,43 +94,44 @@ def swap_linear_layers(
11294
)
11395

11496
root_module = module
115-
visited_modules = {root_module}
116-
117-
for module_name, module in root_module.named_modules():
118-
if module_name in module_names_to_skip:
119-
visited_modules.add(module)
12097

12198
def post_order_traversal(
122-
module: nn.Module, module_name: str, parent_module: Optional[nn.Module]
99+
module: nn.Module,
100+
cur_fqn: Optional[str] = None,
101+
parent_module: Optional[nn.Module] = None,
123102
):
124-
nonlocal visited_modules
103+
if cur_fqn is None:
104+
cur_fqn = ""
105+
125106
for child_module_name, child_module in module.named_children():
126-
if child_module not in visited_modules:
127-
visited_modules.add(child_module)
128-
post_order_traversal(child_module, child_module_name, module)
107+
if cur_fqn == "":
108+
new_fqn = child_module_name
109+
else:
110+
new_fqn = f"{cur_fqn}.{child_module_name}"
111+
112+
post_order_traversal(child_module, new_fqn, module)
129113

130114
if isinstance(module, nn.Linear) and (
131-
linear_layer_filter is None or linear_layer_filter(module)
115+
# linear_layer_filter is None or linear_layer_filter(module)
116+
module_filter_fn is None
117+
or module_filter_fn(cur_fqn, module)
132118
):
133119
assert (
134120
parent_module is not None
135121
), f"Linear root module should return early: {module}"
136122
new_linear_module = from_float_func(module)
137-
setattr(parent_module, module_name, new_linear_module)
123+
cur_module_name = cur_fqn.split(".")[-1]
124+
setattr(parent_module, cur_module_name, new_linear_module)
138125

139-
post_order_traversal(root_module, "", None)
140-
# Without this explicit `del`, this set only gets deleted upon an explicit
141-
# garbage collection (not from when its refcount hits zero)
142-
del visited_modules
126+
post_order_traversal(root_module)
143127
return root_module
144128

145129

146130
def swap_linear_with_float8_linear(
147131
module: nn.Module,
148132
*,
149-
skip_fqn_list: Optional[List[str]] = None,
150133
emulate: bool = False,
151-
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
134+
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
152135
scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC,
153136
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
154137
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC,
@@ -158,10 +141,10 @@ def swap_linear_with_float8_linear(
158141
159142
Args:
160143
module: Module to modify.
161-
skip_fqn_list: If specified, a list of module FQNs to skip.
162144
emulate: If True, emulation is used instead of hardware accelerated gemm
163-
linear_layer_filter: If specified, only the linear layers
164-
that pass the filter function will be swapped.
145+
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
146+
that pass the filter function will be swapped. The inputs to the
147+
filter function are the FQN and module instance.
165148
scaling_type_x (TensorScalingType): scaling type for `x`
166149
scaling_type_w (TensorScalingType): scaling type for `w`
167150
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
@@ -179,8 +162,7 @@ def swap_linear_with_float8_linear(
179162
return swap_linear_layers(
180163
module,
181164
from_float,
182-
skip_fqn_list=skip_fqn_list,
183-
linear_layer_filter=linear_layer_filter,
165+
module_filter_fn=module_filter_fn,
184166
)
185167

186168

float8_experimental/inference.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dataclasses import dataclass
1111

1212
from enum import auto, Enum
13-
from typing import List, Optional
13+
from typing import Callable, List, Optional
1414

1515
import float8_experimental.config as config
1616

@@ -209,7 +209,7 @@ def quantize_to_float8(
209209
module: nn.Module,
210210
quant_config: QuantConfig,
211211
*,
212-
skip_fqn_list: Optional[List[str]] = None,
212+
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
213213
use_fast_accum: bool = True,
214214
) -> Optional[nn.Module]:
215215
"""
@@ -222,7 +222,9 @@ def quantize_to_float8(
222222
Args:
223223
module (nn.Module): The module to modify.
224224
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
225-
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
225+
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
226+
that pass the filter function will be swapped. The inputs to the
227+
filter function are the FQN and module instance.
226228
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
227229
228230
Returns:
@@ -234,5 +236,5 @@ def quantize_to_float8(
234236
return swap_linear_layers(
235237
module,
236238
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
237-
skip_fqn_list=skip_fqn_list,
239+
module_filter_fn=module_filter_fn,
238240
)

test/test_base.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2020
from float8_experimental.float8_linear_utils import (
21-
filter_out_small_unaligned_layers,
2221
linear_requires_sync,
2322
swap_linear_with_float8_linear,
2423
sync_float8_amax_and_scale_history,
@@ -631,24 +630,34 @@ def __init__(self, dim: int):
631630
self.lin1 = nn.Linear(dim, 4 * dim)
632631
self.lin2 = nn.Linear(4 * dim, 4 * dim)
633632

634-
for emulate in [True, False]:
635-
model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40))
636-
# filter out the linear layers whose shape is smaller than 32 or non-divisible by 16.
637-
model = swap_linear_with_float8_linear(
638-
model,
639-
emulate=emulate,
640-
linear_layer_filter=filter_out_small_unaligned_layers(32),
633+
model = nn.Sequential(MLP(8), nn.Linear(32, 32), MLP(40))
634+
# filter out the linear layers whose shape is smaller than 32 or non-divisible by 16.
635+
636+
size_limit = 32
637+
638+
def module_filter_fn(fqn, mod):
639+
return (
640+
mod.in_features >= size_limit
641+
and mod.out_features >= size_limit
642+
and mod.in_features % 16 == 0
643+
and mod.out_features % 16 == 0
641644
)
642-
# in_features=8, out_features=32, 8 is less than 32.
643-
self.assertNotIsInstance(model[0].lin1, Float8Linear)
644-
# in_features=32, out_features=32,
645-
self.assertIsInstance(model[0].lin2, Float8Linear)
646-
# in_features=32, out_features=32,
647-
self.assertIsInstance(model[1], Float8Linear)
648-
# in_features=40, out_features=160, 40 is not divisible by 16.
649-
self.assertNotIsInstance(model[2].lin1, Float8Linear)
650-
# in_features=160, out_features=160,
651-
self.assertIsInstance(model[2].lin2, Float8Linear)
645+
646+
model = swap_linear_with_float8_linear(
647+
model,
648+
emulate=True,
649+
module_filter_fn=module_filter_fn,
650+
)
651+
# in_features=8, out_features=32, 8 is less than 32.
652+
self.assertNotIsInstance(model[0].lin1, Float8Linear)
653+
# in_features=32, out_features=32,
654+
self.assertIsInstance(model[0].lin2, Float8Linear)
655+
# in_features=32, out_features=32,
656+
self.assertIsInstance(model[1], Float8Linear)
657+
# in_features=40, out_features=160, 40 is not divisible by 16.
658+
self.assertNotIsInstance(model[2].lin1, Float8Linear)
659+
# in_features=160, out_features=160,
660+
self.assertIsInstance(model[2].lin2, Float8Linear)
652661

653662
def test_swap_submodule_linears_with_skip(self):
654663
class MLP(nn.Module):
@@ -657,20 +666,21 @@ def __init__(self, dim: int):
657666
self.lin1 = nn.Linear(dim, 4 * dim)
658667
self.lin2 = nn.Linear(4 * dim, dim)
659668

660-
for emulate in [True, False]:
661-
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
662-
skip_fqn_list = ["2", "0.lin2"]
663-
model = swap_linear_with_float8_linear(
664-
model, emulate=emulate, skip_fqn_list=skip_fqn_list
665-
)
666-
self.assertIsInstance(model[0].lin1, Float8Linear)
667-
self.assertNotIsInstance(model[0].lin2, Float8Linear)
668-
self.assertIsInstance(model[0].lin2, nn.Linear)
669-
self.assertIsInstance(model[1], Float8Linear)
670-
self.assertNotIsInstance(model[2].lin2, Float8Linear)
671-
self.assertNotIsInstance(model[2].lin2, Float8Linear)
672-
self.assertIsInstance(model[2].lin1, nn.Linear)
673-
self.assertIsInstance(model[2].lin2, nn.Linear)
669+
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
670+
module_filter_fn = lambda fqn, mod: fqn not in [
671+
"0.lin2",
672+
"2.lin1",
673+
]
674+
model = swap_linear_with_float8_linear(
675+
model,
676+
emulate=True,
677+
module_filter_fn=module_filter_fn,
678+
)
679+
self.assertTrue(type(model[0].lin1) is Float8Linear)
680+
self.assertTrue(type(model[0].lin2) is nn.Linear)
681+
self.assertTrue(type(model[1]) is Float8Linear)
682+
self.assertTrue(type(model[2].lin1) is nn.Linear)
683+
self.assertTrue(type(model[2].lin2) is Float8Linear)
674684

675685
def test_fp8_tensor_statistics(self):
676686
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)

0 commit comments

Comments
 (0)