Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ed1693e

Browse files
vkuzofacebook-github-bot
authored andcommitted
rename DelayedScalingRecipe to DelayedScalingConfig (#333)
Summary: Pull Request resolved: #333 1. rename `DelayedScalingRecipe` to `DelayedScalingConfig` 2. move this to `config.py` and make user facing Reviewed By: weifengpy Differential Revision: D60252067 fbshipit-source-id: ec233df1e0d03fdc649a19de1722ee45d5029aa6
1 parent eff4ba6 commit ed1693e

File tree

4 files changed

+38
-31
lines changed

4 files changed

+38
-31
lines changed

float8_experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
77
from float8_experimental.config import (
8+
DelayedScalingConfig,
89
Float8LinearConfig,
910
Float8TensorCastConfig,
1011
TensorScalingType,
@@ -30,6 +31,7 @@
3031

3132
__all__ = [
3233
# configuration
34+
"DelayedScalingConfig",
3335
"TensorScalingType",
3436
"Float8LinearConfig",
3537
"Float8TensorCastConfig",

float8_experimental/config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,30 @@ class Float8TensorCastConfig:
2929
scaling_type: TensorScalingType = TensorScalingType.DYNAMIC
3030

3131

32+
@dataclass(frozen=True)
33+
class DelayedScalingConfig:
34+
"""
35+
Configuration for delayed scaling.
36+
37+
Note: for now, `history_len` values must be the same for all layers in the
38+
model using delayed scaling.
39+
40+
TODO(future): serialization for recipes
41+
"""
42+
43+
# Controls the history length of amax buffers
44+
history_len: int = 16
45+
46+
# Controls the way to calculate current scale from amax history
47+
# TODO(future): add other functions as needed, hardcoded or user defined
48+
scale_fn_name: str = "max"
49+
50+
def __post_init__(self):
51+
assert (
52+
self.scale_fn_name == "max"
53+
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
54+
55+
3256
@dataclass(frozen=True)
3357
class Float8LinearConfig:
3458
"""
@@ -71,6 +95,13 @@ class Float8LinearConfig:
7195
# If True, emulation is used instead of hardware accelerated gemm
7296
emulate: bool = False
7397

98+
# Configuration for delayed scaling
99+
# Note: this is actually applied per-tensor, but only using the same
100+
# configuration for all tensors and layers in the model is currently
101+
# supported. If in the future we add support for a more fine grained
102+
# configuration, this field may move to per-tensor configs.
103+
delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig()
104+
74105

75106
# If True, use 'fnuz' float8 types for calculations.
76107
# Currently, ROCm only supports fnuz variants.

float8_experimental/float8_linear.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -131,23 +131,6 @@ def backward(ctx, go):
131131
return res, *empty_grads
132132

133133

134-
@dataclasses.dataclass
135-
class DelayedScalingRecipe:
136-
# Controls the history length of amax buffers
137-
history_len: int
138-
139-
# Controls the way to calculate current scale from amax history
140-
# TODO(future): add other functions as needed, hardcoded or user defined
141-
scale_fn_name: str
142-
143-
def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
144-
self.history_len = history_len
145-
self.scale_fn_name = scale_fn_name
146-
assert (
147-
self.scale_fn_name == "max"
148-
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
149-
150-
151134
class Float8Linear(torch.nn.Linear):
152135
"""
153136
Note: this is **not** a public API and is only intended to be used
@@ -161,13 +144,9 @@ class Float8Linear(torch.nn.Linear):
161144
def __init__(self, *args, **kwargs):
162145
"""
163146
Additional arguments on top of `torch.nn.Linear`'s arguments:
164-
* `delayed_scaling_recipe`: configuration for delayed scaling
165147
* `config`: Float8LinearConfig
166148
"""
167149

168-
delayed_scaling_recipe = kwargs.pop(
169-
"delayed_scaling_recipe", DelayedScalingRecipe()
170-
)
171150
# Amax scales should always be kept as float32.
172151
self.always_float32_buffers = set()
173152
config = kwargs.pop("config")
@@ -187,11 +166,6 @@ def __init__(self, *args, **kwargs):
187166

188167
self.config = config
189168

190-
# TODO(future): have a unique recipe per buffer instead of one per
191-
# module, saving implementing that until we need it.
192-
# TODO(future): serialization for recipes
193-
self.recipe = delayed_scaling_recipe
194-
195169
self.create_buffers()
196170

197171
# TODO(future): user level configuration of gemms
@@ -237,7 +211,7 @@ def __init__(self, *args, **kwargs):
237211

238212
def create_buffers(self):
239213
# Default values for history buffers, see above TODO
240-
history_len = self.recipe.history_len
214+
history_len = self.config.delayed_scaling_config.history_len
241215
device = self.weight.device
242216
# TODO(future PR): dtype values below don't have the other float8
243217
# flavors, fix it
@@ -307,7 +281,7 @@ def cast_x_to_float8(
307281
x = x.to(autocast_dtype)
308282

309283
if self.scaling_type_input is TensorScalingType.DELAYED:
310-
scale_fn_name = self.recipe.scale_fn_name
284+
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
311285
_maybe_initialize_amaxes_scales_for_float8_cast(
312286
x,
313287
self.fp8_amax_input,
@@ -338,7 +312,7 @@ def cast_w_to_float8(
338312
if isinstance(self.weight, Float8Tensor): # cast by FSDP
339313
w_fp8 = self.weight
340314
else:
341-
scale_fn_name = self.recipe.scale_fn_name
315+
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
342316
_maybe_initialize_amaxes_scales_for_float8_cast(
343317
w,
344318
self.fp8_amax_weight,
@@ -370,7 +344,7 @@ def cast_w_to_float8(
370344

371345
def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
372346
if self.scaling_type_grad_output is TensorScalingType.DELAYED:
373-
scale_fn_name = self.recipe.scale_fn_name
347+
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
374348
y = NoopFwToFloat8E5M2Bw.apply(
375349
y,
376350
self.fp8_amax_grad_output,

float8_experimental/float8_linear_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def inner_func():
237237
fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output
238238

239239
x_dtypes.add(child.last_seen_input_dtype)
240-
scale_fn_recipes.add(child.recipe.scale_fn_name)
240+
scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name)
241241

242242
# TODO This way to get the activation dtype is not ideal
243243
if len(x_dtypes) != 1:

0 commit comments

Comments
 (0)