Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 1d49dfd

Browse files
committed
[2/x] clean up casting functions: delayed scaling
Summary: Removes delayed scaling from `float8_tensor.py`. After this PR, the invariant is that everything in `float8_tensor.py` requires the scale to be calculated elsewhere. This moves the codebase towards separation of concerns for calculating the scale (via various scaling strategies), separated from creating an instance of `Float8Tensor`. Note that stateful delayed scaling is the reason we need this separation. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 39e6938 Pull Request resolved: #340
1 parent 96162b3 commit 1d49dfd

File tree

8 files changed

+13
-32
lines changed

8 files changed

+13
-32
lines changed

benchmarks/bench_padding.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,13 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
6262
A,
6363
scale_a,
6464
fp8_dtype,
65-
None, # amax_buffer
6665
a_config,
6766
GemmInputRole.INPUT,
6867
)
6968
b_fp8 = ToFloat8ConstrFunc.apply(
7069
B,
7170
scale_b,
7271
fp8_dtype,
73-
None, # amax_buffer
7472
b_config,
7573
GemmInputRole.WEIGHT,
7674
)

float8_experimental/float8_scaling_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def cast_to_float8_e4m3_dynamic(
4444
inpt_tensor,
4545
scale,
4646
e4m3_dtype,
47-
None, # amax_buffer
4847
linear_mm_config,
4948
gemm_input_role,
5049
)
@@ -59,11 +58,11 @@ def cast_to_float8_delayed(
5958
linear_mm_config: Optional[LinearMMConfig] = None,
6059
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
6160
):
61+
amax_buffer.fill_(tensor_to_amax(tensor))
6262
return ToFloat8ConstrFunc.apply(
6363
tensor,
6464
scale,
6565
float8_dtype,
66-
amax_buffer,
6766
linear_mm_config,
6867
gemm_input_role,
6968
)

