@@ -131,23 +131,6 @@ def backward(ctx, go):
131
131
return res , * empty_grads
132
132
133
133
134
- @dataclasses .dataclass
135
- class DelayedScalingRecipe :
136
- # Controls the history length of amax buffers
137
- history_len : int
138
-
139
- # Controls the way to calculate current scale from amax history
140
- # TODO(future): add other functions as needed, hardcoded or user defined
141
- scale_fn_name : str
142
-
143
- def __init__ (self , history_len : int = 16 , scale_fn_name : str = "max" ):
144
- self .history_len = history_len
145
- self .scale_fn_name = scale_fn_name
146
- assert (
147
- self .scale_fn_name == "max"
148
- ), f"{ self .scale_fn_name } is not implemented yet. Only max is supported for now."
149
-
150
-
151
134
class Float8Linear (torch .nn .Linear ):
152
135
"""
153
136
Note: this is **not** a public API and is only intended to be used
@@ -161,13 +144,9 @@ class Float8Linear(torch.nn.Linear):
161
144
def __init__ (self , * args , ** kwargs ):
162
145
"""
163
146
Additional arguments on top of `torch.nn.Linear`'s arguments:
164
- * `delayed_scaling_recipe`: configuration for delayed scaling
165
147
* `config`: Float8LinearConfig
166
148
"""
167
149
168
- delayed_scaling_recipe = kwargs .pop (
169
- "delayed_scaling_recipe" , DelayedScalingRecipe ()
170
- )
171
150
# Amax scales should always be kept as float32.
172
151
self .always_float32_buffers = set ()
173
152
config = kwargs .pop ("config" )
@@ -187,11 +166,6 @@ def __init__(self, *args, **kwargs):
187
166
188
167
self .config = config
189
168
190
- # TODO(future): have a unique recipe per buffer instead of one per
191
- # module, saving implementing that until we need it.
192
- # TODO(future): serialization for recipes
193
- self .recipe = delayed_scaling_recipe
194
-
195
169
self .create_buffers ()
196
170
197
171
# TODO(future): user level configuration of gemms
@@ -237,7 +211,7 @@ def __init__(self, *args, **kwargs):
237
211
238
212
def create_buffers (self ):
239
213
# Default values for history buffers, see above TODO
240
- history_len = self .recipe .history_len
214
+ history_len = self .config . delayed_scaling_config .history_len
241
215
device = self .weight .device
242
216
# TODO(future PR): dtype values below don't have the other float8
243
217
# flavors, fix it
@@ -307,7 +281,7 @@ def cast_x_to_float8(
307
281
x = x .to (autocast_dtype )
308
282
309
283
if self .scaling_type_input is TensorScalingType .DELAYED :
310
- scale_fn_name = self .recipe .scale_fn_name
284
+ scale_fn_name = self .config . delayed_scaling_config .scale_fn_name
311
285
_maybe_initialize_amaxes_scales_for_float8_cast (
312
286
x ,
313
287
self .fp8_amax_input ,
@@ -338,7 +312,7 @@ def cast_w_to_float8(
338
312
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
339
313
w_fp8 = self .weight
340
314
else :
341
- scale_fn_name = self .recipe .scale_fn_name
315
+ scale_fn_name = self .config . delayed_scaling_config .scale_fn_name
342
316
_maybe_initialize_amaxes_scales_for_float8_cast (
343
317
w ,
344
318
self .fp8_amax_weight ,
@@ -370,7 +344,7 @@ def cast_w_to_float8(
370
344
371
345
def cast_y_to_float8_in_bw (self , y : torch .Tensor ) -> torch .Tensor :
372
346
if self .scaling_type_grad_output is TensorScalingType .DELAYED :
373
- scale_fn_name = self .recipe .scale_fn_name
347
+ scale_fn_name = self .config . delayed_scaling_config .scale_fn_name
374
348
y = NoopFwToFloat8E5M2Bw .apply (
375
349
y ,
376
350
self .fp8_amax_grad_output ,
0 commit comments