You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Summary:
Pull Request resolved: #322
Before this PR, we had two top level filtering affordances on `swap_linear_with_float8_linear`: `skip_fqn_list` and `linear_layer_filter`:
```
def swap_linear_with_float8_linear(
...,
skip_fqn_list: Optional[List[str]] = None,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
)
```
This PR unifies them into a single filtering method which is aware of both the FQN as well as the module instance. This should be more future proof and allow for more fine grained filtering.
```
def swap_linear_with_float8_linear(
...,
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
)
```
Note that the `filter_out_small_unaligned_layers` function is removed from the public API, and users are encouraged to write their own, again to make this more future proof. I'm not opposed to adding good utility functions back at some point, but IMO we should finalize the main UX first.
Note that in the future (as outlined in the Meta-only UX discussion doc), we may change this function from being just a filter to also providing the float8 per-module config. I'm saving that for a separate PR, which will also be BC breaking (but we don't have BC yet).
Reviewed By: y-sq
Differential Revision: D60072597
fbshipit-source-id: 0b9482459e0826e46bae42dabe6f4f40b805bde7
0 commit comments