float8_experimental/float8_tensor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def forward(
207207
tensor: torch.Tensor,
208208
scale: torch.Tensor,
209209
float8_dtype=e4m3_dtype,
210-
amax_buffer: Optional[torch.Tensor] = None,
211210
linear_mm_config: Optional[LinearMMConfig] = None,
212211
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
213212
):
@@ -216,11 +215,8 @@ def forward(
216215
tensor: the tensor to convert
217216
scale: the scale to use to convert the tensor
218217
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
219-
amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
220218
emulate: whether to emulate the matmuls in fp32
221219
"""
222-
if amax_buffer is not None:
223-
amax_buffer.fill_(tensor_to_amax(tensor))
224220

225221
return to_fp8_no_autograd(
226222
tensor,

float8_experimental/fsdp_utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
import torch
1111
import torch.nn as nn
1212
import torch.utils._pytree as pytree
13-
from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic
13+
from float8_experimental.float8_scaling_utils import (
14+
cast_to_float8_delayed,
15+
cast_to_float8_e4m3_dynamic,
16+
)
1417

1518
from float8_experimental.float8_tensor import (
1619
Float8Tensor,
@@ -168,7 +171,6 @@ def fsdp_pre_all_gather(self, mesh):
168171
self._tensor,
169172
self._precomputed_scale,
170173
torch.float8_e4m3fn,
171-
None, # amax_buffer
172174
self._linear_mm_config,
173175
GemmInputRole.WEIGHT,
174176
)
@@ -352,12 +354,7 @@ def fsdp_pre_all_gather(self, mesh):
352354
)
353355
self.is_amax_initialized = True
354356

355-
# this will:
356-
# 1. cast the tensor to float8 using `_scale_buffer`
357-
# 2. populate `_amax_buffer` inplace
358-
# TODO(future PR): clean up all the casting functions and clearly
359-
# separate dynamic vs delayed, tech debt has accumulated
360-
float8_tensor = ToFloat8ConstrFunc.apply(
357+
float8_tensor = cast_to_float8_delayed(
361358
self._tensor,
362359
self._scale_buffer,
363360
e4m3_dtype,

float8_experimental/inference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
131131
self.weight,
132132
scale,
133133
dtype,
134-
None, # amax_buffer
135134
self.linear_mm_config,
136135
GemmInputRole.WEIGHT,
137136
)
@@ -205,7 +204,6 @@ def cast_to_float8_e4m3_inference(
205204
inpt_tensor,
206205
scale,
207206
e4m3_dtype,
208-
None, # amax_buffer
209207
linear_mm_config,
210208
GemmInputRole.INPUT,
211209
)

test/test_base.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -451,15 +451,13 @@ def test_different_configs_error(self):
451451
x_fp32,
452452
x_scale,
453453
fp8_dtype,
454-
None, # amax_buffer
455454
linear_config_a,
456455
GemmInputRole.INPUT,
457456
)
458457
b = ToFloat8ConstrFunc.apply(
459458
x_fp32,
460459
x_scale,
461460
fp8_dtype,
462-
None, # amax_buffer
463461
linear_config_b,
464462
GemmInputRole.WEIGHT,
465463
)
@@ -489,10 +487,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
489487
b_scale = tensor_to_scale(b, input_dtype).float()
490488

491489
a_fp8 = ToFloat8ConstrFunc.apply(
492-
a, a_scale, input_dtype, None, None, GemmInputRole.INPUT
490+
a, a_scale, input_dtype, None, GemmInputRole.INPUT
493491
)
494492
b_fp8 = ToFloat8ConstrFunc.apply(
495-
b, b_scale, input_dtype, None, None, GemmInputRole.WEIGHT
493+
b, b_scale, input_dtype, None, GemmInputRole.WEIGHT
496494
)
497495

498496
with pytest.raises(
@@ -512,15 +510,13 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
512510
a,
513511
a_scale,
514512
input_dtype,
515-
None, # amax_buffer
516513
pad_config,
517514
GemmInputRole.INPUT,
518515
)
519516
b_fp8 = ToFloat8ConstrFunc.apply(
520517
b,
521518
b_scale,
522519
input_dtype,
523-
None, # amax_buffer
524520
pad_config,
525521
GemmInputRole.WEIGHT,
526522
)
@@ -537,15 +533,13 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
537533
a,
538534
a_scale,
539535
input_dtype,
540-
None, # amax_buffer
541536
emulated_config,
542537
GemmInputRole.INPUT,
543538
)
544539
b_fp8 = ToFloat8ConstrFunc.apply(
545540
b,
546541
b_scale,
547542
input_dtype,
548-
None, # amax_buffer
549543
emulated_config,
550544
GemmInputRole.WEIGHT,
551545
)

test/test_compile.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
get_float8_layers,
2121
sync_float8_amax_and_scale_history,
2222
)
23-
from float8_experimental.float8_tensor import LinearMMConfig, ToFloat8ConstrFunc
23+
from float8_experimental.float8_tensor import LinearMMConfig
2424
from float8_experimental.float8_utils import e4m3_dtype
25+
from float8_experimental.float8_scaling_utils import cast_to_float8_delayed
2526

2627
from torch._dynamo.test_case import TestCase as DynamoTestCase
2728
from torch._dynamo.testing import CompileCounterWithBackend
@@ -178,7 +179,7 @@ def __init__(self, graph_break: bool):
178179
self.graph_break = graph_break
179180

180181
def forward(self, x):
181-
x_fp8 = ToFloat8ConstrFunc.apply(
182+
x_fp8 = cast_to_float8_delayed(
182183
x,
183184
self.fp8_scale_x,
184185
e4m3_dtype,

test/test_dtensor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,10 @@ def test_scaled_mm(mesh: DeviceMesh, size=16):
8888
y_scale = tensor_to_scale(y_fp32, fp8_dtype).float()
8989

9090
x_fp8 = ToFloat8ConstrFunc.apply(
91-
x_fp32, x_scale, fp8_dtype, None, None, GemmInputRole.INPUT
91+
x_fp32, x_scale, fp8_dtype, None, GemmInputRole.INPUT
9292
)
9393
y_fp8 = ToFloat8ConstrFunc.apply(
94-
y_fp32, y_scale, fp8_dtype, None, None, GemmInputRole.WEIGHT
94+
y_fp32, y_scale, fp8_dtype, None, GemmInputRole.WEIGHT
9595
)
9696

9797
dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False)
@@ -169,15 +169,13 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
169169
dist_x_scale,
170170
fp8_dtype,
171171
None,
172-
None,
173172
GemmInputRole.INPUT,
174173
)
175174
dist_weight_fp8 = ToFloat8ConstrFunc.apply(
176175
dist_wight_fp32,
177176
dist_weight_scale,
178177
fp8_dtype,
179178
None,
180-
None,
181179
GemmInputRole.WEIGHT,
182180
)
183181

0 commit comments

Comments
 (0)